Skip to content

Commit 68f4297

Browse files
authored
Improve async search's tasks cancellation (#53799)
This commit adds an explicit cancellation of the search task if the initial async search submit task is cancelled (connection closed by the user). This was previously done through the cancellation of the parent task but we don't handle grand-children cancellation yet so we have to manually cancel the search task in order to ensure that shard actions are cancelled too. This change can be considered as a workaround until #50990 is fixed.
1 parent aed8ce7 commit 68f4297

File tree

10 files changed

+377
-186
lines changed

10 files changed

+377
-186
lines changed

x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@
3535
import java.util.Map;
3636
import java.util.concurrent.atomic.AtomicBoolean;
3737
import java.util.concurrent.atomic.AtomicReference;
38+
import java.util.function.BooleanSupplier;
3839
import java.util.function.Consumer;
3940
import java.util.function.Supplier;
4041

4142
/**
4243
* Task that tracks the progress of a currently running {@link SearchRequest}.
4344
*/
4445
final class AsyncSearchTask extends SearchTask {
46+
private final BooleanSupplier checkSubmitCancellation;
4547
private final AsyncSearchId searchId;
4648
private final Client client;
4749
private final ThreadPool threadPool;
@@ -68,6 +70,7 @@ final class AsyncSearchTask extends SearchTask {
6870
* @param type The type of the task.
6971
* @param action The action name.
7072
* @param parentTaskId The parent task id.
73+
* @param checkSubmitCancellation A boolean supplier that checks if the submit task has been cancelled.
7174
* @param originHeaders All the request context headers.
7275
* @param taskHeaders The filtered request headers for the task.
7376
* @param searchId The {@link AsyncSearchId} of the task.
@@ -78,6 +81,7 @@ final class AsyncSearchTask extends SearchTask {
7881
String type,
7982
String action,
8083
TaskId parentTaskId,
84+
BooleanSupplier checkSubmitCancellation,
8185
TimeValue keepAlive,
8286
Map<String, String> originHeaders,
8387
Map<String, String> taskHeaders,
@@ -86,6 +90,7 @@ final class AsyncSearchTask extends SearchTask {
8690
ThreadPool threadPool,
8791
Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier) {
8892
super(id, type, action, "async_search", parentTaskId, taskHeaders);
93+
this.checkSubmitCancellation = checkSubmitCancellation;
8994
this.expirationTimeMillis = getStartTime() + keepAlive.getMillis();
9095
this.originHeaders = originHeaders;
9196
this.searchId = searchId;
@@ -212,13 +217,13 @@ private void internalAddCompletionListener(ActionListener<AsyncSearchResponse> l
212217

213218
final Cancellable cancellable;
214219
try {
215-
cancellable = threadPool.schedule(() -> {
220+
cancellable = threadPool.schedule(threadPool.preserveContext(() -> {
216221
if (hasRun.compareAndSet(false, true)) {
217222
// timeout occurred before completion
218223
removeCompletionListener(id);
219224
listener.onResponse(getResponse());
220225
}
221-
}, waitForCompletion, "generic");
226+
}), waitForCompletion, "generic");
222227
} catch (EsRejectedExecutionException exc) {
223228
listener.onFailure(exc);
224229
return;
@@ -291,41 +296,45 @@ private AsyncSearchResponse getResponse() {
291296
return searchResponse.get().toAsyncSearchResponse(this, expirationTimeMillis);
292297
}
293298

294-
// cancels the task if it expired
295-
private void checkExpiration() {
299+
// checks if the search task should be cancelled
300+
private void checkCancellation() {
296301
long now = System.currentTimeMillis();
297-
if (expirationTimeMillis < now) {
302+
if (expirationTimeMillis < now || checkSubmitCancellation.getAsBoolean()) {
303+
// we cancel the search task if the initial submit task was cancelled,
304+
// this is needed because the task cancellation mechanism doesn't
305+
// handle the cancellation of grand-children.
298306
cancelTask(() -> {});
299307
}
300308
}
301309

302310
class Listener extends SearchProgressActionListener {
303311
@Override
304312
protected void onQueryResult(int shardIndex) {
305-
checkExpiration();
313+
checkCancellation();
306314
}
307315

308316
@Override
309317
protected void onFetchResult(int shardIndex) {
310-
checkExpiration();
318+
checkCancellation();
311319
}
312320

313321
@Override
314322
protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
315323
// best effort to cancel expired tasks
316-
checkExpiration();
317-
searchResponse.get().addShardFailure(shardIndex, new ShardSearchFailure(exc, shardTarget));
324+
checkCancellation();
325+
searchResponse.get().addShardFailure(shardIndex,
326+
new ShardSearchFailure(exc, shardTarget.getNodeId() != null ? shardTarget : null));
318327
}
319328

320329
@Override
321330
protected void onFetchFailure(int shardIndex, Exception exc) {
322-
checkExpiration();
331+
checkCancellation();
323332
}
324333

325334
@Override
326335
protected void onListShards(List<SearchShard> shards, List<SearchShard> skipped, Clusters clusters, boolean fetchPhase) {
327336
// best effort to cancel expired tasks
328-
checkExpiration();
337+
checkCancellation();
329338
searchResponse.compareAndSet(null,
330339
new MutableSearchResponse(shards.size() + skipped.size(), skipped.size(), clusters, aggReduceContextSupplier));
331340
executeInitListeners();
@@ -334,7 +343,7 @@ protected void onListShards(List<SearchShard> shards, List<SearchShard> skipped,
334343
@Override
335344
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
336345
// best effort to cancel expired tasks
337-
checkExpiration();
346+
checkCancellation();
338347
searchResponse.get().updatePartialResponse(shards.size(),
339348
new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs,
340349
null, null, false, null, reducePhase), aggs == null);
@@ -343,7 +352,7 @@ public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, Inter
343352
@Override
344353
public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
345354
// best effort to cancel expired tasks
346-
checkExpiration();
355+
checkCancellation();
347356
searchResponse.get().updatePartialResponse(shards.size(),
348357
new InternalSearchResponse(new SearchHits(SearchHits.EMPTY, totalHits, Float.NaN), aggs,
349358
null, null, false, null, reducePhase), true);

x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public TransportSubmitAsyncSearchAction(ClusterService clusterService,
6666
@Override
6767
protected void doExecute(Task task, SubmitAsyncSearchRequest request, ActionListener<AsyncSearchResponse> submitListener) {
6868
CancellableTask submitTask = (CancellableTask) task;
69-
final SearchRequest searchRequest = createSearchRequest(request, submitTask.getId(), request.getKeepAlive());
69+
final SearchRequest searchRequest = createSearchRequest(request, submitTask, request.getKeepAlive());
7070
AsyncSearchTask searchTask = (AsyncSearchTask) taskManager.register("transport", SearchAction.INSTANCE.name(), searchRequest);
7171
searchAction.execute(searchTask, searchRequest, searchTask.getSearchProgressActionListener());
7272
searchTask.addCompletionListener(
@@ -81,7 +81,7 @@ public void onResponse(AsyncSearchResponse searchResponse) {
8181
// the user cancelled the submit so we don't store anything
8282
// and propagate the failure
8383
Exception cause = new TaskCancelledException(submitTask.getReasonCancelled());
84-
onFatalFailure(searchTask, cause, false, submitListener);
84+
onFatalFailure(searchTask, cause, searchResponse.isRunning(), submitListener);
8585
} else {
8686
final String docId = searchTask.getSearchId().getDocId();
8787
// creates the fallback response if the node crashes/restarts in the middle of the request
@@ -129,7 +129,7 @@ public void onFailure(Exception exc) {
129129
}, request.getWaitForCompletion());
130130
}
131131

132-
private SearchRequest createSearchRequest(SubmitAsyncSearchRequest request, long parentTaskId, TimeValue keepAlive) {
132+
private SearchRequest createSearchRequest(SubmitAsyncSearchRequest request, CancellableTask submitTask, TimeValue keepAlive) {
133133
String docID = UUIDs.randomBase64UUID();
134134
Map<String, String> originHeaders = nodeClient.threadPool().getThreadContext().getHeaders();
135135
SearchRequest searchRequest = new SearchRequest(request.getSearchRequest()) {
@@ -138,16 +138,17 @@ public AsyncSearchTask createTask(long id, String type, String action, TaskId pa
138138
AsyncSearchId searchId = new AsyncSearchId(docID, new TaskId(nodeClient.getLocalNodeId(), id));
139139
Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier =
140140
() -> requestToAggReduceContextBuilder.apply(request.getSearchRequest());
141-
return new AsyncSearchTask(id, type, action, parentTaskId, keepAlive, originHeaders, taskHeaders, searchId,
142-
store.getClient(), nodeClient.threadPool(), aggReduceContextSupplier);
141+
return new AsyncSearchTask(id, type, action, parentTaskId,
142+
() -> submitTask.isCancelled(), keepAlive, originHeaders, taskHeaders, searchId, store.getClient(),
143+
nodeClient.threadPool(), aggReduceContextSupplier);
143144
}
144145
};
145-
searchRequest.setParentTask(new TaskId(nodeClient.getLocalNodeId(), parentTaskId));
146+
searchRequest.setParentTask(new TaskId(nodeClient.getLocalNodeId(), submitTask.getId()));
146147
return searchRequest;
147148
}
148149

149150
private void onFatalFailure(AsyncSearchTask task, Exception error, boolean shouldCancel, ActionListener<AsyncSearchResponse> listener) {
150-
if (shouldCancel) {
151+
if (shouldCancel && task.isCancelled() == false) {
151152
task.cancelTask(() -> {
152153
try {
153154
task.addCompletionListener(finalResponse -> taskManager.unregister(task));

x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchActionTests.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,4 +253,28 @@ public void testNoIndex() throws Exception {
253253
ElasticsearchException exc = response.getFailure();
254254
assertThat(exc.getMessage(), containsString("no such index"));
255255
}
256+
257+
public void testCancellation() throws Exception {
258+
SubmitAsyncSearchRequest request = new SubmitAsyncSearchRequest(indexName);
259+
request.getSearchRequest().source(
260+
new SearchSourceBuilder().aggregation(new CancellingAggregationBuilder("test"))
261+
);
262+
request.setWaitForCompletion(TimeValue.timeValueMillis(1));
263+
AsyncSearchResponse response = submitAsyncSearch(request);
264+
assertNotNull(response.getSearchResponse());
265+
assertTrue(response.isRunning());
266+
assertThat(response.getSearchResponse().getTotalShards(), equalTo(numShards));
267+
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(0));
268+
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));
269+
270+
response = getAsyncSearch(response.getId());
271+
assertNotNull(response.getSearchResponse());
272+
assertTrue(response.isRunning());
273+
assertThat(response.getSearchResponse().getTotalShards(), equalTo(numShards));
274+
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(0));
275+
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));
276+
277+
deleteAsyncSearch(response.getId());
278+
ensureTaskRemoval(response.getId());
279+
}
256280
}

0 commit comments

Comments
 (0)