Skip to content

Commit feb0024

Browse files
authored
[ML] make deployment infer requests fully cancellable (#88649)
When an infer request is made, it may or may not be queued for later execution. If the caller making the inference request stops listening for the result, we should not execute action. This commit allows for infer requests made to deployed models to be cancelled even after they are queued for inference. Related to: #88009
1 parent 6bbe32f commit feb0024

File tree

7 files changed

+110
-7
lines changed

7 files changed

+110
-7
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.cluster.service.ClusterService;
1919
import org.elasticsearch.common.inject.Inject;
2020
import org.elasticsearch.rest.RestStatus;
21+
import org.elasticsearch.tasks.CancellableTask;
2122
import org.elasticsearch.tasks.Task;
2223
import org.elasticsearch.tasks.TaskId;
2324
import org.elasticsearch.threadpool.ThreadPool;
@@ -143,11 +144,13 @@ protected void taskOperation(
143144
TrainedModelDeploymentTask task,
144145
ActionListener<InferTrainedModelDeploymentAction.Response> listener
145146
) {
147+
assert actionTask instanceof CancellableTask : "task [" + actionTask + "] not cancellable";
146148
task.infer(
147149
request.getDocs().get(0),
148150
request.getUpdate(),
149151
request.isSkipQueue(),
150152
request.getInferenceTimeout(),
153+
actionTask,
151154
ActionListener.wrap(
152155
pyTorchResult -> listener.onResponse(new InferTrainedModelDeploymentAction.Response(pyTorchResult)),
153156
listener::onFailure

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,10 @@ public void infer(
277277
Map<String, Object> doc,
278278
boolean skipQueue,
279279
TimeValue timeout,
280+
Task parentActionTask,
280281
ActionListener<InferenceResults> listener
281282
) {
282-
deploymentManager.infer(task, config, doc, skipQueue, timeout, listener);
283+
deploymentManager.infer(task, config, doc, skipQueue, timeout, parentActionTask, listener);
283284
}
284285

285286
public Optional<ModelStats> modelStats(TrainedModelDeploymentTask task) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.core.TimeValue;
2121
import org.elasticsearch.index.query.IdsQueryBuilder;
2222
import org.elasticsearch.search.SearchHit;
23+
import org.elasticsearch.tasks.Task;
2324
import org.elasticsearch.threadpool.ThreadPool;
2425
import org.elasticsearch.xcontent.NamedXContentRegistry;
2526
import org.elasticsearch.xcontent.XContentFactory;
@@ -237,6 +238,7 @@ public void infer(
237238
Map<String, Object> doc,
238239
boolean skipQueue,
239240
TimeValue timeout,
241+
Task parentActionTask,
240242
ActionListener<InferenceResults> listener
241243
) {
242244
var processContext = getProcessContext(task, listener::onFailure);
@@ -254,6 +256,7 @@ public void infer(
254256
config,
255257
doc,
256258
threadPool,
259+
parentActionTask,
257260
listener
258261
);
259262

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
1212
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.core.Nullable;
1314
import org.elasticsearch.core.TimeValue;
15+
import org.elasticsearch.tasks.CancellableTask;
16+
import org.elasticsearch.tasks.Task;
17+
import org.elasticsearch.tasks.TaskCancelledException;
1418
import org.elasticsearch.threadpool.ThreadPool;
1519
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
1620
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
@@ -33,6 +37,7 @@ class InferencePyTorchAction extends AbstractPyTorchAction<InferenceResults> {
3337

3438
private final InferenceConfig config;
3539
private final Map<String, Object> doc;
40+
private final Task parentActionTask;
3641

3742
InferencePyTorchAction(
3843
String modelId,
@@ -42,11 +47,25 @@ class InferencePyTorchAction extends AbstractPyTorchAction<InferenceResults> {
4247
InferenceConfig config,
4348
Map<String, Object> doc,
4449
ThreadPool threadPool,
50+
@Nullable Task parentActionTask,
4551
ActionListener<InferenceResults> listener
4652
) {
4753
super(modelId, requestId, timeout, processContext, threadPool, listener);
4854
this.config = config;
4955
this.doc = doc;
56+
this.parentActionTask = parentActionTask;
57+
}
58+
59+
private boolean isCancelled() {
60+
if (parentActionTask instanceof CancellableTask cancellableTask) {
61+
try {
62+
cancellableTask.ensureNotCancelled();
63+
} catch (TaskCancelledException ex) {
64+
logger.debug(() -> format("[%s] %s", getModelId(), ex.getMessage()));
65+
return true;
66+
}
67+
}
68+
return false;
5069
}
5170

5271
@Override
@@ -56,12 +75,15 @@ protected void doRun() throws Exception {
5675
logger.debug(() -> format("[%s] skipping inference on request [%s] as it has timed out", getModelId(), getRequestId()));
5776
return;
5877
}
78+
if (isCancelled()) {
79+
onFailure("inference task cancelled");
80+
return;
81+
}
5982

6083
final String requestIdStr = String.valueOf(getRequestId());
6184
try {
6285
// The request builder expect a list of inputs which are then batched.
63-
// TODO batching was implemented for expected use-cases such as zero-shot
64-
// classification but is not used here.
86+
// TODO batching was implemented for expected use-cases such as zero-shot classification but is not used here.
6587
List<String> text = Collections.singletonList(NlpTask.extractInput(getProcessContext().getModelInput().get(), doc));
6688
NlpTask.Processor processor = getProcessContext().getNlpTaskProcessor().get();
6789
processor.validateInputs(text);
@@ -74,6 +96,11 @@ protected void doRun() throws Exception {
7496
logger.debug("[{}] [{}] input truncated", getModelId(), getRequestId());
7597
}
7698

99+
// Tokenization is non-trivial, so check for cancellation one last time before sending request to the native process
100+
if (isCancelled()) {
101+
onFailure("inference task cancelled");
102+
return;
103+
}
77104
getProcessContext().getResultProcessor()
78105
.registerRequest(
79106
requestIdStr,
@@ -109,6 +136,10 @@ private void processResult(
109136
);
110137
return;
111138
}
139+
if (isCancelled()) {
140+
onFailure("inference task cancelled");
141+
return;
142+
}
112143
InferenceResults results = inferenceResultsProcessor.processResult(tokenization, pyTorchResult.inferenceResult());
113144
logger.debug(() -> format("[%s] processed result for request [%s]", getModelId(), getRequestId()));
114145
onSuccess(results);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.license.XPackLicenseState;
1919
import org.elasticsearch.rest.RestStatus;
2020
import org.elasticsearch.tasks.CancellableTask;
21+
import org.elasticsearch.tasks.Task;
2122
import org.elasticsearch.tasks.TaskId;
2223
import org.elasticsearch.xpack.core.ml.MlTasks;
2324
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
@@ -132,6 +133,7 @@ public void infer(
132133
InferenceConfigUpdate update,
133134
boolean skipQueue,
134135
TimeValue timeout,
136+
Task parentActionTask,
135137
ActionListener<InferenceResults> listener
136138
) {
137139
if (inferenceConfigHolder.get() == null) {
@@ -150,7 +152,15 @@ public void infer(
150152
);
151153
return;
152154
}
153-
trainedModelAssignmentNodeService.infer(this, update.apply(inferenceConfigHolder.get()), doc, skipQueue, timeout, listener);
155+
trainedModelAssignmentNodeService.infer(
156+
this,
157+
update.apply(inferenceConfigHolder.get()),
158+
doc,
159+
skipQueue,
160+
timeout,
161+
parentActionTask,
162+
listener
163+
);
154164
}
155165

156166
public Optional<ModelStats> modelStats() {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ public void testRejectedExecution() {
101101
Map.of(),
102102
false,
103103
TimeValue.timeValueMinutes(1),
104+
null,
104105
ActionListener.wrap(result -> fail("unexpected success"), e -> assertThat(e, instanceOf(EsRejectedExecutionException.class)))
105106
);
106107

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
package org.elasticsearch.xpack.ml.inference.deployment;
99

1010
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.common.settings.Settings;
1112
import org.elasticsearch.core.TimeValue;
13+
import org.elasticsearch.tasks.CancellableTask;
14+
import org.elasticsearch.tasks.Task;
15+
import org.elasticsearch.tasks.TaskAwareRequest;
16+
import org.elasticsearch.tasks.TaskId;
17+
import org.elasticsearch.tasks.TaskManager;
1218
import org.elasticsearch.test.ESTestCase;
1319
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
1420
import org.elasticsearch.threadpool.TestThreadPool;
@@ -21,6 +27,7 @@
2127
import org.junit.Before;
2228

2329
import java.util.Map;
30+
import java.util.Set;
2431
import java.util.concurrent.atomic.AtomicInteger;
2532

2633
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
@@ -64,7 +71,7 @@ public void testInferListenerOnlyCalledOnce() {
6471
AtomicInteger timeoutCount = new AtomicInteger();
6572
when(processContext.getTimeoutCount()).thenReturn(timeoutCount);
6673

67-
ListenerCounter listener = new ListenerCounter();
74+
TestListenerCounter listener = new TestListenerCounter();
6875
InferencePyTorchAction action = new InferencePyTorchAction(
6976
"test-model",
7077
1,
@@ -73,6 +80,7 @@ public void testInferListenerOnlyCalledOnce() {
7380
new PassThroughConfig(null, null, null),
7481
Map.of(),
7582
tp,
83+
null,
7684
listener
7785
);
7886
action.init();
@@ -93,6 +101,7 @@ public void testInferListenerOnlyCalledOnce() {
93101
new PassThroughConfig(null, null, null),
94102
Map.of(),
95103
tp,
104+
null,
96105
listener
97106
);
98107
action.init();
@@ -114,6 +123,7 @@ public void testInferListenerOnlyCalledOnce() {
114123
new PassThroughConfig(null, null, null),
115124
Map.of(),
116125
tp,
126+
null,
117127
listener
118128
);
119129
action.init();
@@ -134,7 +144,7 @@ public void testRunNotCalledAfterNotified() {
134144
AtomicInteger timeoutCount = new AtomicInteger();
135145
when(processContext.getTimeoutCount()).thenReturn(timeoutCount);
136146

137-
ListenerCounter listener = new ListenerCounter();
147+
TestListenerCounter listener = new TestListenerCounter();
138148
{
139149
InferencePyTorchAction action = new InferencePyTorchAction(
140150
"test-model",
@@ -144,6 +154,7 @@ public void testRunNotCalledAfterNotified() {
144154
new PassThroughConfig(null, null, null),
145155
Map.of(),
146156
tp,
157+
null,
147158
listener
148159
);
149160
action.init();
@@ -161,6 +172,7 @@ public void testRunNotCalledAfterNotified() {
161172
new PassThroughConfig(null, null, null),
162173
Map.of(),
163174
tp,
175+
null,
164176
listener
165177
);
166178
action.init();
@@ -170,7 +182,49 @@ public void testRunNotCalledAfterNotified() {
170182
}
171183
}
172184

173-
static class ListenerCounter implements ActionListener<InferenceResults> {
185+
public void testCallingRunAfterParentTaskCancellation() throws Exception {
186+
DeploymentManager.ProcessContext processContext = mock(DeploymentManager.ProcessContext.class);
187+
PyTorchResultProcessor resultProcessor = mock(PyTorchResultProcessor.class);
188+
when(processContext.getResultProcessor()).thenReturn(resultProcessor);
189+
AtomicInteger timeoutCount = new AtomicInteger();
190+
when(processContext.getTimeoutCount()).thenReturn(timeoutCount);
191+
TaskManager taskManager = new TaskManager(Settings.EMPTY, tp, Set.of());
192+
TestListenerCounter listener = new TestListenerCounter();
193+
CancellableTask cancellableTask = (CancellableTask) taskManager.register("test_task", "testAction", new TaskAwareRequest() {
194+
@Override
195+
public void setParentTask(TaskId taskId) {}
196+
197+
@Override
198+
public TaskId getParentTask() {
199+
return TaskId.EMPTY_TASK_ID;
200+
}
201+
202+
@Override
203+
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
204+
return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers);
205+
}
206+
});
207+
InferencePyTorchAction action = new InferencePyTorchAction(
208+
"test-model",
209+
1,
210+
TimeValue.MAX_VALUE,
211+
processContext,
212+
new PassThroughConfig(null, null, null),
213+
Map.of(),
214+
tp,
215+
cancellableTask,
216+
listener
217+
);
218+
action.init();
219+
taskManager.cancel(cancellableTask, "test", () -> {});
220+
221+
action.doRun();
222+
assertThat(listener.failureCounts, equalTo(1));
223+
assertThat(listener.responseCounts, equalTo(0));
224+
verify(resultProcessor, never()).registerRequest(anyString(), any());
225+
}
226+
227+
static class TestListenerCounter implements ActionListener<InferenceResults> {
174228
private int responseCounts;
175229
private int failureCounts;
176230

0 commit comments

Comments
 (0)