1313import org .elasticsearch .action .ingest .DeletePipelineRequest ;
1414import org .elasticsearch .action .ingest .PutPipelineAction ;
1515import org .elasticsearch .action .ingest .PutPipelineRequest ;
16+ import org .elasticsearch .action .support .WriteRequest ;
1617import org .elasticsearch .cluster .ClusterState ;
1718import org .elasticsearch .common .bytes .BytesArray ;
1819import org .elasticsearch .common .xcontent .XContentType ;
20+ import org .elasticsearch .tasks .TaskInfo ;
1921import org .elasticsearch .xpack .core .ml .MlMetadata ;
2022import org .elasticsearch .xpack .core .ml .action .PutDataFrameAnalyticsAction ;
23+ import org .elasticsearch .xpack .core .ml .action .PutTrainedModelAction ;
2124import org .elasticsearch .xpack .core .ml .action .StartDataFrameAnalyticsAction ;
25+ import org .elasticsearch .xpack .core .ml .action .StartTrainedModelDeploymentAction ;
2226import org .elasticsearch .xpack .core .ml .datafeed .DatafeedConfig ;
2327import org .elasticsearch .xpack .core .ml .dataframe .DataFrameAnalyticsConfig ;
2428import org .elasticsearch .xpack .core .ml .dataframe .analyses .BoostedTreeParams ;
2529import org .elasticsearch .xpack .core .ml .dataframe .analyses .Classification ;
30+ import org .elasticsearch .xpack .core .ml .inference .TrainedModelConfig ;
31+ import org .elasticsearch .xpack .core .ml .inference .TrainedModelInput ;
32+ import org .elasticsearch .xpack .core .ml .inference .TrainedModelType ;
33+ import org .elasticsearch .xpack .core .ml .inference .trainedmodel .ClassificationConfig ;
34+ import org .elasticsearch .xpack .core .ml .inference .trainedmodel .IndexLocation ;
2635import org .elasticsearch .xpack .core .ml .job .config .Job ;
2736import org .elasticsearch .xpack .core .ml .job .config .JobState ;
2837import org .elasticsearch .xpack .core .ml .job .process .autodetect .state .DataCounts ;
2938import org .junit .After ;
3039
40+ import java .util .Arrays ;
3141import java .util .Collections ;
3242import java .util .HashSet ;
43+ import java .util .List ;
3344import java .util .Set ;
3445import java .util .concurrent .TimeUnit ;
46+ import java .util .stream .Collectors ;
3547
3648import static org .elasticsearch .xpack .ml .inference .ingest .InferenceProcessor .Factory .countNumberInferenceProcessors ;
3749import static org .elasticsearch .xpack .ml .integration .ClassificationIT .KEYWORD_FIELD ;
3850import static org .elasticsearch .xpack .ml .integration .MlNativeDataFrameAnalyticsIntegTestCase .buildAnalytics ;
51+ import static org .elasticsearch .xpack .ml .integration .PyTorchModelIT .BASE_64_ENCODED_MODEL ;
52+ import static org .elasticsearch .xpack .ml .integration .PyTorchModelIT .RAW_MODEL_SIZE ;
3953import static org .elasticsearch .xpack .ml .support .BaseMlIntegTestCase .createDatafeed ;
4054import static org .elasticsearch .xpack .ml .support .BaseMlIntegTestCase .createScheduledJob ;
4155import static org .elasticsearch .xpack .ml .support .BaseMlIntegTestCase .getDataCounts ;
4256import static org .elasticsearch .xpack .ml .support .BaseMlIntegTestCase .indexDocs ;
4357import static org .hamcrest .Matchers .containsString ;
44- import static org .hamcrest .Matchers .emptyArray ;
58+ import static org .hamcrest .Matchers .empty ;
4559import static org .hamcrest .Matchers .equalTo ;
4660import static org .hamcrest .Matchers .greaterThan ;
4761import static org .hamcrest .Matchers .is ;
@@ -51,6 +65,7 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
5165 private final Set <String > createdPipelines = new HashSet <>();
5266 private final Set <String > jobIds = new HashSet <>();
5367 private final Set <String > datafeedIds = new HashSet <>();
68+ private static final String TRAINED_MODEL_ID = "trained-model-to-reset" ;
5469
5570 void cleanupDatafeed (String datafeedId ) {
5671 try {
@@ -122,7 +137,10 @@ public void testMLFeatureReset() throws Exception {
122137 ResetFeatureStateAction .INSTANCE ,
123138 new ResetFeatureStateRequest ()
124139 ).actionGet ();
125- assertBusy (() -> assertThat (client ().admin ().indices ().prepareGetIndex ().addIndices (".ml*" ).get ().indices (), emptyArray ()));
140+ assertBusy (() -> {
141+ List <String > indices = Arrays .asList (client ().admin ().indices ().prepareGetIndex ().addIndices (".ml*" ).get ().indices ());
142+ assertThat (indices .toString (), indices , is (empty ()));
143+ });
126144 assertThat (isResetMode (), is (false ));
127145 // If we have succeeded, clear the jobs and datafeeds so that the delete API doesn't recreate the notifications index
128146 jobIds .clear ();
@@ -147,6 +165,94 @@ public void testMLFeatureResetFailureDueToPipelines() throws Exception {
147165 assertThat (isResetMode (), is (false ));
148166 }
149167
168+ public void testMLFeatureResetWithModelDeployment () throws Exception {
169+ createModelDeployment ();
170+ client ().execute (
171+ ResetFeatureStateAction .INSTANCE ,
172+ new ResetFeatureStateRequest ()
173+ ).actionGet ();
174+ assertBusy (() -> {
175+ List <String > indices = Arrays .asList (client ().admin ().indices ().prepareGetIndex ().addIndices (".ml*" ).get ().indices ());
176+ assertThat (indices .toString (), indices , is (empty ()));
177+ });
178+ assertThat (isResetMode (), is (false ));
179+ List <String > tasksNames = client ().admin ()
180+ .cluster ()
181+ .prepareListTasks ()
182+ .setActions ("xpack/ml/*" )
183+ .get ()
184+ .getTasks ()
185+ .stream ()
186+ .map (TaskInfo ::getAction )
187+ .collect (Collectors .toList ());
188+ assertThat (tasksNames , is (empty ()));
189+ }
190+
191+ void createModelDeployment () {
192+ String indexname = "model_store" ;
193+ client ().admin ().indices ().prepareCreate (indexname ).setMapping (
194+ " {\" properties\" : {\n " +
195+ " \" doc_type\" : { \" type\" : \" keyword\" },\n " +
196+ " \" model_id\" : { \" type\" : \" keyword\" },\n " +
197+ " \" definition_length\" : { \" type\" : \" long\" },\n " +
198+ " \" total_definition_length\" : { \" type\" : \" long\" },\n " +
199+ " \" compression_version\" : { \" type\" : \" long\" },\n " +
200+ " \" definition\" : { \" type\" : \" binary\" },\n " +
201+ " \" eos\" : { \" type\" : \" boolean\" },\n " +
202+ " \" task_type\" : { \" type\" : \" keyword\" },\n " +
203+ " \" vocab\" : { \" type\" : \" keyword\" },\n " +
204+ " \" with_special_tokens\" : { \" type\" : \" boolean\" },\n " +
205+ " \" do_lower_case\" : { \" type\" : \" boolean\" }\n " +
206+ " }\n " +
207+ " }}"
208+ ).get ();
209+ client ().prepareIndex (indexname )
210+ .setId (TRAINED_MODEL_ID + "_task_config" )
211+ .setSource (
212+ "{ " +
213+ "\" task_type\" : \" bert_pass_through\" ,\n " +
214+ "\" with_special_tokens\" : false," +
215+ "\" vocab\" : [\" these\" , \" are\" , \" my\" , \" words\" ]\n " +
216+ "}" ,
217+ XContentType .JSON
218+ ).setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
219+ .get ();
220+ client ().prepareIndex (indexname )
221+ .setId ("trained_model_definition_doc-" + TRAINED_MODEL_ID + "-0" )
222+ .setSource (
223+ "{ " +
224+ "\" doc_type\" : \" trained_model_definition_doc\" ," +
225+ "\" model_id\" : \" " + TRAINED_MODEL_ID +"\" ," +
226+ "\" doc_num\" : 0," +
227+ "\" definition_length\" :" + RAW_MODEL_SIZE + "," +
228+ "\" total_definition_length\" :" + RAW_MODEL_SIZE + "," +
229+ "\" compression_version\" : 1," +
230+ "\" definition\" : \" " + BASE_64_ENCODED_MODEL + "\" ," +
231+ "\" eos\" : true" +
232+ "}" ,
233+ XContentType .JSON
234+ ).setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
235+ .get ();
236+ client ()
237+ .execute (
238+ PutTrainedModelAction .INSTANCE ,
239+ new PutTrainedModelAction .Request (
240+ TrainedModelConfig .builder ()
241+ .setModelType (TrainedModelType .PYTORCH )
242+ .setInferenceConfig (new ClassificationConfig (1 ))
243+ .setInput (new TrainedModelInput (Arrays .asList ("text_field" )))
244+ .setLocation (new IndexLocation (TRAINED_MODEL_ID , indexname ))
245+ .setModelId (TRAINED_MODEL_ID )
246+ .build ()
247+ )
248+ )
249+ .actionGet ();
250+ client ().execute (
251+ StartTrainedModelDeploymentAction .INSTANCE ,
252+ new StartTrainedModelDeploymentAction .Request (TRAINED_MODEL_ID )
253+ ).actionGet ();
254+ }
255+
150256 private boolean isResetMode () {
151257 ClusterState state = client ().admin ().cluster ().prepareState ().get ().getState ();
152258 return MlMetadata .getMlMetadata (state ).isResetMode ();
0 commit comments