Skip to content

Commit 71ab4c4

Browse files
authored
[ML] make trained model rest APIs cancellable (#88009)
This change makes all the trained model APIs cancellable, and addresses the handful of APIs that rely on our abstract resource structure. closes: #87931
1 parent f153c2a commit 71ab4c4

File tree

34 files changed

+311
-138
lines changed

34 files changed

+311
-138
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/action/AbstractGetResourcesRequest.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@
1010
import org.elasticsearch.action.ActionRequestValidationException;
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.tasks.CancellableTask;
14+
import org.elasticsearch.tasks.Task;
15+
import org.elasticsearch.tasks.TaskId;
1316
import org.elasticsearch.xpack.core.action.util.PageParams;
1417

1518
import java.io.IOException;
19+
import java.util.Map;
1620
import java.util.Objects;
1721

1822
public abstract class AbstractGetResourcesRequest extends ActionRequest {
@@ -93,5 +97,12 @@ public boolean equals(Object obj) {
9397
&& allowNoResources == other.allowNoResources;
9498
}
9599

100+
@Override
101+
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
102+
return new CancellableTask(id, type, action, getCancelableTaskDescription(), parentTaskId, headers);
103+
}
104+
105+
public abstract String getCancelableTaskDescription();
106+
96107
public abstract String getResourceIdField();
97108
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/action/AbstractTransportGetResourcesAction.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.search.SearchHit;
2727
import org.elasticsearch.search.builder.SearchSourceBuilder;
2828
import org.elasticsearch.search.sort.SortBuilders;
29+
import org.elasticsearch.tasks.TaskId;
2930
import org.elasticsearch.transport.TransportService;
3031
import org.elasticsearch.xcontent.NamedXContentRegistry;
3132
import org.elasticsearch.xcontent.ParseField;
@@ -73,7 +74,7 @@ protected AbstractTransportGetResourcesAction(
7374
this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
7475
}
7576

76-
protected void searchResources(AbstractGetResourcesRequest request, ActionListener<QueryPage<Resource>> listener) {
77+
protected void searchResources(AbstractGetResourcesRequest request, TaskId parentTaskId, ActionListener<QueryPage<Resource>> listener) {
7778
String[] tokens = Strings.tokenizeToStringArray(request.getResourceId(), ",");
7879
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().sort(
7980
SortBuilders.fieldSort(request.getResourceIdField())
@@ -96,6 +97,7 @@ protected void searchResources(AbstractGetResourcesRequest request, ActionListen
9697
indicesOptions
9798
)
9899
).source(customSearchOptions(sourceBuilder));
100+
searchRequest.setParentTask(parentTaskId);
99101

100102
executeAsyncWithOrigin(
101103
client.threadPool().getThreadContext(),

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import java.io.IOException;
1818

19+
import static org.elasticsearch.core.Strings.format;
20+
1921
public class GetDataFrameAnalyticsAction extends ActionType<GetDataFrameAnalyticsAction.Response> {
2022

2123
public static final GetDataFrameAnalyticsAction INSTANCE = new GetDataFrameAnalyticsAction();
@@ -46,6 +48,11 @@ public Request(StreamInput in) throws IOException {
4648
public String getResourceIdField() {
4749
return DataFrameAnalyticsConfig.ID.getPreferredName();
4850
}
51+
52+
@Override
53+
public String getCancelableTaskDescription() {
54+
return format("get_data_frame_analytics[%s]", getResourceId());
55+
}
4956
}
5057

5158
public static class Response extends AbstractGetResourcesResponse<DataFrameAnalyticsConfig> {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetFiltersAction.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import java.io.IOException;
1919

20+
import static org.elasticsearch.core.Strings.format;
21+
2022
public class GetFiltersAction extends ActionType<GetFiltersAction.Response> {
2123

2224
public static final GetFiltersAction INSTANCE = new GetFiltersAction();
@@ -41,6 +43,11 @@ public Request(StreamInput in) throws IOException {
4143
super(in);
4244
}
4345

46+
@Override
47+
public String getCancelableTaskDescription() {
48+
return format("get_filters[%s]", getResourceId());
49+
}
50+
4451
@Override
4552
public String getResourceIdField() {
4653
return MlFilter.ID.getPreferredName();

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import java.util.Objects;
2626
import java.util.Set;
2727

28+
import static org.elasticsearch.core.Strings.format;
29+
2830
public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Response> {
2931

3032
public static final GetTrainedModelsAction INSTANCE = new GetTrainedModelsAction();
@@ -118,7 +120,6 @@ public int hashCode() {
118120
public static class Request extends AbstractGetResourcesRequest {
119121

120122
public static final ParseField INCLUDE = new ParseField("include");
121-
public static final String DEFINITION = "definition";
122123
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
123124
public static final ParseField TAGS = new ParseField("tags");
124125

@@ -178,6 +179,11 @@ public boolean equals(Object obj) {
178179
Request other = (Request) obj;
179180
return super.equals(obj) && this.includes.equals(other.includes) && Objects.equals(tags, other.tags);
180181
}
182+
183+
@Override
184+
public String getCancelableTaskDescription() {
185+
return format("get_trained_models[%s]", getResourceId());
186+
}
181187
}
182188

183189
public static class Response extends AbstractGetResourcesResponse<TrainedModelConfig> {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import java.util.Set;
3636

3737
import static org.elasticsearch.core.RestApiVersion.onOrAfter;
38+
import static org.elasticsearch.core.Strings.format;
3839

3940
public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStatsAction.Response> {
4041

@@ -68,6 +69,11 @@ public Request(StreamInput in) throws IOException {
6869
super(in);
6970
}
7071

72+
@Override
73+
public String getCancelableTaskDescription() {
74+
return format("get_trained_model_stats[%s]", getResourceId());
75+
}
76+
7177
@Override
7278
public String getResourceIdField() {
7379
return TrainedModelConfig.MODEL_ID.getPreferredName();

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import org.elasticsearch.common.io.stream.StreamInput;
1515
import org.elasticsearch.common.io.stream.StreamOutput;
1616
import org.elasticsearch.core.TimeValue;
17+
import org.elasticsearch.tasks.CancellableTask;
18+
import org.elasticsearch.tasks.Task;
19+
import org.elasticsearch.tasks.TaskId;
1720
import org.elasticsearch.xcontent.ObjectParser;
1821
import org.elasticsearch.xcontent.ParseField;
1922
import org.elasticsearch.xcontent.ToXContentObject;
@@ -31,6 +34,8 @@
3134
import java.util.Objects;
3235
import java.util.stream.Collectors;
3336

37+
import static org.elasticsearch.core.Strings.format;
38+
3439
public class InferModelAction extends ActionType<InferModelAction.Response> {
3540
public static final String NAME = "cluster:internal/xpack/ml/inference/infer";
3641
public static final String EXTERNAL_NAME = "cluster:monitor/xpack/ml/inference/infer";
@@ -176,6 +181,11 @@ public boolean equals(Object o) {
176181
&& Objects.equals(objectsToInfer, that.objectsToInfer);
177182
}
178183

184+
@Override
185+
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
186+
return new CancellableTask(id, type, action, format("infer_trained_model[%s]", modelId), parentTaskId, headers);
187+
}
188+
179189
@Override
180190
public int hashCode() {
181191
return Objects.hash(modelId, objectsToInfer, update, previouslyLicensed, timeout);

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
import org.elasticsearch.common.io.stream.Writeable;
1818
import org.elasticsearch.core.Nullable;
1919
import org.elasticsearch.core.TimeValue;
20+
import org.elasticsearch.tasks.CancellableTask;
2021
import org.elasticsearch.tasks.Task;
22+
import org.elasticsearch.tasks.TaskId;
2123
import org.elasticsearch.xcontent.ObjectParser;
2224
import org.elasticsearch.xcontent.ParseField;
2325
import org.elasticsearch.xcontent.ToXContentObject;
@@ -36,6 +38,7 @@
3638
import java.util.Optional;
3739

3840
import static org.elasticsearch.action.ValidateActions.addValidationError;
41+
import static org.elasticsearch.core.Strings.format;
3942

4043
public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedModelDeploymentAction.Response> {
4144

@@ -192,6 +195,11 @@ public int hashCode() {
192195
return Objects.hash(deploymentId, update, docs, inferenceTimeout);
193196
}
194197

198+
@Override
199+
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
200+
return new CancellableTask(id, type, action, format("infer_trained_model_deployment[%s]", deploymentId), parentTaskId, headers);
201+
}
202+
195203
public static class Builder {
196204

197205
private String deploymentId;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/transform/action/GetTransformAction.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import java.util.Objects;
3333

3434
import static org.elasticsearch.action.ValidateActions.addValidationError;
35+
import static org.elasticsearch.core.Strings.format;
3536

3637
public class GetTransformAction extends ActionType<GetTransformAction.Response> {
3738

@@ -76,6 +77,11 @@ public ActionRequestValidationException validate() {
7677
return exception;
7778
}
7879

80+
@Override
81+
public String getCancelableTaskDescription() {
82+
return format("get_transforms[%s]", getResourceId());
83+
}
84+
7985
@Override
8086
public String getResourceIdField() {
8187
return TransformField.ID.getPreferredName();

x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,15 @@ public void testStoreModelViaChunkedPersister() throws IOException {
110110
PageParams.defaultParams(),
111111
Collections.emptySet(),
112112
ModelAliasMetadata.EMPTY,
113+
null,
113114
getIdsFuture
114115
);
115116
Tuple<Long, Map<String, Set<String>>> ids = getIdsFuture.actionGet();
116117
assertThat(ids.v1(), equalTo(1L));
117118
String inferenceModelId = ids.v2().keySet().iterator().next();
118119

119120
PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
120-
trainedModelProvider.getTrainedModel(inferenceModelId, GetTrainedModelsAction.Includes.all(), getTrainedModelFuture);
121+
trainedModelProvider.getTrainedModel(inferenceModelId, GetTrainedModelsAction.Includes.all(), null, getTrainedModelFuture);
121122

122123
TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
123124
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
@@ -128,7 +129,7 @@ public void testStoreModelViaChunkedPersister() throws IOException {
128129
assertThat(storedConfig.getMetadata(), hasKey("hyperparameters"));
129130

130131
PlainActionFuture<Map<String, TrainedModelMetadata>> getTrainedMetadataFuture = new PlainActionFuture<>();
131-
trainedModelProvider.getTrainedModelMetadata(Collections.singletonList(inferenceModelId), getTrainedMetadataFuture);
132+
trainedModelProvider.getTrainedModelMetadata(Collections.singletonList(inferenceModelId), null, getTrainedMetadataFuture);
132133

133134
TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet().get(inferenceModelId);
134135
assertThat(storedMetadata.getModelId(), startsWith(modelId));

0 commit comments

Comments
 (0)