diff --git a/docs/changelog/83644.yaml b/docs/changelog/83644.yaml new file mode 100644 index 0000000000000..56c5c4b6bf5e7 --- /dev/null +++ b/docs/changelog/83644.yaml @@ -0,0 +1,5 @@ +pr: 83644 +summary: Wait for model process to be stop in stop deployment +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java index cdc42a05c496d..a7ed543a835d0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; +import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState; import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata; @@ -88,6 +89,11 @@ protected void doExecute( }, listener::onFailure)); return; } + if (allocation.getAllocationState() == AllocationState.STOPPING) { + String message = "Trained model [" + deploymentId + "] is STOPPING"; + listener.onFailure(ExceptionsHelper.conflictStatusException(message)); + return; + } String[] randomRunningNode = allocation.getStartedNodes(); if (randomRunningNode.length == 0) { String message = "Trained model [" + deploymentId + "] is not allocated to any nodes"; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java index 21a06afbf125f..f7481ccc59b3c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java @@ -39,7 +39,6 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationClusterService; import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata; -import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService; import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; import java.util.Collections; @@ -66,7 +65,6 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct private final Client client; private final IngestService ingestService; - private final TrainedModelAllocationService trainedModelAllocationService; private final TrainedModelAllocationClusterService trainedModelAllocationClusterService; @Inject @@ -76,7 +74,6 @@ public TransportStopTrainedModelDeploymentAction( ActionFilters actionFilters, Client client, IngestService ingestService, - TrainedModelAllocationService trainedModelAllocationService, TrainedModelAllocationClusterService trainedModelAllocationClusterService ) { super( @@ -91,7 +88,6 @@ public TransportStopTrainedModelDeploymentAction( ); this.client = new OriginSettingClient(client, ML_ORIGIN); this.ingestService = ingestService; - this.trainedModelAllocationService = trainedModelAllocationService; this.trainedModelAllocationClusterService = trainedModelAllocationClusterService; } @@ -150,6 +146,7 @@ protected void doExecute( } // NOTE, should only run on Master node + assert clusterService.localNode().isMasterNode(); trainedModelAllocationClusterService.setModelAllocationToStopping( modelId, ActionListener.wrap( @@ -196,30 +193,25 @@ private void normalUndeploy( ) { request.setNodes(modelAllocation.getNodeRoutingTable().keySet().toArray(String[]::new)); ActionListener finalListener = ActionListener.wrap(r -> { - waitForTaskRemoved(modelId, modelAllocation, request, r, ActionListener.wrap(waited -> { - trainedModelAllocationService.deleteModelAllocation( - modelId, - ActionListener.wrap(deleted -> listener.onResponse(r), deletionFailed -> { - logger.error( - () -> new ParameterizedMessage( - "[{}] failed to delete model allocation after nodes unallocated the deployment", - modelId - ), + assert clusterService.localNode().isMasterNode(); + trainedModelAllocationClusterService.removeModelAllocation( + modelId, + ActionListener.wrap(deleted -> listener.onResponse(r), deletionFailed -> { + logger.error( + () -> new ParameterizedMessage( + "[{}] failed to delete model allocation after nodes unallocated the deployment", + modelId + ), + deletionFailed + ); + listener.onFailure( + ExceptionsHelper.serverError( + "failed to delete model allocation after nodes unallocated the deployment. Attempt to stop again", deletionFailed - ); - listener.onFailure( - ExceptionsHelper.serverError( - "failed to delete model allocation after nodes unallocated the deployment. Attempt to stop again", - deletionFailed - ) - ); - }) - ); - }, - // TODO should we attempt to delete the deployment here? - listener::onFailure - )); - + ) + ); + }) + ); }, e -> { if (ExceptionsHelper.unwrapCause(e) instanceof FailedNodeException) { // A node has dropped out of the cluster since we started executing the requests. @@ -235,24 +227,6 @@ private void normalUndeploy( super.doExecute(task, request, finalListener); } - void waitForTaskRemoved( - String modelId, - TrainedModelAllocation trainedModelAllocation, - StopTrainedModelDeploymentAction.Request request, - StopTrainedModelDeploymentAction.Response response, - ActionListener listener - ) { - final Set nodesOfConcern = trainedModelAllocation.getNodeRoutingTable().keySet(); - client.admin() - .cluster() - .prepareListTasks(nodesOfConcern.toArray(String[]::new)) - .setDetailed(true) - .setWaitForCompletion(true) - .setActions(modelId) - .setTimeout(request.getTimeout()) - .execute(ActionListener.wrap(complete -> listener.onResponse(response), listener::onFailure)); - } - @Override protected StopTrainedModelDeploymentAction.Response newResponse( StopTrainedModelDeploymentAction.Request request, @@ -275,7 +249,9 @@ protected void taskOperation( TrainedModelDeploymentTask task, ActionListener listener ) { - task.stop("undeploy_trained_model (api)"); - listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)); + task.stop( + "undeploy_trained_model (api)", + ActionListener.wrap(r -> listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)), listener::onFailure) + ); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java index dd80adc9cecd7..5815a843a076b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java @@ -135,7 +135,8 @@ void stopDeploymentAsync(TrainedModelDeploymentTask task, String reason, ActionL if (stopped) { return; } - task.stopWithoutNotification(reason); + task.markAsStopped(reason); + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { try { deploymentManager.stopDeployment(task); @@ -204,20 +205,12 @@ void loadQueuedModels() { loadingModels.addAll(loadingToRetry); } - public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason) { + public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason, ActionListener listener) { ActionListener notifyDeploymentOfStopped = ActionListener.wrap( - _void -> updateStoredState( - task.getModelId(), - new RoutingStateAndReason(RoutingState.STOPPED, reason), - ActionListener.wrap(s -> {}, failure -> {}) - ), + _void -> updateStoredState(task.getModelId(), new RoutingStateAndReason(RoutingState.STOPPED, reason), listener), failed -> { // if we failed to stop the process, something strange is going on, but we should still notify of stop logger.warn(() -> new ParameterizedMessage("[{}] failed to stop due to error", task.getModelId()), failed); - updateStoredState( - task.getModelId(), - new RoutingStateAndReason(RoutingState.STOPPED, reason), - ActionListener.wrap(s -> {}, failure -> {}) - ); + updateStoredState(task.getModelId(), new RoutingStateAndReason(RoutingState.STOPPED, reason), listener); } ); updateStoredState( @@ -309,7 +302,7 @@ public void clusterChanged(ClusterChangedEvent event) { && isResetMode == false) { prepareModelToLoad(trainedModelAllocation.getTaskParams()); } - // This mode is not routed to the current node at all + // This model is not routed to the current node at all if (routingStateAndReason == null) { TrainedModelDeploymentTask task = modelIdToTask.remove(trainedModelAllocation.getTaskParams().getModelId()); if (task != null) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java index 64f9bb58fb664..e2e0162c12cce 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -9,9 +9,11 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.lucene.util.SetOnce; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.core.TimeValue; import org.elasticsearch.license.LicensedFeature; import org.elasticsearch.license.XPackLicenseState; @@ -80,15 +82,11 @@ public TaskParams getParams() { return params; } - public void stop(String reason) { - logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason); - licensedFeature.stopTracking(licenseState, "model-" + params.getModelId()); - stopped = true; - stoppedReasonHolder.trySet(reason); - trainedModelAllocationNodeService.stopDeploymentAndNotify(this, reason); + public void stop(String reason, ActionListener listener) { + trainedModelAllocationNodeService.stopDeploymentAndNotify(this, reason, listener); } - public void stopWithoutNotification(String reason) { + public void markAsStopped(String reason) { licensedFeature.stopTracking(licenseState, "model-" + params.getModelId()); logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason); stoppedReasonHolder.trySet(reason); @@ -106,7 +104,14 @@ public Optional stoppedReason() { @Override protected void onCancelled() { String reason = getReasonCancelled(); - stop(reason); + logger.info("[{}] task cancelled due to reason [{}]", getModelId(), reason); + stop( + reason, + ActionListener.wrap( + acknowledgedResponse -> {}, + e -> logger.error(new ParameterizedMessage("[{}] error stopping the model after task cancellation", getModelId()), e) + ) + ); } public void infer(Map doc, InferenceConfigUpdate update, TimeValue timeout, ActionListener listener) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java index 2852a452a627c..7f19e41135858 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java @@ -196,7 +196,7 @@ public void testLoadQueuedModelsWhenTaskIsStopped() throws Exception { // Only one model should be loaded, the other should be stopped trainedModelAllocationNodeService.prepareModelToLoad(newParams(modelToLoad)); trainedModelAllocationNodeService.prepareModelToLoad(newParams(stoppedModelToLoad)); - trainedModelAllocationNodeService.getTask(stoppedModelToLoad).stop("testing"); + trainedModelAllocationNodeService.getTask(stoppedModelToLoad).stop("testing", ActionListener.wrap(r -> {}, e -> {})); trainedModelAllocationNodeService.loadQueuedModels(); assertBusy(() -> { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java index 330559b495abb..259a5e4ca9827 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.deployment; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.license.LicensedFeature; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.TaskId; @@ -14,12 +15,15 @@ import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig; import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationNodeService; +import org.mockito.ArgumentCaptor; import java.util.Map; import java.util.function.Consumer; import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_ACTION; import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -29,6 +33,15 @@ public class TrainedModelDeploymentTaskTests extends ESTestCase { void assertTrackingComplete(Consumer method, String modelId) { XPackLicenseState licenseState = mock(XPackLicenseState.class); LicensedFeature.Persistent feature = mock(LicensedFeature.Persistent.class); + TrainedModelAllocationNodeService nodeService = mock(TrainedModelAllocationNodeService.class); + + ArgumentCaptor taskCaptor = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); + ArgumentCaptor reasonCaptur = ArgumentCaptor.forClass(String.class); + doAnswer(invocation -> { + taskCaptor.getValue().markAsStopped(reasonCaptur.getValue()); + return null; + }).when(nodeService).stopDeploymentAndNotify(taskCaptor.capture(), reasonCaptur.capture(), any()); + TrainedModelDeploymentTask task = new TrainedModelDeploymentTask( 0, TRAINED_MODEL_ALLOCATION_TASK_TYPE, @@ -42,7 +55,7 @@ void assertTrackingComplete(Consumer method, String randomInt(5), randomInt(5) ), - mock(TrainedModelAllocationNodeService.class), + nodeService, licenseState, feature ); @@ -53,12 +66,12 @@ void assertTrackingComplete(Consumer method, String verify(feature, times(1)).stopTracking(licenseState, "model-" + modelId); } - public void testOnStopWithoutNotification() { - assertTrackingComplete(t -> t.stopWithoutNotification("foo"), randomAlphaOfLength(10)); + public void testMarkAsStopped() { + assertTrackingComplete(t -> t.markAsStopped("foo"), randomAlphaOfLength(10)); } public void testOnStop() { - assertTrackingComplete(t -> t.stop("foo"), randomAlphaOfLength(10)); + assertTrackingComplete(t -> t.stop("foo", ActionListener.wrap(r -> {}, e -> {})), randomAlphaOfLength(10)); } public void testCancelled() {