55 */
66package org .elasticsearch .xpack .ml .dataframe ;
77
8- import org .elasticsearch .Version ;
98import org .elasticsearch .action .ActionListener ;
109import org .elasticsearch .action .index .IndexAction ;
1110import org .elasticsearch .action .index .IndexRequest ;
1211import org .elasticsearch .action .index .IndexResponse ;
1312import org .elasticsearch .action .search .SearchAction ;
1413import org .elasticsearch .action .search .SearchResponse ;
1514import org .elasticsearch .client .Client ;
16- import org .elasticsearch .cluster .service .ClusterService ;
1715import org .elasticsearch .common .settings .Settings ;
1816import org .elasticsearch .common .util .concurrent .ThreadContext ;
1917import org .elasticsearch .search .SearchHit ;
2220import org .elasticsearch .threadpool .ThreadPool ;
2321import org .elasticsearch .xpack .core .ml .action .GetDataFrameAnalyticsStatsAction ;
2422import org .elasticsearch .xpack .core .ml .action .GetDataFrameAnalyticsStatsActionResponseTests ;
25- import org .elasticsearch .xpack .core .ml .action .StartDataFrameAnalyticsAction .TaskParams ;
2623import org .elasticsearch .xpack .core .ml .utils .PhaseProgress ;
2724import org .elasticsearch .xpack .ml .dataframe .DataFrameAnalyticsTask .StartingState ;
28- import org .elasticsearch .xpack .ml .notifications .DataFrameAnalyticsAuditor ;
2925import org .mockito .ArgumentCaptor ;
3026import org .mockito .InOrder ;
3127import org .mockito .stubbing .Answer ;
@@ -116,13 +112,13 @@ public void testDetermineStartingState_GivenEmptyProgress() {
116112 assertThat (startingState , equalTo (StartingState .FINISHED ));
117113 }
118114
119- private void testMarkAsCompleted (SearchHits searchHits , String expectedIndexOrAlias ) {
115+ private void testPersistProgress (SearchHits searchHits , String expectedIndexOrAlias ) {
120116 Client client = mock (Client .class );
121117 ThreadPool threadPool = mock (ThreadPool .class );
122118 when (threadPool .getThreadContext ()).thenReturn (new ThreadContext (Settings .EMPTY ));
123119 when (client .threadPool ()).thenReturn (threadPool );
124120
125- GetDataFrameAnalyticsStatsAction .Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests .randomResponse ();
121+ GetDataFrameAnalyticsStatsAction .Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests .randomResponse (1 );
126122 doAnswer (withResponse (getStatsResponse )).when (client ).execute (eq (GetDataFrameAnalyticsStatsAction .INSTANCE ), any (), any ());
127123
128124 SearchResponse searchResponse = mock (SearchResponse .class );
@@ -132,40 +128,30 @@ private void testMarkAsCompleted(SearchHits searchHits, String expectedIndexOrAl
132128 IndexResponse indexResponse = mock (IndexResponse .class );
133129 doAnswer (withResponse (indexResponse )).when (client ).execute (eq (IndexAction .INSTANCE ), any (), any ());
134130
135- TaskParams taskParams = new TaskParams ("task_id" , Version .CURRENT , Collections .emptyList (), false );
136- DataFrameAnalyticsTask task =
137- new DataFrameAnalyticsTask (
138- 0 ,
139- "" ,
140- "" ,
141- null ,
142- null ,
143- client ,
144- mock (ClusterService .class ),
145- mock (DataFrameAnalyticsManager .class ),
146- mock (DataFrameAnalyticsAuditor .class ),
147- taskParams );
148- task .markAsCompleted ();
131+ Runnable runnable = mock (Runnable .class );
132+
133+ DataFrameAnalyticsTask .persistProgress (client , "task_id" , runnable );
149134
150135 ArgumentCaptor <IndexRequest > indexRequestCaptor = ArgumentCaptor .forClass (IndexRequest .class );
151136
152- InOrder inOrder = inOrder (client );
137+ InOrder inOrder = inOrder (client , runnable );
153138 inOrder .verify (client ).execute (eq (GetDataFrameAnalyticsStatsAction .INSTANCE ), any (), any ());
154139 inOrder .verify (client ).execute (eq (SearchAction .INSTANCE ), any (), any ());
155140 inOrder .verify (client ).execute (eq (IndexAction .INSTANCE ), indexRequestCaptor .capture (), any ());
141+ inOrder .verify (runnable ).run ();
156142 inOrder .verifyNoMoreInteractions ();
157143
158144 IndexRequest indexRequest = indexRequestCaptor .getValue ();
159145 assertThat (indexRequest .index (), equalTo (expectedIndexOrAlias ));
160146 assertThat (indexRequest .id (), equalTo ("data_frame_analytics-task_id-progress" ));
161147 }
162148
163- public void testMarkAsCompleted_ProgressDocumentCreated () {
164- testMarkAsCompleted (SearchHits .empty (), ".ml-state-write" );
149+ public void testPersistProgress_ProgressDocumentCreated () {
150+ testPersistProgress (SearchHits .empty (), ".ml-state-write" );
165151 }
166152
167- public void testMarkAsCompleted_ProgressDocumentUpdated () {
168- testMarkAsCompleted (
153+ public void testPersistProgress_ProgressDocumentUpdated () {
154+ testPersistProgress (
169155 new SearchHits (new SearchHit []{ SearchHit .createFromMap (Map .of ("_index" , ".ml-state-dummy" )) }, null , 0.0f ),
170156 ".ml-state-dummy" );
171157 }
0 commit comments