diff --git a/docs/changelog/143016.yaml b/docs/changelog/143016.yaml new file mode 100644 index 0000000000000..f10d3ea8ed24e --- /dev/null +++ b/docs/changelog/143016.yaml @@ -0,0 +1,6 @@ +area: ES|QL +issues: + - 142662 +pr: 143016 +summary: Cancel async query on expiry +type: bug diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AsyncEsqlQueryActionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AsyncEsqlQueryActionIT.java index 60835596e0468..23f59e11fefe0 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AsyncEsqlQueryActionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AsyncEsqlQueryActionIT.java @@ -42,6 +42,8 @@ import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent; import static org.elasticsearch.xpack.esql.EsqlTestUtils.getValuesList; +import static org.elasticsearch.xpack.esql.action.EsqlQueryRequest.asyncEsqlQueryRequest; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -324,6 +326,52 @@ public void testUpdateKeepAlive() throws Exception { } } + public void testCancelOnExpiry() throws Exception { + TimeValue keepAlive = timeValueMillis(between(1000, 2000)); + var request = asyncEsqlQueryRequest(); + request.query("from test | stats sum(pause_me)"); + request.waitForCompletionTimeout(TimeValue.timeValueMillis(between(1, 10))); + request.keepOnCompletion(randomBoolean()); + request.keepAlive(keepAlive); + final String asyncId; + try { + try (EsqlQueryResponse initialResponse = client().execute(EsqlQueryAction.INSTANCE, request).actionGet(60, TimeUnit.SECONDS)) { + assertThat(initialResponse.isRunning(), is(true)); + assertTrue(initialResponse.asyncExecutionId().isPresent()); + asyncId = initialResponse.asyncExecutionId().get(); + } + // all the started drivers were canceled + assertBusy(() -> { + List tasks = client().admin() + .cluster() + .prepareListTasks() + .setActions(DriverTaskRunner.ACTION_NAME) + .setDetailed(true) + .get() + .getTasks(); + for (TaskInfo task : tasks) { + assertTrue(task.cancelled()); + } + }); + // the async task was canceled + assertBusy(() -> { + List queryTasks = getEsqlQueryTasks(); + assertThat(queryTasks, hasSize(1)); + assertTrue(queryTasks.get(0).cancelled()); + }); + } finally { + scriptPermits.release(numberOfDocs()); + } + TaskCancelledException error = expectThrows(TaskCancelledException.class, () -> { + var getRequest = new GetAsyncResultRequest(asyncId).setWaitForCompletionTimeout(timeValueSeconds(10)) + .setKeepAlive(timeValueSeconds(30)); + try (var resp = client().execute(EsqlAsyncGetResultAction.INSTANCE, getRequest).actionGet()) { + assertThat(resp.isRunning(), is(false)); + } + }); + assertThat(error.getMessage(), containsString("keep_alive expired")); + } + private static long getExpirationFromTask(String asyncId) { List tasks = new ArrayList<>(); for (TransportService ts : internalCluster().getInstances(TransportService.class)) { diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java index 90ae9332c85b0..a177bb76480a5 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java @@ -44,9 +44,7 @@ import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.tasks.CancellableTask; -import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.async.AsyncExecutionId; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; @@ -204,17 +202,7 @@ private void runLookup(DataType keyType, PopulateIndices populateIndices) throws }, EsqlPlugin.STORED_FIELDS_SEQUENTIAL_PROPORTION.getDefault(Settings.EMPTY))), 0 ); - CancellableTask parentTask = new EsqlQueryTask( - 1, - "test", - "test", - "test", - null, - Map.of(), - Map.of(), - new AsyncExecutionId("test", TaskId.EMPTY_TASK_ID), - TEST_REQUEST_TIMEOUT - ); + CancellableTask parentTask = new CancellableTask(1, "test", "test", "test", null, Map.of()); final String finalNodeWithShard = nodeWithShard; LookupFromIndexOperator.Factory lookup = new LookupFromIndexOperator.Factory( "test", diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryTask.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryTask.java index 4d7565a5d7863..0376e04f95d62 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryTask.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryTask.java @@ -9,15 +9,18 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.xpack.core.async.AsyncExecutionId; import org.elasticsearch.xpack.core.async.StoredAsyncTask; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; -public class EsqlQueryTask extends StoredAsyncTask { +public abstract class EsqlQueryTask extends StoredAsyncTask { private EsqlExecutionInfo executionInfo; + private final AtomicReference scheduledCancellation = new AtomicReference<>(); public EsqlQueryTask( long id, @@ -47,4 +50,41 @@ public EsqlQueryResponse getCurrentResult() { // TODO it'd be nice to have the number of documents we've read from completed drivers here return new EsqlQueryResponse(List.of(), List.of(), 0, 0, null, false, getExecutionId().getEncoded(), true, true, executionInfo); } + + @Override + public void onResponse(EsqlQueryResponse response) { + removeScheduledCancellation(); + super.onResponse(response); + } + + @Override + public void onFailure(Exception e) { + removeScheduledCancellation(); + super.onFailure(e); + } + + @Override + public void setExpirationTime(long expirationTime) { + super.setExpirationTime(expirationTime); + rescheduleCancellationOnExpiry(); + } + + private void removeScheduledCancellation() { + var prev = scheduledCancellation.getAndSet(null); + if (prev != null) { + prev.cancel(); + } + } + + /** + * Schedules task cancellation at the given expiration time + */ + protected abstract Scheduler.ScheduledCancellable scheduleCancellationOnExpiry(long expirationTimeMillis); + + public void rescheduleCancellationOnExpiry() { + var prev = scheduledCancellation.getAndSet(scheduleCancellationOnExpiry(getExpirationTimeMillis())); + if (prev != null) { + prev.cancel(); + } + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java index 91c0e77714216..5fffaac9f1979 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java @@ -19,15 +19,20 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.compute.data.BlockFactoryProvider; import org.elasticsearch.compute.operator.exchange.ExchangeService; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.search.SearchService; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.threadpool.Scheduler; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.RemoteClusterAware; import org.elasticsearch.transport.RemoteClusterService; @@ -70,6 +75,8 @@ public class TransportEsqlQueryAction extends HandledTransportAction { + private static final Logger logger = LogManager.getLogger(TransportEsqlQueryAction.class); + private final ThreadPool threadPool; private final PlanExecutor planExecutor; private final ComputeService computeService; @@ -197,6 +204,7 @@ private void doExecuteForked(Task task, EsqlQueryRequest request, ActionListener public void execute(EsqlQueryRequest request, EsqlQueryTask task, ActionListener listener) { // set EsqlExecutionInfo on async-search task so that it is accessible to GET _query/async while the query is still running task.setExecutionInfo(createEsqlExecutionInfo(request)); + task.rescheduleCancellationOnExpiry(); ActionListener.run(listener, l -> innerExecute(task, request, l)); } @@ -411,7 +419,24 @@ public EsqlQueryTask createTask( originHeaders, asyncExecutionId, request.keepAlive() - ); + ) { + @Override + protected Scheduler.ScheduledCancellable scheduleCancellationOnExpiry(long expirationTimeMillis) { + final long delay = Math.max(expirationTimeMillis - threadPool.absoluteTimeInMillis(), 0); + final CancellableTask task = this; + return threadPool.schedule(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + logger.warn("failed to cancel async-query on expiry", e); + } + + @Override + protected void doRun() { + taskManager.cancelTaskAndDescendants(task, "keep_alive expired", false, ActionListener.noop()); + } + }, TimeValue.timeValueMillis(delay), threadPool.generic()); + } + }; } @Override