Skip to content

Commit f1d8593

Browse files
authored
[ML] integrating feature reset for trained model deployments (#76126)
this integrates removing all model deployments in the ML feature reset action.
1 parent 7e7b5c9 commit f1d8593

File tree

12 files changed

+313
-20
lines changed

12 files changed

+313
-20
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,6 @@ public static PersistentTasksCustomMetadata.PersistentTask<?> getSnapshotUpgrade
126126
return tasks == null ? null : tasks.getTask(snapshotUpgradeTaskId(jobId, snapshotId));
127127
}
128128

129-
@Nullable
130-
public static PersistentTasksCustomMetadata.PersistentTask<?> getTrainedModelDeploymentTask(
131-
String modelId, @Nullable PersistentTasksCustomMetadata tasks) {
132-
return tasks == null ? null : tasks.getTask(trainedModelDeploymentTaskId(modelId));
133-
}
134-
135129
/**
136130
* Note that the return value of this method does NOT take node relocations into account.
137131
* Use {@link #getJobStateModifiedForReassignments} to return a value adjusted to the most

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.elasticsearch.common.xcontent.ToXContentObject;
2626
import org.elasticsearch.common.xcontent.XContentBuilder;
2727
import org.elasticsearch.tasks.Task;
28-
import org.elasticsearch.xpack.core.ml.MlTasks;
2928
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
3029
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
3130
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -35,6 +34,8 @@
3534
import java.util.Objects;
3635
import java.util.concurrent.TimeUnit;
3736

37+
import static org.elasticsearch.xpack.core.ml.MlTasks.trainedModelDeploymentTaskId;
38+
3839
public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedModelAllocationAction.Response> {
3940

4041
public static final StartTrainedModelDeploymentAction INSTANCE = new StartTrainedModelDeploymentAction();
@@ -237,7 +238,7 @@ static boolean match(Task task, String expectedId) {
237238
if (Strings.isAllOrWildcard(expectedId)) {
238239
return true;
239240
}
240-
String expectedDescription = MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX + expectedId;
241+
String expectedDescription = trainedModelDeploymentTaskId(expectedId);
241242
return expectedDescription.equals(task.getDescription());
242243
}
243244
return false;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/annotations/AnnotationIndex.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
*/
77
package org.elasticsearch.xpack.core.ml.annotations;
88

9+
import org.apache.logging.log4j.LogManager;
10+
import org.apache.logging.log4j.Logger;
11+
import org.apache.logging.log4j.message.ParameterizedMessage;
912
import org.elasticsearch.ResourceAlreadyExistsException;
1013
import org.elasticsearch.Version;
1114
import org.elasticsearch.action.ActionListener;
@@ -34,6 +37,8 @@
3437

3538
public class AnnotationIndex {
3639

40+
private static final Logger logger = LogManager.getLogger(AnnotationIndex.class);
41+
3742
public static final String READ_ALIAS_NAME = ".ml-annotations-read";
3843
public static final String WRITE_ALIAS_NAME = ".ml-annotations-write";
3944
// Exposed for testing, but always use the aliases in non-test code
@@ -100,6 +105,14 @@ public static void createAnnotationsIndexIfNecessary(Client client, ClusterState
100105

101106
// Create the annotations index if it doesn't exist already.
102107
if (mlLookup.containsKey(INDEX_NAME) == false) {
108+
logger.debug(
109+
() -> new ParameterizedMessage(
110+
"Creating [{}] because [{}] exists; trace {}",
111+
INDEX_NAME,
112+
mlLookup.firstKey(),
113+
org.elasticsearch.ExceptionsHelper.formatStackTrace(Thread.currentThread().getStackTrace())
114+
)
115+
);
103116

104117
CreateIndexRequest createIndexRequest =
105118
new CreateIndexRequest(INDEX_NAME)

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
9090
import org.elasticsearch.xpack.ml.autoscaling.MlScalingReason;
9191
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
92+
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
9293
import org.elasticsearch.xpack.transform.Transform;
9394

9495
import java.io.IOException;
@@ -280,6 +281,20 @@ protected void ensureClusterStateConsistency() throws IOException {
280281
if (cluster() != null && cluster().size() > 0) {
281282
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>(ClusterModule.getNamedWriteables());
282283
entries.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables());
284+
entries.add(
285+
new NamedWriteableRegistry.Entry(
286+
Metadata.Custom.class,
287+
TrainedModelAllocationMetadata.NAME,
288+
TrainedModelAllocationMetadata::new
289+
)
290+
);
291+
entries.add(
292+
new NamedWriteableRegistry.Entry(
293+
NamedDiff.class,
294+
TrainedModelAllocationMetadata.NAME,
295+
TrainedModelAllocationMetadata::readDiffFrom
296+
)
297+
);
283298
entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelAliasMetadata.NAME, ModelAliasMetadata::new));
284299
entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelAliasMetadata.NAME, ModelAliasMetadata::readDiffFrom));
285300
entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, "ml", MlMetadata::new));

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ public void unsetLogging() throws IOException {
8383

8484
private static final String MODEL_INDEX = "model_store";
8585
private static final String MODEL_ID ="simple_model_to_evaluate";
86-
private static final String BASE_64_ENCODED_MODEL =
86+
static final String BASE_64_ENCODED_MODEL =
8787
"UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwp" +
8888
"TdXBlclNpbXBsZQpxACmBfShYCAAAAHRyYWluaW5ncQGIdWJxAi5QSwcIXOpBBDQAAAA0AAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAA" +
8989
"AAAAAdAEEAc2ltcGxlbW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQj0AWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaW" +
@@ -106,7 +106,7 @@ public void unsetLogging() throws IOException {
106106
"EsBAgAAAAAICAAAAAAAANGeZ1UCAAAAAgAAABMAAAAAAAAAAAAAAAAAFAQAAHNpbXBsZW1vZGVsL3ZlcnNpb25QSwYGLAAAAAAAAAAe" +
107107
"Ay0AAAAAAAAAAAAFAAAAAAAAAAUAAAAAAAAAagEAAAAAAACSBAAAAAAAAFBLBgcAAAAA/AUAAAAAAAABAAAAUEsFBgAAAAAFAAUAagE" +
108108
"AAJIEAAAAAA==";
109-
private static final int RAW_MODEL_SIZE; // size of the model before base64 encoding
109+
static final int RAW_MODEL_SIZE; // size of the model before base64 encoding
110110
static {
111111
RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length;
112112
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,35 +13,49 @@
1313
import org.elasticsearch.action.ingest.DeletePipelineRequest;
1414
import org.elasticsearch.action.ingest.PutPipelineAction;
1515
import org.elasticsearch.action.ingest.PutPipelineRequest;
16+
import org.elasticsearch.action.support.WriteRequest;
1617
import org.elasticsearch.cluster.ClusterState;
1718
import org.elasticsearch.common.bytes.BytesArray;
1819
import org.elasticsearch.common.xcontent.XContentType;
20+
import org.elasticsearch.tasks.TaskInfo;
1921
import org.elasticsearch.xpack.core.ml.MlMetadata;
2022
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
23+
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
2124
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
25+
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
2226
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
2327
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
2428
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
2529
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
30+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
31+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
32+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
33+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
34+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
2635
import org.elasticsearch.xpack.core.ml.job.config.Job;
2736
import org.elasticsearch.xpack.core.ml.job.config.JobState;
2837
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts;
2938
import org.junit.After;
3039

40+
import java.util.Arrays;
3141
import java.util.Collections;
3242
import java.util.HashSet;
43+
import java.util.List;
3344
import java.util.Set;
3445
import java.util.concurrent.TimeUnit;
46+
import java.util.stream.Collectors;
3547

3648
import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.Factory.countNumberInferenceProcessors;
3749
import static org.elasticsearch.xpack.ml.integration.ClassificationIT.KEYWORD_FIELD;
3850
import static org.elasticsearch.xpack.ml.integration.MlNativeDataFrameAnalyticsIntegTestCase.buildAnalytics;
51+
import static org.elasticsearch.xpack.ml.integration.PyTorchModelIT.BASE_64_ENCODED_MODEL;
52+
import static org.elasticsearch.xpack.ml.integration.PyTorchModelIT.RAW_MODEL_SIZE;
3953
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.createDatafeed;
4054
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.createScheduledJob;
4155
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.getDataCounts;
4256
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.indexDocs;
4357
import static org.hamcrest.Matchers.containsString;
44-
import static org.hamcrest.Matchers.emptyArray;
58+
import static org.hamcrest.Matchers.empty;
4559
import static org.hamcrest.Matchers.equalTo;
4660
import static org.hamcrest.Matchers.greaterThan;
4761
import static org.hamcrest.Matchers.is;
@@ -51,6 +65,7 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
5165
private final Set<String> createdPipelines = new HashSet<>();
5266
private final Set<String> jobIds = new HashSet<>();
5367
private final Set<String> datafeedIds = new HashSet<>();
68+
private static final String TRAINED_MODEL_ID = "trained-model-to-reset";
5469

5570
void cleanupDatafeed(String datafeedId) {
5671
try {
@@ -122,7 +137,10 @@ public void testMLFeatureReset() throws Exception {
122137
ResetFeatureStateAction.INSTANCE,
123138
new ResetFeatureStateRequest()
124139
).actionGet();
125-
assertBusy(() -> assertThat(client().admin().indices().prepareGetIndex().addIndices(".ml*").get().indices(), emptyArray()));
140+
assertBusy(() -> {
141+
List<String> indices = Arrays.asList(client().admin().indices().prepareGetIndex().addIndices(".ml*").get().indices());
142+
assertThat(indices.toString(), indices, is(empty()));
143+
});
126144
assertThat(isResetMode(), is(false));
127145
// If we have succeeded, clear the jobs and datafeeds so that the delete API doesn't recreate the notifications index
128146
jobIds.clear();
@@ -147,6 +165,94 @@ public void testMLFeatureResetFailureDueToPipelines() throws Exception {
147165
assertThat(isResetMode(), is(false));
148166
}
149167

168+
public void testMLFeatureResetWithModelDeployment() throws Exception {
169+
createModelDeployment();
170+
client().execute(
171+
ResetFeatureStateAction.INSTANCE,
172+
new ResetFeatureStateRequest()
173+
).actionGet();
174+
assertBusy(() -> {
175+
List<String> indices = Arrays.asList(client().admin().indices().prepareGetIndex().addIndices(".ml*").get().indices());
176+
assertThat(indices.toString(), indices, is(empty()));
177+
});
178+
assertThat(isResetMode(), is(false));
179+
List<String> tasksNames = client().admin()
180+
.cluster()
181+
.prepareListTasks()
182+
.setActions("xpack/ml/*")
183+
.get()
184+
.getTasks()
185+
.stream()
186+
.map(TaskInfo::getAction)
187+
.collect(Collectors.toList());
188+
assertThat(tasksNames, is(empty()));
189+
}
190+
191+
void createModelDeployment() {
192+
String indexname = "model_store";
193+
client().admin().indices().prepareCreate(indexname).setMapping(
194+
" {\"properties\": {\n" +
195+
" \"doc_type\": { \"type\": \"keyword\" },\n" +
196+
" \"model_id\": { \"type\": \"keyword\" },\n" +
197+
" \"definition_length\": { \"type\": \"long\" },\n" +
198+
" \"total_definition_length\": { \"type\": \"long\" },\n" +
199+
" \"compression_version\": { \"type\": \"long\" },\n" +
200+
" \"definition\": { \"type\": \"binary\" },\n" +
201+
" \"eos\": { \"type\": \"boolean\" },\n" +
202+
" \"task_type\": { \"type\": \"keyword\" },\n" +
203+
" \"vocab\": { \"type\": \"keyword\" },\n" +
204+
" \"with_special_tokens\": { \"type\": \"boolean\" },\n" +
205+
" \"do_lower_case\": { \"type\": \"boolean\" }\n" +
206+
" }\n" +
207+
" }}"
208+
).get();
209+
client().prepareIndex(indexname)
210+
.setId(TRAINED_MODEL_ID + "_task_config")
211+
.setSource(
212+
"{ " +
213+
"\"task_type\": \"bert_pass_through\",\n" +
214+
"\"with_special_tokens\": false," +
215+
"\"vocab\": [\"these\", \"are\", \"my\", \"words\"]\n" +
216+
"}",
217+
XContentType.JSON
218+
).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
219+
.get();
220+
client().prepareIndex(indexname)
221+
.setId("trained_model_definition_doc-" + TRAINED_MODEL_ID + "-0")
222+
.setSource(
223+
"{ " +
224+
"\"doc_type\": \"trained_model_definition_doc\"," +
225+
"\"model_id\": \"" + TRAINED_MODEL_ID +"\"," +
226+
"\"doc_num\": 0," +
227+
"\"definition_length\":" + RAW_MODEL_SIZE + "," +
228+
"\"total_definition_length\":" + RAW_MODEL_SIZE + "," +
229+
"\"compression_version\": 1," +
230+
"\"definition\": \"" + BASE_64_ENCODED_MODEL + "\"," +
231+
"\"eos\": true" +
232+
"}",
233+
XContentType.JSON
234+
).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
235+
.get();
236+
client()
237+
.execute(
238+
PutTrainedModelAction.INSTANCE,
239+
new PutTrainedModelAction.Request(
240+
TrainedModelConfig.builder()
241+
.setModelType(TrainedModelType.PYTORCH)
242+
.setInferenceConfig(new ClassificationConfig(1))
243+
.setInput(new TrainedModelInput(Arrays.asList("text_field")))
244+
.setLocation(new IndexLocation(TRAINED_MODEL_ID, indexname))
245+
.setModelId(TRAINED_MODEL_ID)
246+
.build()
247+
)
248+
)
249+
.actionGet();
250+
client().execute(
251+
StartTrainedModelDeploymentAction.INSTANCE,
252+
new StartTrainedModelDeploymentAction.Request(TRAINED_MODEL_ID)
253+
).actionGet();
254+
}
255+
150256
private boolean isResetMode() {
151257
ClusterState state = client().admin().cluster().prepareState().get().getState();
152258
return MlMetadata.getMlMetadata(state).isResetMode();

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
567567
private final SetOnce<ModelLoadingService> modelLoadingService = new SetOnce<>();
568568
private final SetOnce<MlAutoscalingDeciderService> mlAutoscalingDeciderService = new SetOnce<>();
569569
private final SetOnce<DeploymentManager> deploymentManager = new SetOnce<>();
570+
private final SetOnce<TrainedModelAllocationClusterService> trainedModelAllocationClusterServiceSetOnce = new SetOnce<>();
570571

571572
public MachineLearning(Settings settings, Path configPath) {
572573
this.settings = settings;
@@ -870,11 +871,11 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
870871
clusterService,
871872
threadPool
872873
);
873-
final TrainedModelAllocationClusterService trainedModelAllocationClusterService = new TrainedModelAllocationClusterService(
874+
trainedModelAllocationClusterServiceSetOnce.set(new TrainedModelAllocationClusterService(
874875
settings,
875876
clusterService,
876877
new NodeLoadDetector(memoryTracker)
877-
);
878+
));
878879

879880
mlAutoscalingDeciderService.set(new MlAutoscalingDeciderService(memoryTracker, settings, clusterService));
880881

@@ -905,7 +906,7 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
905906
modelLoadingService,
906907
trainedModelProvider,
907908
trainedModelAllocationService,
908-
trainedModelAllocationClusterService,
909+
trainedModelAllocationClusterServiceSetOnce.get(),
909910
deploymentManager.get()
910911
);
911912
}
@@ -1375,7 +1376,10 @@ public void cleanUpFeature(
13751376

13761377
ActionListener<ResetFeatureStateResponse.ResetFeatureStateStatus> unsetResetModeListener = ActionListener.wrap(
13771378
success -> client.execute(SetResetModeAction.INSTANCE, SetResetModeActionRequest.disabled(true), ActionListener.wrap(
1378-
resetSuccess -> finalListener.onResponse(success),
1379+
resetSuccess -> {
1380+
finalListener.onResponse(success);
1381+
logger.info("Finished machine learning feature reset");
1382+
},
13791383
resetFailure -> {
13801384
logger.error("failed to disable reset mode after state otherwise successful machine learning reset", resetFailure);
13811385
finalListener.onFailure(
@@ -1434,6 +1438,7 @@ public void cleanUpFeature(
14341438
client.admin()
14351439
.cluster()
14361440
.prepareListTasks()
1441+
// This waits for all xpack actions including: allocations, anomaly detections, analytics
14371442
.setActions("xpack/ml/*")
14381443
.setWaitForCompletion(true)
14391444
.execute(ActionListener.wrap(
@@ -1504,7 +1509,7 @@ public void cleanUpFeature(
15041509
}, unsetResetModeListener::onFailure);
15051510

15061511
// Stop data feeds
1507-
ActionListener<AcknowledgedResponse> pipelineValidation = ActionListener.wrap(
1512+
ActionListener<AcknowledgedResponse> stopDeploymentsListener = ActionListener.wrap(
15081513
acknowledgedResponse -> {
15091514
StopDatafeedAction.Request stopDatafeedsReq = new StopDatafeedAction.Request("_all")
15101515
.setAllowNoMatch(true);
@@ -1519,6 +1524,18 @@ public void cleanUpFeature(
15191524
unsetResetModeListener::onFailure
15201525
);
15211526

1527+
// Stop all model deployments
1528+
ActionListener<AcknowledgedResponse> pipelineValidation = ActionListener.wrap(
1529+
acknowledgedResponse -> {
1530+
if (trainedModelAllocationClusterServiceSetOnce.get() == null) {
1531+
stopDeploymentsListener.onResponse(AcknowledgedResponse.TRUE);
1532+
return;
1533+
}
1534+
trainedModelAllocationClusterServiceSetOnce.get().removeAllModelAllocations(stopDeploymentsListener);
1535+
},
1536+
unsetResetModeListener::onFailure
1537+
);
1538+
15221539
// validate no pipelines are using machine learning models
15231540
ActionListener<AcknowledgedResponse> afterResetModeSet = ActionListener.wrap(
15241541
acknowledgedResponse -> {

0 commit comments

Comments
 (0)