3535import java .util .Map ;
3636import java .util .concurrent .atomic .AtomicBoolean ;
3737import java .util .concurrent .atomic .AtomicReference ;
38+ import java .util .function .BooleanSupplier ;
3839import java .util .function .Consumer ;
3940import java .util .function .Supplier ;
4041
4142/**
4243 * Task that tracks the progress of a currently running {@link SearchRequest}.
4344 */
4445final class AsyncSearchTask extends SearchTask {
46+ private final BooleanSupplier checkSubmitCancellation ;
4547 private final AsyncSearchId searchId ;
4648 private final Client client ;
4749 private final ThreadPool threadPool ;
@@ -68,6 +70,7 @@ final class AsyncSearchTask extends SearchTask {
6870 * @param type The type of the task.
6971 * @param action The action name.
7072 * @param parentTaskId The parent task id.
73+ * @param checkSubmitCancellation A boolean supplier that checks if the submit task has been cancelled.
7174 * @param originHeaders All the request context headers.
7275 * @param taskHeaders The filtered request headers for the task.
7376 * @param searchId The {@link AsyncSearchId} of the task.
@@ -78,6 +81,7 @@ final class AsyncSearchTask extends SearchTask {
7881 String type ,
7982 String action ,
8083 TaskId parentTaskId ,
84+ BooleanSupplier checkSubmitCancellation ,
8185 TimeValue keepAlive ,
8286 Map <String , String > originHeaders ,
8387 Map <String , String > taskHeaders ,
@@ -86,6 +90,7 @@ final class AsyncSearchTask extends SearchTask {
8690 ThreadPool threadPool ,
8791 Supplier <InternalAggregation .ReduceContext > aggReduceContextSupplier ) {
8892 super (id , type , action , "async_search" , parentTaskId , taskHeaders );
93+ this .checkSubmitCancellation = checkSubmitCancellation ;
8994 this .expirationTimeMillis = getStartTime () + keepAlive .getMillis ();
9095 this .originHeaders = originHeaders ;
9196 this .searchId = searchId ;
@@ -212,13 +217,13 @@ private void internalAddCompletionListener(ActionListener<AsyncSearchResponse> l
212217
213218 final Cancellable cancellable ;
214219 try {
215- cancellable = threadPool .schedule (() -> {
220+ cancellable = threadPool .schedule (threadPool . preserveContext ( () -> {
216221 if (hasRun .compareAndSet (false , true )) {
217222 // timeout occurred before completion
218223 removeCompletionListener (id );
219224 listener .onResponse (getResponse ());
220225 }
221- }, waitForCompletion , "generic" );
226+ }) , waitForCompletion , "generic" );
222227 } catch (EsRejectedExecutionException exc ) {
223228 listener .onFailure (exc );
224229 return ;
@@ -291,41 +296,45 @@ private AsyncSearchResponse getResponse() {
291296 return searchResponse .get ().toAsyncSearchResponse (this , expirationTimeMillis );
292297 }
293298
294- // cancels the task if it expired
295- private void checkExpiration () {
299+ // checks if the search task should be cancelled
300+ private void checkCancellation () {
296301 long now = System .currentTimeMillis ();
297- if (expirationTimeMillis < now ) {
302+ if (expirationTimeMillis < now || checkSubmitCancellation .getAsBoolean ()) {
303+ // we cancel the search task if the initial submit task was cancelled,
304+ // this is needed because the task cancellation mechanism doesn't
305+ // handle the cancellation of grand-children.
298306 cancelTask (() -> {});
299307 }
300308 }
301309
302310 class Listener extends SearchProgressActionListener {
303311 @ Override
304312 protected void onQueryResult (int shardIndex ) {
305- checkExpiration ();
313+ checkCancellation ();
306314 }
307315
308316 @ Override
309317 protected void onFetchResult (int shardIndex ) {
310- checkExpiration ();
318+ checkCancellation ();
311319 }
312320
313321 @ Override
314322 protected void onQueryFailure (int shardIndex , SearchShardTarget shardTarget , Exception exc ) {
315323 // best effort to cancel expired tasks
316- checkExpiration ();
317- searchResponse .get ().addShardFailure (shardIndex , new ShardSearchFailure (exc , shardTarget ));
324+ checkCancellation ();
325+ searchResponse .get ().addShardFailure (shardIndex ,
326+ new ShardSearchFailure (exc , shardTarget .getNodeId () != null ? shardTarget : null ));
318327 }
319328
320329 @ Override
321330 protected void onFetchFailure (int shardIndex , Exception exc ) {
322- checkExpiration ();
331+ checkCancellation ();
323332 }
324333
325334 @ Override
326335 protected void onListShards (List <SearchShard > shards , List <SearchShard > skipped , Clusters clusters , boolean fetchPhase ) {
327336 // best effort to cancel expired tasks
328- checkExpiration ();
337+ checkCancellation ();
329338 searchResponse .compareAndSet (null ,
330339 new MutableSearchResponse (shards .size () + skipped .size (), skipped .size (), clusters , aggReduceContextSupplier ));
331340 executeInitListeners ();
@@ -334,7 +343,7 @@ protected void onListShards(List<SearchShard> shards, List<SearchShard> skipped,
334343 @ Override
335344 public void onPartialReduce (List <SearchShard > shards , TotalHits totalHits , InternalAggregations aggs , int reducePhase ) {
336345 // best effort to cancel expired tasks
337- checkExpiration ();
346+ checkCancellation ();
338347 searchResponse .get ().updatePartialResponse (shards .size (),
339348 new InternalSearchResponse (new SearchHits (SearchHits .EMPTY , totalHits , Float .NaN ), aggs ,
340349 null , null , false , null , reducePhase ), aggs == null );
@@ -343,7 +352,7 @@ public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, Inter
343352 @ Override
344353 public void onFinalReduce (List <SearchShard > shards , TotalHits totalHits , InternalAggregations aggs , int reducePhase ) {
345354 // best effort to cancel expired tasks
346- checkExpiration ();
355+ checkCancellation ();
347356 searchResponse .get ().updatePartialResponse (shards .size (),
348357 new InternalSearchResponse (new SearchHits (SearchHits .EMPTY , totalHits , Float .NaN ), aggs ,
349358 null , null , false , null , reducePhase ), true );
0 commit comments