diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java index c540668d8fdd6..8654a0807c418 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java @@ -18,6 +18,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; @@ -143,11 +144,13 @@ protected void taskOperation( TrainedModelDeploymentTask task, ActionListener listener ) { + assert actionTask instanceof CancellableTask : "task [" + actionTask + "] not cancellable"; task.infer( request.getDocs().get(0), request.getUpdate(), request.isSkipQueue(), request.getInferenceTimeout(), + actionTask, ActionListener.wrap( pyTorchResult -> listener.onResponse(new InferTrainedModelDeploymentAction.Response(pyTorchResult)), listener::onFailure diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java index aa8445647745e..1d48f1d1f2297 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java @@ -277,9 +277,10 @@ public void infer( Map doc, boolean skipQueue, TimeValue timeout, + Task parentActionTask, ActionListener listener ) { - deploymentManager.infer(task, config, doc, skipQueue, timeout, listener); + deploymentManager.infer(task, config, doc, skipQueue, timeout, parentActionTask, listener); } public Optional modelStats(TrainedModelDeploymentTask task) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 35e7f619a8e83..90dccf138fe1a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -20,6 +20,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.query.IdsQueryBuilder; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentFactory; @@ -237,6 +238,7 @@ public void infer( Map doc, boolean skipQueue, TimeValue timeout, + Task parentActionTask, ActionListener listener ) { var processContext = getProcessContext(task, listener::onFailure); @@ -254,6 +256,7 @@ public void infer( config, doc, threadPool, + parentActionTask, listener ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java index 71220194ba58a..720751dd617c5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java @@ -10,7 +10,11 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; @@ -33,6 +37,7 @@ class InferencePyTorchAction extends AbstractPyTorchAction { private final InferenceConfig config; private final Map doc; + private final Task parentActionTask; InferencePyTorchAction( String modelId, @@ -42,11 +47,25 @@ class InferencePyTorchAction extends AbstractPyTorchAction { InferenceConfig config, Map doc, ThreadPool threadPool, + @Nullable Task parentActionTask, ActionListener listener ) { super(modelId, requestId, timeout, processContext, threadPool, listener); this.config = config; this.doc = doc; + this.parentActionTask = parentActionTask; + } + + private boolean isCancelled() { + if (parentActionTask instanceof CancellableTask cancellableTask) { + try { + cancellableTask.ensureNotCancelled(); + } catch (TaskCancelledException ex) { + logger.debug(() -> format("[%s] %s", getModelId(), ex.getMessage())); + return true; + } + } + return false; } @Override @@ -56,12 +75,15 @@ protected void doRun() throws Exception { logger.debug(() -> format("[%s] skipping inference on request [%s] as it has timed out", getModelId(), getRequestId())); return; } + if (isCancelled()) { + onFailure("inference task cancelled"); + return; + } final String requestIdStr = String.valueOf(getRequestId()); try { // The request builder expect a list of inputs which are then batched. - // TODO batching was implemented for expected use-cases such as zero-shot - // classification but is not used here. + // TODO batching was implemented for expected use-cases such as zero-shot classification but is not used here. List text = Collections.singletonList(NlpTask.extractInput(getProcessContext().getModelInput().get(), doc)); NlpTask.Processor processor = getProcessContext().getNlpTaskProcessor().get(); processor.validateInputs(text); @@ -74,6 +96,11 @@ protected void doRun() throws Exception { logger.debug("[{}] [{}] input truncated", getModelId(), getRequestId()); } + // Tokenization is non-trivial, so check for cancellation one last time before sending request to the native process + if (isCancelled()) { + onFailure("inference task cancelled"); + return; + } getProcessContext().getResultProcessor() .registerRequest( requestIdStr, @@ -109,6 +136,10 @@ private void processResult( ); return; } + if (isCancelled()) { + onFailure("inference task cancelled"); + return; + } InferenceResults results = inferenceResultsProcessor.processResult(tokenization, pyTorchResult.inferenceResult()); logger.debug(() -> format("[%s] processed result for request [%s]", getModelId(), getRequestId())); onSuccess(results); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java index 72e706ca595c6..caef67ddab889 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -18,6 +18,7 @@ import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; @@ -132,6 +133,7 @@ public void infer( InferenceConfigUpdate update, boolean skipQueue, TimeValue timeout, + Task parentActionTask, ActionListener listener ) { if (inferenceConfigHolder.get() == null) { @@ -150,7 +152,15 @@ public void infer( ); return; } - trainedModelAssignmentNodeService.infer(this, update.apply(inferenceConfigHolder.get()), doc, skipQueue, timeout, listener); + trainedModelAssignmentNodeService.infer( + this, + update.apply(inferenceConfigHolder.get()), + doc, + skipQueue, + timeout, + parentActionTask, + listener + ); } public Optional modelStats() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java index 10b2813603d59..4350428b221a2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java @@ -101,6 +101,7 @@ public void testRejectedExecution() { Map.of(), false, TimeValue.timeValueMinutes(1), + null, ActionListener.wrap(result -> fail("unexpected success"), e -> assertThat(e, instanceOf(EsRejectedExecutionException.class))) ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java index 4a3a23a6622a2..4590aeb2a8888 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java @@ -8,7 +8,13 @@ package org.elasticsearch.xpack.ml.inference.deployment; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskAwareRequest; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.threadpool.TestThreadPool; @@ -21,6 +27,7 @@ import org.junit.Before; import java.util.Map; +import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME; @@ -64,7 +71,7 @@ public void testInferListenerOnlyCalledOnce() { AtomicInteger timeoutCount = new AtomicInteger(); when(processContext.getTimeoutCount()).thenReturn(timeoutCount); - ListenerCounter listener = new ListenerCounter(); + TestListenerCounter listener = new TestListenerCounter(); InferencePyTorchAction action = new InferencePyTorchAction( "test-model", 1, @@ -73,6 +80,7 @@ public void testInferListenerOnlyCalledOnce() { new PassThroughConfig(null, null, null), Map.of(), tp, + null, listener ); action.init(); @@ -93,6 +101,7 @@ public void testInferListenerOnlyCalledOnce() { new PassThroughConfig(null, null, null), Map.of(), tp, + null, listener ); action.init(); @@ -114,6 +123,7 @@ public void testInferListenerOnlyCalledOnce() { new PassThroughConfig(null, null, null), Map.of(), tp, + null, listener ); action.init(); @@ -134,7 +144,7 @@ public void testRunNotCalledAfterNotified() { AtomicInteger timeoutCount = new AtomicInteger(); when(processContext.getTimeoutCount()).thenReturn(timeoutCount); - ListenerCounter listener = new ListenerCounter(); + TestListenerCounter listener = new TestListenerCounter(); { InferencePyTorchAction action = new InferencePyTorchAction( "test-model", @@ -144,6 +154,7 @@ public void testRunNotCalledAfterNotified() { new PassThroughConfig(null, null, null), Map.of(), tp, + null, listener ); action.init(); @@ -161,6 +172,7 @@ public void testRunNotCalledAfterNotified() { new PassThroughConfig(null, null, null), Map.of(), tp, + null, listener ); action.init(); @@ -170,7 +182,49 @@ public void testRunNotCalledAfterNotified() { } } - static class ListenerCounter implements ActionListener { + public void testCallingRunAfterParentTaskCancellation() throws Exception { + DeploymentManager.ProcessContext processContext = mock(DeploymentManager.ProcessContext.class); + PyTorchResultProcessor resultProcessor = mock(PyTorchResultProcessor.class); + when(processContext.getResultProcessor()).thenReturn(resultProcessor); + AtomicInteger timeoutCount = new AtomicInteger(); + when(processContext.getTimeoutCount()).thenReturn(timeoutCount); + TaskManager taskManager = new TaskManager(Settings.EMPTY, tp, Set.of()); + TestListenerCounter listener = new TestListenerCounter(); + CancellableTask cancellableTask = (CancellableTask) taskManager.register("test_task", "testAction", new TaskAwareRequest() { + @Override + public void setParentTask(TaskId taskId) {} + + @Override + public TaskId getParentTask() { + return TaskId.EMPTY_TASK_ID; + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers); + } + }); + InferencePyTorchAction action = new InferencePyTorchAction( + "test-model", + 1, + TimeValue.MAX_VALUE, + processContext, + new PassThroughConfig(null, null, null), + Map.of(), + tp, + cancellableTask, + listener + ); + action.init(); + taskManager.cancel(cancellableTask, "test", () -> {}); + + action.doRun(); + assertThat(listener.failureCounts, equalTo(1)); + assertThat(listener.responseCounts, equalTo(0)); + verify(resultProcessor, never()).registerRequest(anyString(), any()); + } + + static class TestListenerCounter implements ActionListener { private int responseCounts; private int failureCounts;