Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/144010.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
area: Search
issues: []
pr: 144010
summary: Expose keep_alive in async task status
type: enhancement
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
9318000
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.4.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
esql_async_source_bytes_buffered,9317000
async_task_keep_alive_status,9318000
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
import org.elasticsearch.search.aggregations.metrics.Min;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.query.ThrowingQueryBuilder;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.test.ESIntegTestCase.SuiteScopeTestCase;
import org.elasticsearch.test.TransportVersionUtils;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.xpack.core.XPackPlugin;
import org.elasticsearch.xpack.core.async.AsyncExecutionId;
import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse;
import org.elasticsearch.xpack.core.search.action.AsyncStatusResponse;
import org.elasticsearch.xpack.core.search.action.SubmitAsyncSearchRequest;
Expand Down Expand Up @@ -417,6 +419,7 @@ public void testUpdateRunningKeepAlive() throws Exception {
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));
assertThat(response.getExpirationTime(), greaterThan(now));
expirationTime = response.getExpirationTime();
assertThat(getRunningAsyncSearchTask(responseId).toString(), containsString("\"keep_alive\" : \"5d\""));
} finally {
response.decRef();
}
Expand Down Expand Up @@ -445,6 +448,7 @@ public void testUpdateRunningKeepAlive() throws Exception {
assertThat(response.getSearchResponse().getTotalShards(), equalTo(numShards));
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(0));
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));
assertThat(getRunningAsyncSearchTask(response.getId()).toString(), containsString("\"keep_alive\" : \"10d\""));

AsyncStatusResponse statusResponse = getAsyncStatus(response.getId(), TimeValue.timeValueDays(10));
assertTrue(statusResponse.isRunning());
Expand Down Expand Up @@ -491,6 +495,22 @@ public void testUpdateRunningKeepAlive() throws Exception {
}
}

private TaskInfo getRunningAsyncSearchTask(String asyncSearchId) throws Exception {
var targetTaskId = AsyncExecutionId.decode(asyncSearchId).getTaskId();
TaskInfo found = client().admin()
.cluster()
.prepareListTasks()
.setDetailed(true)
.get()
.getTasks()
.stream()
.filter(taskInfo -> taskInfo.taskId().equals(targetTaskId))
.findAny()
.orElse(null);
assertNotNull(found);
return found;
}

public void testUpdateStoreKeepAlive() throws Exception {
SubmitAsyncSearchRequest request = new SubmitAsyncSearchRequest(indexName);
long now = System.currentTimeMillis();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.action.search.TransportSearchAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.Requests;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.TimeValue;
Expand All @@ -32,16 +34,19 @@
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.tasks.RawTaskStatus;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.threadpool.Scheduler.Cancellable;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.async.AsyncExecutionId;
import org.elasticsearch.xpack.core.async.AsyncTask;
import org.elasticsearch.xpack.core.async.AsyncTaskIndexService;
import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse;
import org.elasticsearch.xpack.core.search.action.AsyncStatusResponse;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -75,6 +80,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask, Releasable
private final Map<Long, Consumer<AsyncSearchResponse>> completionListeners = new HashMap<>();

private volatile long expirationTimeMillis;
private volatile TimeValue keepAlive;
private final AtomicBoolean isCancelling = new AtomicBoolean(false);

private final MutableSearchResponse searchResponse;
Expand Down Expand Up @@ -112,6 +118,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask, Releasable
) {
super(id, type, action, () -> "async_search{" + descriptionSupplier.get() + "}", parentTaskId, taskHeaders);
this.expirationTimeMillis = getStartTime() + keepAlive.getMillis();
this.keepAlive = keepAlive;
this.originHeaders = originHeaders;
this.searchId = searchId;
this.client = client;
Expand Down Expand Up @@ -147,8 +154,27 @@ Listener getSearchProgressActionListener() {
* Update the expiration time of the (partial) response.
*/
@Override
public void setExpirationTime(long expirationTime) {
public void setExpirationTime(long expirationTime, TimeValue keepAlive) {
this.expirationTimeMillis = expirationTime;
this.keepAlive = keepAlive;
}

@Override
public TimeValue getKeepAlive() {
return keepAlive;
}

@Override
public Status getStatus() {
try (XContentBuilder builder = XContentBuilder.builder(Requests.INDEX_CONTENT_TYPE.xContent())) {
builder.startObject();
builder.field("request_id", searchId.getEncoded());
builder.field("keep_alive", keepAlive.getStringRep());
builder.endObject();
return new RawTaskStatus(BytesReference.bytes(builder));
} catch (IOException e) {
throw new IllegalStateException("failed to build async search task status", e);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ protected void doExecute(Task task, GetAsyncStatusRequest request, ActionListene
store.updateExpirationTime(searchId.getDocId(), expirationTime, ActionListener.wrap(p -> {
AsyncSearchTask asyncSearchTask = getTask(taskManager, searchId, AsyncSearchTask.class);
if (asyncSearchTask != null) {
asyncSearchTask.setExpirationTime(expirationTime);
asyncSearchTask.setExpirationTime(expirationTime, request.getKeepAlive());
}
store.retrieveStatus(
request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import org.apache.lucene.search.TotalHits;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchRequest;
Expand All @@ -18,6 +19,7 @@
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.shard.ShardId;
Expand All @@ -30,13 +32,18 @@
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.tasks.RawTaskStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.TransportVersionUtils;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.RemoteClusterAware;
import org.elasticsearch.xpack.core.async.AsyncExecutionId;
import org.elasticsearch.xpack.core.async.AsyncTask;
import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse;
import org.junit.After;
import org.junit.Before;
Expand All @@ -45,12 +52,17 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasToString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
Expand Down Expand Up @@ -126,6 +138,29 @@ public void testTaskDescription() {
}
}

public void testTaskStatusIncludesKeepAlive() {
try (AsyncSearchTask task = createAsyncSearchTask()) {
TaskInfo taskInfo = task.taskInfo("node1", true);
assertThat(taskInfo.status(), notNullValue());
assertThat(taskInfo, hasToString(containsString("\"request_id\" : \"" + task.getExecutionId().getEncoded() + "\"")));
assertThat(taskInfo, hasToString(containsString("\"keep_alive\" : \"1h\"")));
}
}

public void testTaskStatusSerializationToPreviousTransportVersionUsesRawTaskStatus() throws IOException {
try (AsyncSearchTask task = createAsyncSearchTask()) {
NamedWriteableRegistry oldRegistry = new NamedWriteableRegistry(
List.of(new NamedWriteableRegistry.Entry(Task.Status.class, RawTaskStatus.NAME, RawTaskStatus::new))
);
TaskInfo taskInfo = task.taskInfo("node1", true);
TransportVersion previousVersion = TransportVersionUtils.randomVersionNotSupporting(AsyncTask.ASYNC_TASK_KEEP_ALIVE_STATUS);
TaskInfo serialized = copyWriteable(taskInfo, oldRegistry, TaskInfo::from, previousVersion);
assertThat(serialized.status(), instanceOf(RawTaskStatus.class));
Map<String, Object> statusMap = ((RawTaskStatus) serialized.status()).toMap();
assertThat(statusMap, allOf(hasEntry("request_id", task.getExecutionId().getEncoded()), hasEntry("keep_alive", "1h")));
}
}

public void testWaitForInit() throws InterruptedException {
try (
AsyncSearchTask task = new AsyncSearchTask(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ private void getSearchResponseFromTask(
}

if (expirationTimeMillis != -1) {
task.setExpirationTime(expirationTimeMillis);
task.setExpirationTime(expirationTimeMillis, request.getKeepAlive());
}
boolean added = addCompletionListener.apply(
task,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package org.elasticsearch.xpack.core.async;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.TaskManager;

import java.util.Map;
Expand All @@ -15,6 +17,11 @@
* A task that supports asynchronous execution and provides information necessary for safe temporary storage of results
*/
public interface AsyncTask {
/**
* Transport version that added {@code keep_alive} to async task status payloads.
*/
TransportVersion ASYNC_TASK_KEEP_ALIVE_STATUS = TransportVersion.fromName("async_task_keep_alive_status");

/**
* Returns all of the request contexts headers
*/
Expand All @@ -33,7 +40,12 @@ public interface AsyncTask {
/**
* Update the expiration time of the (partial) response.
*/
void setExpirationTime(long expirationTimeMillis);
void setExpirationTime(long expirationTimeMillis, TimeValue keepAlive);

/**
* Returns the currently effective keep-alive for this task.
*/
TimeValue getKeepAlive();

/**
* Performs necessary checks, cancels the task and calls the runnable upon completion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public abstract class StoredAsyncTask<Response extends ActionResponse> extends C
private final AsyncExecutionId asyncExecutionId;
private final Map<String, String> originHeaders;
private volatile long expirationTimeMillis;
private volatile TimeValue keepAlive;
protected final List<ActionListener<Response>> completionListeners;
private boolean hasCompleted = false;

Expand All @@ -43,6 +44,7 @@ public StoredAsyncTask(
this.asyncExecutionId = asyncExecutionId;
this.originHeaders = originHeaders;
this.expirationTimeMillis = getStartTime() + keepAlive.getMillis();
this.keepAlive = keepAlive;
this.completionListeners = new ArrayList<>();
}

Expand All @@ -60,14 +62,20 @@ public AsyncExecutionId getExecutionId() {
* Update the expiration time of the (partial) response.
*/
@Override
public void setExpirationTime(long expirationTime) {
public void setExpirationTime(long expirationTime, TimeValue keepAlive) {
this.expirationTimeMillis = expirationTime;
this.keepAlive = keepAlive;
}

public long getExpirationTimeMillis() {
return expirationTimeMillis;
}

@Override
public TimeValue getKeepAlive() {
return keepAlive;
}

public synchronized boolean addCompletionListener(Supplier<ActionListener<Response>> listenerSupplier) {
if (hasCompleted) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public static class TestTask extends CancellableTask implements AsyncTask {
private final AsyncExecutionId executionId;
private final Map<ActionListener<TestAsyncResponse>, TimeValue> listeners = new HashMap<>();
private long expirationTimeMillis;
private TimeValue keepAlive = TimeValue.ZERO;

public TestTask(
AsyncExecutionId executionId,
Expand Down Expand Up @@ -83,8 +84,14 @@ public AsyncExecutionId getExecutionId() {
}

@Override
public void setExpirationTime(long expirationTime) {
public void setExpirationTime(long expirationTime, TimeValue keepAlive) {
this.expirationTimeMillis = expirationTime;
this.keepAlive = keepAlive;
}

@Override
public TimeValue getKeepAlive() {
return keepAlive;
}

@Override
Expand Down Expand Up @@ -192,7 +199,7 @@ public void testRetrieveFromMemoryWithExpiration() throws Exception {
try {
boolean shouldExpire = randomBoolean();
long expirationTime = System.currentTimeMillis() + randomLongBetween(100000, 1000000) * (shouldExpire ? -1 : 1);
task.setExpirationTime(expirationTime);
task.setExpirationTime(expirationTime, TimeValue.timeValueMillis(Math.max(expirationTime - System.currentTimeMillis(), 0)));

if (updateInitialResultsInStore) {
// we need to store initial result
Expand Down Expand Up @@ -240,7 +247,7 @@ public void testAssertExpirationPropagation() throws Exception {
TestTask task = (TestTask) taskManager.register("test", "test", request);
try {
long startTime = System.currentTimeMillis();
task.setExpirationTime(startTime + TimeValue.timeValueMinutes(1).getMillis());
task.setExpirationTime(startTime + TimeValue.timeValueMinutes(1).getMillis(), TimeValue.timeValueMinutes(1));
boolean taskCompleted = randomBoolean();
if (taskCompleted) {
taskManager.unregister(task);
Expand Down Expand Up @@ -286,7 +293,7 @@ public void testRetrieveFromDisk() throws Exception {
TestTask task = (TestTask) taskManager.register("test", "test", request);
try {
long startTime = System.currentTimeMillis();
task.setExpirationTime(startTime + TimeValue.timeValueMinutes(1).getMillis());
task.setExpirationTime(startTime + TimeValue.timeValueMinutes(1).getMillis(), TimeValue.timeValueMinutes(1));

if (updateInitialResultsInStore) {
// we need to store initial result
Expand Down Expand Up @@ -355,7 +362,7 @@ public void testFailWithIncompatibleResults() throws Exception {
TestTask task = (TestTask) taskManager.register("test", "test", request);
try {
long startTime = System.currentTimeMillis();
task.setExpirationTime(startTime + TimeValue.timeValueMinutes(1).getMillis());
task.setExpirationTime(startTime + TimeValue.timeValueMinutes(1).getMillis(), TimeValue.timeValueMinutes(1));

// we need to store initial result
PlainActionFuture<DocWriteResponse> futureCreate = new PlainActionFuture<>();
Expand Down
Loading
Loading