4040import org .elasticsearch .xpack .ml .job .NodeLoadDetector ;
4141
4242import java .util .Collections ;
43- import java .util .List ;
43+ import java .util .Comparator ;
4444import java .util .Locale ;
4545import java .util .Map ;
4646import java .util .Optional ;
4747import java .util .Set ;
4848import java .util .TreeMap ;
49+ import java .util .function .Function ;
4950import java .util .stream .Collectors ;
5051
5152public class TrainedModelAllocationClusterService implements ClusterStateListener {
@@ -245,8 +246,7 @@ ClusterState createModelAllocation(ClusterState currentState, StartTrainedModelD
245246 Set <String > shuttingDownNodes = nodesShuttingDown (currentState );
246247 Map <String , String > nodeToReason = new TreeMap <>();
247248 for (DiscoveryNode node : currentState .getNodes ().getAllNodes ()) {
248- if (StartTrainedModelDeploymentAction .TaskParams .mayAllocateToNode (node )
249- && shuttingDownNodes .contains (node .getId ()) == false ) {
249+ if (StartTrainedModelDeploymentAction .TaskParams .mayAllocateToNode (node ) && shuttingDownNodes .contains (node .getId ()) == false ) {
250250 Optional <String > maybeError = nodeHasCapacity (currentState , params , node );
251251 if (maybeError .isPresent ()) {
252252 nodeToReason .put (node .getName (), maybeError .get ());
@@ -289,16 +289,8 @@ static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTra
289289 logger .trace (
290290 () -> new ParameterizedMessage ("[{}] [{}] current metadata before update {}" , modelId , nodeId , Strings .toString (metadata ))
291291 );
292- Set <String > shuttingDownNodes = nodesShuttingDown (currentState );
293- List <DiscoveryNode > allocatableNodes = currentState .nodes ()
294- .getAllNodes ()
295- .stream ()
296- .filter (
297- d -> StartTrainedModelDeploymentAction .TaskParams .mayAllocateToNode (d ) && shuttingDownNodes .contains (d .getId ()) == false
298- )
299- .collect (Collectors .toList ());
300292 final TrainedModelAllocation existingAllocation = metadata .getModelAllocation (modelId );
301- final TrainedModelAllocationMetadata .Builder builder = TrainedModelAllocationMetadata .builder (currentState );
293+ final TrainedModelAllocationMetadata .Builder builder = TrainedModelAllocationMetadata .builder (currentState );
302294 // If state is stopped, this indicates the node process is closed, remove the node from the allocation
303295 if (request .getRoutingState ().getState ().equals (RoutingState .STOPPED )) {
304296 if (existingAllocation == null || existingAllocation .isRoutedToNode (nodeId ) == false ) {
@@ -313,20 +305,20 @@ static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTra
313305 }
314306 // If we are stopping, don't update anything
315307 if (existingAllocation .getAllocationState ().equals (AllocationState .STOPPING )) {
316- logger .debug (() -> new ParameterizedMessage (
317- "[{}] requested update from node [{}] to update route state to [{}]" ,
318- modelId ,
319- nodeId ,
320- request .getRoutingState ()
321- ));
308+ logger .debug (
309+ () -> new ParameterizedMessage (
310+ "[{}] requested update from node [{}] to update route state to [{}]" ,
311+ modelId ,
312+ nodeId ,
313+ request .getRoutingState ()
314+ )
315+ );
322316 return currentState ;
323317 }
324318 if (existingAllocation .isRoutedToNode (nodeId ) == false ) {
325319 throw new ResourceNotFoundException ("allocation for model with id [{}]] is not routed to node [{}]" , modelId , nodeId );
326320 }
327- builder .getAllocation (modelId )
328- .updateExistingRoutingEntry (nodeId , request .getRoutingState ())
329- .calculateAndSetAllocationState ();
321+ builder .getAllocation (modelId ).updateExistingRoutingEntry (nodeId , request .getRoutingState ()).calculateAndSetAllocationState ();
330322
331323 return update (currentState , builder );
332324 }
@@ -342,7 +334,7 @@ static ClusterState removeAllocation(ClusterState currentState, String modelId)
342334 static ClusterState removeAllAllocations (ClusterState currentState ) {
343335 if (TrainedModelAllocationMetadata .fromState (currentState ).modelAllocations ().isEmpty ()) {
344336 return currentState ;
345- };
337+ }
346338 return ClusterState .builder (currentState )
347339 .metadata (
348340 Metadata .builder (currentState .metadata ())
@@ -356,64 +348,62 @@ ClusterState addRemoveAllocationNodes(ClusterState currentState) {
356348 final TrainedModelAllocationMetadata previousState = TrainedModelAllocationMetadata .fromState (currentState );
357349 final TrainedModelAllocationMetadata .Builder builder = TrainedModelAllocationMetadata .builder (currentState );
358350 Set <String > shuttingDownNodes = nodesShuttingDown (currentState );
359- Set <String > currentNotShuttingDownNodes = currentState .getNodes ()
351+ Map <String , DiscoveryNode > currentEligibleNodes = currentState .getNodes ()
360352 .getAllNodes ()
361353 .stream ()
362- .map (DiscoveryNode ::getId )
363- .filter (id -> shuttingDownNodes .contains (id ) == false )
364- .collect (Collectors .toSet ());
365- // TODO: make more efficient, right now this is O(nm) where n = sizeof(models) and m = sizeof(nodes)
366- // It could probably be O(max(n, m))
367- // Add nodes and keep track of currently routed nodes
368- // Should we indicate a partial allocation somehow if some nodes don't have space?
369- for (Map .Entry <String , TrainedModelAllocation > modelAllocationEntry : previousState .modelAllocations ().entrySet ()) {
370- // Don't bother adding/removing nodes if this allocation is stopping
371- if (modelAllocationEntry .getValue ().getAllocationState ().equals (AllocationState .STOPPING )) {
372- continue ;
373- }
374- final String modelId = modelAllocationEntry .getKey ();
375- Map <String , String > nodeToReason = new TreeMap <>();
376- for (DiscoveryNode node : currentState .getNodes ()) {
377- // Only add the route if the node is NOT shutting down, this would be a weird case of the node
378- // just being added to the cluster and immediately shutting down...
379- if (shuttingDownNodes .contains (node .getId ()) == false
380- && StartTrainedModelDeploymentAction .TaskParams .mayAllocateToNode (node )
381- && modelAllocationEntry .getValue ().isRoutedToNode (node .getId ()) == false ) {
382- Optional <String > failure = nodeHasCapacity (currentState , modelAllocationEntry .getValue ().getTaskParams (), node );
383- if (failure .isPresent ()) {
384- nodeToReason .put (node .getName (), failure .get ());
385- } else {
386- builder .getAllocation (modelId ).addNewRoutingEntry (node .getId ());
354+ // TODO: Change when we update `mayAllocateToNode`
355+ .filter (node -> shuttingDownNodes .contains (node .getId ()) == false
356+ && StartTrainedModelDeploymentAction .TaskParams .mayAllocateToNode (node ))
357+ .collect (Collectors .toMap (DiscoveryNode ::getId , Function .identity ()));
358+ // TODO: make more efficient, we iterate every entry, sorting by nodes routed (fewest to most)
359+ previousState .modelAllocations ()
360+ .entrySet ()
361+ .stream ()
362+ .filter (entry -> entry .getValue ().getAllocationState ().equals (AllocationState .STOPPING ) == false )
363+ .sorted (Comparator .comparing (e -> e .getValue ().getNodeRoutingTable ().size ()))
364+ .forEach (modelAllocationEntry -> {
365+ final String modelId = modelAllocationEntry .getKey ();
366+ Map <String , String > nodeToReason = new TreeMap <>();
367+ for (DiscoveryNode node : currentEligibleNodes .values ()) {
368+ if (modelAllocationEntry .getValue ().isRoutedToNode (node .getId ()) == false ) {
369+ Optional <String > failure = builder .isChanged () ?
370+ // We use the builder only if we have changed, there is no point in creating a new object if we haven't changed
371+ nodeHasCapacity (currentState , builder , modelAllocationEntry .getValue ().getTaskParams (), node ) :
372+ nodeHasCapacity (currentState , modelAllocationEntry .getValue ().getTaskParams (), node );
373+ if (failure .isPresent ()) {
374+ nodeToReason .put (node .getName (), failure .get ());
375+ } else {
376+ builder .getAllocation (modelId ).addNewRoutingEntry (node .getId ());
377+ }
387378 }
388379 }
389- }
390- if ( nodeToReason . isEmpty () == false ) {
391- builder . getAllocation ( modelId )
392- . setReason (
393- nodeToReason . entrySet ()
394- . stream ()
395- . map (
396- entry -> String . format (
397- Locale . ROOT ,
398- "Not allocating on node [%s]. Reason: %s" ,
399- entry .getKey (),
400- entry . getValue ( )
380+ if ( nodeToReason . isEmpty () == false ) {
381+ builder . getAllocation ( modelId )
382+ . setReason (
383+ nodeToReason . entrySet ()
384+ . stream ()
385+ . map (
386+ entry -> String . format (
387+ Locale . ROOT ,
388+ "Not allocating on node [%s]. Reason: %s" ,
389+ entry . getKey () ,
390+ entry .getValue ()
391+ )
401392 )
402- )
403- .collect (Collectors .joining ("|" ))
404- );
405- } else {
406- builder .getAllocation (modelId ).clearReason ();
407- }
408- for (String nodeId : modelAllocationEntry .getValue ().getNodeRoutingTable ().keySet ()) {
409- if (currentNotShuttingDownNodes .contains (nodeId ) == false ) {
410- builder .getAllocation (modelId ).removeRoutingEntry (nodeId );
393+ .collect (Collectors .joining ("|" ))
394+ );
395+ } else {
396+ builder .getAllocation (modelId ).clearReason ();
411397 }
412- }
413- // It may be we moved from STARTED to PARTIALLY_STARTED with the addition of new nodes
414- // Or moved from PARTIALLY_STARTED to STARTED if a node was removed
415- builder .getAllocation (modelId ).calculateAndSetAllocationState ();
416- }
398+ for (String nodeId : modelAllocationEntry .getValue ().getNodeRoutingTable ().keySet ()) {
399+ if (currentEligibleNodes .containsKey (nodeId ) == false ) {
400+ builder .getAllocation (modelId ).removeRoutingEntry (nodeId );
401+ }
402+ }
403+ // It may be we moved from STARTED to PARTIALLY_STARTED with the addition of new nodes
404+ // Or moved from PARTIALLY_STARTED to STARTED if a node was removed
405+ builder .getAllocation (modelId ).calculateAndSetAllocationState ();
406+ });
417407 return update (currentState , builder );
418408 }
419409
@@ -448,8 +438,33 @@ static boolean shouldAllocateModels(final ClusterChangedEvent event) {
448438
449439 Optional <String > nodeHasCapacity (ClusterState state , StartTrainedModelDeploymentAction .TaskParams params , DiscoveryNode node ) {
450440 NodeLoad load = nodeLoadDetector .detectNodeLoad (state , true , node , Integer .MAX_VALUE , maxMemoryPercentage , useAuto );
441+ return handleNodeLoad (load , node .getId (), params );
442+ }
443+
444+ /**
445+ * Gather current node capacity taking the passed allocation metadata into account instead of the one stored in cluster state.
446+ */
447+ Optional <String > nodeHasCapacity (
448+ ClusterState state ,
449+ TrainedModelAllocationMetadata .Builder builder ,
450+ StartTrainedModelDeploymentAction .TaskParams params ,
451+ DiscoveryNode node
452+ ) {
453+ NodeLoad load = nodeLoadDetector .detectNodeLoad (
454+ state ,
455+ builder .build (),
456+ true ,
457+ node ,
458+ Integer .MAX_VALUE ,
459+ maxMemoryPercentage ,
460+ useAuto
461+ );
462+ return handleNodeLoad (load , node .getId (), params );
463+ }
464+
465+ Optional <String > handleNodeLoad (NodeLoad load , String nodeId , StartTrainedModelDeploymentAction .TaskParams params ) {
451466 if (Strings .isNullOrEmpty (load .getError ()) == false ) {
452- logger .warn ("[{}] failed to calculate current node load with error [{}]" , params .getModelId (), node . getId () );
467+ logger .warn ("[{}] failed to calculate current node load with error [{}]" , params .getModelId (), nodeId );
453468 return Optional .of (load .getError ());
454469 }
455470 if (load .getFreeMemory () < params .estimateMemoryUsageBytes ()) {
@@ -464,8 +479,7 @@ Optional<String> nodeHasCapacity(ClusterState state, StartTrainedModelDeployment
464479 load .getAssignedJobMemory (),
465480 ByteSizeValue .ofBytes (load .getAssignedJobMemory ()).toString (),
466481 params .estimateMemoryUsageBytes (),
467- ByteSizeValue .ofBytes (params .estimateMemoryUsageBytes ()).toString ()
468- }
482+ ByteSizeValue .ofBytes (params .estimateMemoryUsageBytes ()).toString () }
469483 )
470484 );
471485 }
0 commit comments