Skip to content

Commit fec5e8a

Browse files
authored
[ML] prefer least allocated model when a new node is added to the cluster (#77756)
When a new node is added to the cluster, we should first attempt to allocate models that have fewer allocated nodes. This way we prioritize getting allocations up and running vs. using a random selection.
1 parent ff9f2a2 commit fec5e8a

File tree

3 files changed

+190
-78
lines changed

3 files changed

+190
-78
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java

Lines changed: 91 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@
4040
import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
4141

4242
import java.util.Collections;
43-
import java.util.List;
43+
import java.util.Comparator;
4444
import java.util.Locale;
4545
import java.util.Map;
4646
import java.util.Optional;
4747
import java.util.Set;
4848
import java.util.TreeMap;
49+
import java.util.function.Function;
4950
import java.util.stream.Collectors;
5051

5152
public class TrainedModelAllocationClusterService implements ClusterStateListener {
@@ -245,8 +246,7 @@ ClusterState createModelAllocation(ClusterState currentState, StartTrainedModelD
245246
Set<String> shuttingDownNodes = nodesShuttingDown(currentState);
246247
Map<String, String> nodeToReason = new TreeMap<>();
247248
for (DiscoveryNode node : currentState.getNodes().getAllNodes()) {
248-
if (StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node)
249-
&& shuttingDownNodes.contains(node.getId()) == false) {
249+
if (StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node) && shuttingDownNodes.contains(node.getId()) == false) {
250250
Optional<String> maybeError = nodeHasCapacity(currentState, params, node);
251251
if (maybeError.isPresent()) {
252252
nodeToReason.put(node.getName(), maybeError.get());
@@ -289,16 +289,8 @@ static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTra
289289
logger.trace(
290290
() -> new ParameterizedMessage("[{}] [{}] current metadata before update {}", modelId, nodeId, Strings.toString(metadata))
291291
);
292-
Set<String> shuttingDownNodes = nodesShuttingDown(currentState);
293-
List<DiscoveryNode> allocatableNodes = currentState.nodes()
294-
.getAllNodes()
295-
.stream()
296-
.filter(
297-
d -> StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(d) && shuttingDownNodes.contains(d.getId()) == false
298-
)
299-
.collect(Collectors.toList());
300292
final TrainedModelAllocation existingAllocation = metadata.getModelAllocation(modelId);
301-
final TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);
293+
final TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);
302294
// If state is stopped, this indicates the node process is closed, remove the node from the allocation
303295
if (request.getRoutingState().getState().equals(RoutingState.STOPPED)) {
304296
if (existingAllocation == null || existingAllocation.isRoutedToNode(nodeId) == false) {
@@ -313,20 +305,20 @@ static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTra
313305
}
314306
// If we are stopping, don't update anything
315307
if (existingAllocation.getAllocationState().equals(AllocationState.STOPPING)) {
316-
logger.debug(() -> new ParameterizedMessage(
317-
"[{}] requested update from node [{}] to update route state to [{}]",
318-
modelId,
319-
nodeId,
320-
request.getRoutingState()
321-
));
308+
logger.debug(
309+
() -> new ParameterizedMessage(
310+
"[{}] requested update from node [{}] to update route state to [{}]",
311+
modelId,
312+
nodeId,
313+
request.getRoutingState()
314+
)
315+
);
322316
return currentState;
323317
}
324318
if (existingAllocation.isRoutedToNode(nodeId) == false) {
325319
throw new ResourceNotFoundException("allocation for model with id [{}]] is not routed to node [{}]", modelId, nodeId);
326320
}
327-
builder.getAllocation(modelId)
328-
.updateExistingRoutingEntry(nodeId, request.getRoutingState())
329-
.calculateAndSetAllocationState();
321+
builder.getAllocation(modelId).updateExistingRoutingEntry(nodeId, request.getRoutingState()).calculateAndSetAllocationState();
330322

331323
return update(currentState, builder);
332324
}
@@ -342,7 +334,7 @@ static ClusterState removeAllocation(ClusterState currentState, String modelId)
342334
static ClusterState removeAllAllocations(ClusterState currentState) {
343335
if (TrainedModelAllocationMetadata.fromState(currentState).modelAllocations().isEmpty()) {
344336
return currentState;
345-
};
337+
}
346338
return ClusterState.builder(currentState)
347339
.metadata(
348340
Metadata.builder(currentState.metadata())
@@ -356,64 +348,62 @@ ClusterState addRemoveAllocationNodes(ClusterState currentState) {
356348
final TrainedModelAllocationMetadata previousState = TrainedModelAllocationMetadata.fromState(currentState);
357349
final TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);
358350
Set<String> shuttingDownNodes = nodesShuttingDown(currentState);
359-
Set<String> currentNotShuttingDownNodes = currentState.getNodes()
351+
Map<String, DiscoveryNode> currentEligibleNodes = currentState.getNodes()
360352
.getAllNodes()
361353
.stream()
362-
.map(DiscoveryNode::getId)
363-
.filter(id -> shuttingDownNodes.contains(id) == false)
364-
.collect(Collectors.toSet());
365-
// TODO: make more efficient, right now this is O(nm) where n = sizeof(models) and m = sizeof(nodes)
366-
// It could probably be O(max(n, m))
367-
// Add nodes and keep track of currently routed nodes
368-
// Should we indicate a partial allocation somehow if some nodes don't have space?
369-
for (Map.Entry<String, TrainedModelAllocation> modelAllocationEntry : previousState.modelAllocations().entrySet()) {
370-
// Don't bother adding/removing nodes if this allocation is stopping
371-
if (modelAllocationEntry.getValue().getAllocationState().equals(AllocationState.STOPPING)) {
372-
continue;
373-
}
374-
final String modelId = modelAllocationEntry.getKey();
375-
Map<String, String> nodeToReason = new TreeMap<>();
376-
for (DiscoveryNode node : currentState.getNodes()) {
377-
// Only add the route if the node is NOT shutting down, this would be a weird case of the node
378-
// just being added to the cluster and immediately shutting down...
379-
if (shuttingDownNodes.contains(node.getId()) == false
380-
&& StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node)
381-
&& modelAllocationEntry.getValue().isRoutedToNode(node.getId()) == false) {
382-
Optional<String> failure = nodeHasCapacity(currentState, modelAllocationEntry.getValue().getTaskParams(), node);
383-
if (failure.isPresent()) {
384-
nodeToReason.put(node.getName(), failure.get());
385-
} else {
386-
builder.getAllocation(modelId).addNewRoutingEntry(node.getId());
354+
// TODO: Change when we update `mayAllocateToNode`
355+
.filter(node -> shuttingDownNodes.contains(node.getId()) == false
356+
&& StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node))
357+
.collect(Collectors.toMap(DiscoveryNode::getId, Function.identity()));
358+
// TODO: make more efficient, we iterate every entry, sorting by nodes routed (fewest to most)
359+
previousState.modelAllocations()
360+
.entrySet()
361+
.stream()
362+
.filter(entry -> entry.getValue().getAllocationState().equals(AllocationState.STOPPING) == false)
363+
.sorted(Comparator.comparing(e -> e.getValue().getNodeRoutingTable().size()))
364+
.forEach(modelAllocationEntry -> {
365+
final String modelId = modelAllocationEntry.getKey();
366+
Map<String, String> nodeToReason = new TreeMap<>();
367+
for (DiscoveryNode node : currentEligibleNodes.values()) {
368+
if (modelAllocationEntry.getValue().isRoutedToNode(node.getId()) == false) {
369+
Optional<String> failure = builder.isChanged() ?
370+
// We use the builder only if we have changed, there is no point in creating a new object if we haven't changed
371+
nodeHasCapacity(currentState, builder, modelAllocationEntry.getValue().getTaskParams(), node) :
372+
nodeHasCapacity(currentState, modelAllocationEntry.getValue().getTaskParams(), node);
373+
if (failure.isPresent()) {
374+
nodeToReason.put(node.getName(), failure.get());
375+
} else {
376+
builder.getAllocation(modelId).addNewRoutingEntry(node.getId());
377+
}
387378
}
388379
}
389-
}
390-
if (nodeToReason.isEmpty() == false) {
391-
builder.getAllocation(modelId)
392-
.setReason(
393-
nodeToReason.entrySet()
394-
.stream()
395-
.map(
396-
entry -> String.format(
397-
Locale.ROOT,
398-
"Not allocating on node [%s]. Reason: %s",
399-
entry.getKey(),
400-
entry.getValue()
380+
if (nodeToReason.isEmpty() == false) {
381+
builder.getAllocation(modelId)
382+
.setReason(
383+
nodeToReason.entrySet()
384+
.stream()
385+
.map(
386+
entry -> String.format(
387+
Locale.ROOT,
388+
"Not allocating on node [%s]. Reason: %s",
389+
entry.getKey(),
390+
entry.getValue()
391+
)
401392
)
402-
)
403-
.collect(Collectors.joining("|"))
404-
);
405-
} else {
406-
builder.getAllocation(modelId).clearReason();
407-
}
408-
for (String nodeId : modelAllocationEntry.getValue().getNodeRoutingTable().keySet()) {
409-
if (currentNotShuttingDownNodes.contains(nodeId) == false) {
410-
builder.getAllocation(modelId).removeRoutingEntry(nodeId);
393+
.collect(Collectors.joining("|"))
394+
);
395+
} else {
396+
builder.getAllocation(modelId).clearReason();
411397
}
412-
}
413-
// It may be we moved from STARTED to PARTIALLY_STARTED with the addition of new nodes
414-
// Or moved from PARTIALLY_STARTED to STARTED if a node was removed
415-
builder.getAllocation(modelId).calculateAndSetAllocationState();
416-
}
398+
for (String nodeId : modelAllocationEntry.getValue().getNodeRoutingTable().keySet()) {
399+
if (currentEligibleNodes.containsKey(nodeId) == false) {
400+
builder.getAllocation(modelId).removeRoutingEntry(nodeId);
401+
}
402+
}
403+
// It may be we moved from STARTED to PARTIALLY_STARTED with the addition of new nodes
404+
// Or moved from PARTIALLY_STARTED to STARTED if a node was removed
405+
builder.getAllocation(modelId).calculateAndSetAllocationState();
406+
});
417407
return update(currentState, builder);
418408
}
419409

@@ -448,8 +438,33 @@ static boolean shouldAllocateModels(final ClusterChangedEvent event) {
448438

449439
Optional<String> nodeHasCapacity(ClusterState state, StartTrainedModelDeploymentAction.TaskParams params, DiscoveryNode node) {
450440
NodeLoad load = nodeLoadDetector.detectNodeLoad(state, true, node, Integer.MAX_VALUE, maxMemoryPercentage, useAuto);
441+
return handleNodeLoad(load, node.getId(), params);
442+
}
443+
444+
/**
445+
* Gather current node capacity taking the passed allocation metadata into account instead of the one stored in cluster state.
446+
*/
447+
Optional<String> nodeHasCapacity(
448+
ClusterState state,
449+
TrainedModelAllocationMetadata.Builder builder,
450+
StartTrainedModelDeploymentAction.TaskParams params,
451+
DiscoveryNode node
452+
) {
453+
NodeLoad load = nodeLoadDetector.detectNodeLoad(
454+
state,
455+
builder.build(),
456+
true,
457+
node,
458+
Integer.MAX_VALUE,
459+
maxMemoryPercentage,
460+
useAuto
461+
);
462+
return handleNodeLoad(load, node.getId(), params);
463+
}
464+
465+
Optional<String> handleNodeLoad(NodeLoad load, String nodeId, StartTrainedModelDeploymentAction.TaskParams params) {
451466
if (Strings.isNullOrEmpty(load.getError()) == false) {
452-
logger.warn("[{}] failed to calculate current node load with error [{}]", params.getModelId(), node.getId());
467+
logger.warn("[{}] failed to calculate current node load with error [{}]", params.getModelId(), nodeId);
453468
return Optional.of(load.getError());
454469
}
455470
if (load.getFreeMemory() < params.estimateMemoryUsageBytes()) {
@@ -464,8 +479,7 @@ Optional<String> nodeHasCapacity(ClusterState state, StartTrainedModelDeployment
464479
load.getAssignedJobMemory(),
465480
ByteSizeValue.ofBytes(load.getAssignedJobMemory()).toString(),
466481
params.estimateMemoryUsageBytes(),
467-
ByteSizeValue.ofBytes(params.estimateMemoryUsageBytes()).toString()
468-
}
482+
ByteSizeValue.ofBytes(params.estimateMemoryUsageBytes()).toString() }
469483
)
470484
);
471485
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoadDetector.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,24 @@ public NodeLoad detectNodeLoad(ClusterState clusterState,
4848
int dynamicMaxOpenJobs,
4949
int maxMachineMemoryPercent,
5050
boolean useAutoMachineMemoryCalculation) {
51+
return detectNodeLoad(
52+
clusterState,
53+
TrainedModelAllocationMetadata.fromState(clusterState),
54+
allNodesHaveDynamicMaxWorkers,
55+
node,
56+
dynamicMaxOpenJobs,
57+
maxMachineMemoryPercent,
58+
useAutoMachineMemoryCalculation
59+
);
60+
}
61+
62+
public NodeLoad detectNodeLoad(ClusterState clusterState,
63+
TrainedModelAllocationMetadata allocationMetadata,
64+
boolean allNodesHaveDynamicMaxWorkers,
65+
DiscoveryNode node,
66+
int dynamicMaxOpenJobs,
67+
int maxMachineMemoryPercent,
68+
boolean useAutoMachineMemoryCalculation) {
5169
PersistentTasksCustomMetadata persistentTasks = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE);
5270
Map<String, String> nodeAttributes = node.getAttributes();
5371
List<String> errors = new ArrayList<>();
@@ -80,7 +98,7 @@ public NodeLoad detectNodeLoad(ClusterState clusterState,
8098
return nodeLoad.setError(Strings.collectionToCommaDelimitedString(errors)).build();
8199
}
82100
updateLoadGivenTasks(nodeLoad, persistentTasks);
83-
updateLoadGivenModelAllocations(nodeLoad, TrainedModelAllocationMetadata.fromState(clusterState));
101+
updateLoadGivenModelAllocations(nodeLoad, allocationMetadata);
84102
return nodeLoad.build();
85103
}
86104

0 commit comments

Comments
 (0)