Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase concurrent request of opening point-in-time #96782

Merged
merged 10 commits into from
Jun 20, 2023
5 changes: 5 additions & 0 deletions docs/changelog/96782.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 96782
summary: Increase concurrent request of opening point-in-time
area: Search
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.admin.indices.stats.CommonStats;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.IndexService;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchContextMissingException;
import org.elasticsearch.search.SearchHit;
Expand All @@ -33,10 +36,14 @@
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.transport.TransportService;

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

Expand All @@ -54,6 +61,11 @@

public class PointInTimeIT extends ESIntegTestCase {

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return CollectionUtils.appendToCopy(super.nodePlugins(), MockTransportService.TestPlugin.class);
}

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder()
Expand Down Expand Up @@ -430,6 +442,52 @@ public void testCloseInvalidPointInTime() {
assertThat(tasks, empty());
}

public void testOpenPITConcurrentShardRequests() throws Exception {
DiscoveryNode dataNode = randomFrom(clusterService().state().nodes().getDataNodes().values());
int numShards = randomIntBetween(5, 10);
int maxConcurrentRequests = randomIntBetween(2, 5);
assertAcked(
client().admin()
.indices()
.prepareCreate("test")
.setSettings(
Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards)
.put("index.routing.allocation.require._id", dataNode.getId())
.build()
)
);
var transportService = (MockTransportService) internalCluster().getInstance(TransportService.class, dataNode.getName());
try {
CountDownLatch sentLatch = new CountDownLatch(maxConcurrentRequests);
CountDownLatch readyLatch = new CountDownLatch(1);
transportService.addRequestHandlingBehavior(
TransportOpenPointInTimeAction.OPEN_SHARD_READER_CONTEXT_NAME,
(handler, request, channel, task) -> {
sentLatch.countDown();
Thread thread = new Thread(() -> {
try {
assertTrue(readyLatch.await(1, TimeUnit.MINUTES));
handler.messageReceived(request, channel, task);
} catch (Exception e) {
throw new AssertionError(e);
}
});
thread.start();
}
);
OpenPointInTimeRequest request = new OpenPointInTimeRequest("test").keepAlive(TimeValue.timeValueMinutes(1));
request.maxConcurrentShardRequests(maxConcurrentRequests);
PlainActionFuture<OpenPointInTimeResponse> future = new PlainActionFuture<>();
client().execute(OpenPointInTimeAction.INSTANCE, request, future);
assertTrue(sentLatch.await(1, TimeUnit.MINUTES));
readyLatch.countDown();
closePointInTime(future.actionGet().getPointInTimeId());
} finally {
transportService.clearAllRules();
}
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private void assertPagination(PointInTimeBuilder pit, int expectedNumDocs, int size, SortBuilder<?>... sorts) throws Exception {
Set<String> seen = new HashSet<>();
Expand Down
3 changes: 2 additions & 1 deletion server/src/main/java/org/elasticsearch/TransportVersion.java
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,10 @@ private static TransportVersion registerTransportVersion(int id, String uniqueId
public static final TransportVersion V_8_500_014 = registerTransportVersion(8_500_014, "D115A2E1-1739-4A02-AB7B-64F6EA157EFB");
public static final TransportVersion V_8_500_015 = registerTransportVersion(8_500_015, "651216c9-d54f-4189-9fe1-48d82d276863");
public static final TransportVersion V_8_500_016 = registerTransportVersion(8_500_016, "492C94FB-AAEA-4C9E-8375-BDB67A398584");
public static final TransportVersion V_8_500_017 = registerTransportVersion(8_500_017, "0EDCB5BA-049C-443C-8AB1-5FA58FB996FB");

private static class CurrentHolder {
private static final TransportVersion CURRENT = findCurrent(V_8_500_016);
private static final TransportVersion CURRENT = findCurrent(V_8_500_017);

// finds the pluggable current version, or uses the given fallback
private static TransportVersion findCurrent(TransportVersion fallback) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.elasticsearch.action.search;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.IndicesRequest;
Expand All @@ -20,16 +21,18 @@
import org.elasticsearch.tasks.TaskId;

import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.action.ValidateActions.addValidationError;

public final class OpenPointInTimeRequest extends ActionRequest implements IndicesRequest.Replaceable {

private String[] indices;
private IndicesOptions indicesOptions = DEFAULT_INDICES_OPTIONS;
private TimeValue keepAlive;

private int maxConcurrentShardRequests = 5;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: shall we link to the default constant for the same parameter in search, given the two have the same default value for now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I pushed a252899.

@Nullable
private String routing;
@Nullable
Expand All @@ -48,6 +51,9 @@ public OpenPointInTimeRequest(StreamInput in) throws IOException {
this.keepAlive = in.readTimeValue();
this.routing = in.readOptionalString();
this.preference = in.readOptionalString();
if (in.getTransportVersion().onOrAfter(TransportVersion.V_8_500_017)) {
this.maxConcurrentShardRequests = in.readVInt();
}
}

@Override
Expand All @@ -58,6 +64,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeTimeValue(keepAlive);
out.writeOptionalString(routing);
out.writeOptionalString(preference);
if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_500_017)) {
out.writeVInt(maxConcurrentShardRequests);
}
}

@Override
Expand Down Expand Up @@ -123,6 +132,27 @@ public OpenPointInTimeRequest preference(String preference) {
return this;
}

/**
* Similar to {@link SearchRequest#getMaxConcurrentShardRequests()}, this returns the number of shard requests that should be
* executed concurrently on a single node . This value should be used as a protection mechanism to reduce the number of shard
* requests fired per open point-in-time request. The default is {@code 5}
*/
public int maxConcurrentShardRequests() {
return maxConcurrentShardRequests;
}

/**
* Similar to {@link SearchRequest#setMaxConcurrentShardRequests(int)}, this sets the number of shard requests that should be
* executed concurrently on a single node. This value should be used as a protection mechanism to reduce the number of shard
* requests fired per open point-in-time request.
*/
public void maxConcurrentShardRequests(int maxConcurrentShardRequests) {
if (maxConcurrentShardRequests < 1) {
throw new IllegalArgumentException("maxConcurrentShardRequests must be >= 1");
}
this.maxConcurrentShardRequests = maxConcurrentShardRequests;
}

@Override
public boolean allowsRemoteIndices() {
return true;
Expand All @@ -138,8 +168,46 @@ public String getDescription() {
return "open search context: indices [" + String.join(",", indices) + "] keep_alive [" + keepAlive + "]";
}

@Override
public String toString() {
return "OpenPointInTimeRequest{"
+ "indices="
+ Arrays.toString(indices)
+ ", keepAlive="
+ keepAlive
+ ", maxConcurrentShardRequests="
+ maxConcurrentShardRequests
+ ", routing='"
+ routing
+ '\''
+ ", preference='"
+ preference
+ '\''
+ '}';
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new SearchTask(id, type, action, this::getDescription, parentTaskId, headers);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
OpenPointInTimeRequest that = (OpenPointInTimeRequest) o;
return maxConcurrentShardRequests == that.maxConcurrentShardRequests
&& Arrays.equals(indices, that.indices)
&& indicesOptions.equals(that.indicesOptions)
&& keepAlive.equals(that.keepAlive)
&& Objects.equals(routing, that.routing)
&& Objects.equals(preference, that.preference);
}

@Override
public int hashCode() {
int result = Objects.hash(indicesOptions, keepAlive, maxConcurrentShardRequests, routing, preference);
result = 31 * result + Arrays.hashCode(indices);
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
openRequest.routing(request.param("routing"));
openRequest.preference(request.param("preference"));
openRequest.keepAlive(TimeValue.parseTimeValue(request.param("keep_alive"), null, "keep_alive"));
if (request.hasParam("max_concurrent_shard_requests")) {
final int maxConcurrentShardRequests = request.paramAsInt(
"max_concurrent_shard_requests",
openRequest.maxConcurrentShardRequests()
);
openRequest.maxConcurrentShardRequests(maxConcurrentShardRequests);
}
return channel -> client.execute(OpenPointInTimeAction.INSTANCE, openRequest, new RestToXContentListener<>(channel));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ protected void doExecute(Task task, OpenPointInTimeRequest request, ActionListen
.preference(request.preference())
.routing(request.routing())
.allowPartialSearchResults(false);
searchRequest.setMaxConcurrentShardRequests(request.maxConcurrentShardRequests());
searchRequest.setCcsMinimizeRoundtrips(false);
transportSearchAction.executeRequest((SearchTask) task, searchRequest, listener.map(r -> {
assert r.pointInTimeId() != null : r;
Expand Down Expand Up @@ -117,6 +118,8 @@ public SearchPhase newSearchPhase(
ThreadPool threadPool,
SearchResponse.Clusters clusters
) {
assert searchRequest.getMaxConcurrentShardRequests() == pitRequest.maxConcurrentShardRequests()
: searchRequest.getMaxConcurrentShardRequests() + " != " + pitRequest.maxConcurrentShardRequests();
return new AbstractSearchAsyncAction<>(
actionName,
logger,
Expand All @@ -132,7 +135,7 @@ public SearchPhase newSearchPhase(
clusterState,
task,
new ArraySearchPhaseResults<>(shardIterators.size()),
1,
searchRequest.getMaxConcurrentShardRequests(),
clusters
) {
@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.action.search;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.test.TransportVersionUtils;

import java.io.IOException;
import java.util.List;

import static org.hamcrest.Matchers.equalTo;

public class OpenPointInTimeRequestTests extends AbstractWireSerializingTestCase<OpenPointInTimeRequest> {
@Override
protected Writeable.Reader<OpenPointInTimeRequest> instanceReader() {
return OpenPointInTimeRequest::new;
}

@Override
protected OpenPointInTimeRequest createTestInstance() {
OpenPointInTimeRequest request = new OpenPointInTimeRequest("index-1", "index-2");
request.keepAlive(TimeValue.timeValueSeconds(randomIntBetween(1, 1000)));
if (randomBoolean()) {
request.maxConcurrentShardRequests(randomIntBetween(1, 10));
} else {
assertThat(request.maxConcurrentShardRequests(), equalTo(5));
}
if (randomBoolean()) {
request.preference(randomAlphaOfLength(10));
}
if (randomBoolean()) {
request.routing(randomAlphaOfLength(10));
}
return request;
}

@Override
protected OpenPointInTimeRequest mutateInstance(OpenPointInTimeRequest in) throws IOException {
return switch (between(0, 4)) {
case 0 -> {
OpenPointInTimeRequest request = new OpenPointInTimeRequest("new-index");
request.maxConcurrentShardRequests(in.maxConcurrentShardRequests());
request.keepAlive(in.keepAlive());
request.preference(in.preference());
request.routing(in.routing());
yield request;
}
case 1 -> {
OpenPointInTimeRequest request = new OpenPointInTimeRequest(in.indices());
request.maxConcurrentShardRequests(in.maxConcurrentShardRequests() + between(1, 10));
request.keepAlive(in.keepAlive());
request.preference(in.preference());
request.routing(in.routing());
yield request;
}
case 2 -> {
OpenPointInTimeRequest request = new OpenPointInTimeRequest(in.indices());
request.maxConcurrentShardRequests(in.maxConcurrentShardRequests());
request.keepAlive(TimeValue.timeValueSeconds(between(2000, 5000)));
request.preference(in.preference());
request.routing(in.routing());
yield request;
}
case 3 -> {
OpenPointInTimeRequest request = new OpenPointInTimeRequest(in.indices());
request.maxConcurrentShardRequests(in.maxConcurrentShardRequests());
request.keepAlive(in.keepAlive());
request.preference(randomAlphaOfLength(5));
request.routing(in.routing());
yield request;
}
case 4 -> {
OpenPointInTimeRequest request = new OpenPointInTimeRequest(in.indices());
request.maxConcurrentShardRequests(in.maxConcurrentShardRequests());
request.keepAlive(in.keepAlive());
request.preference(in.preference());
request.routing(randomAlphaOfLength(5));
yield request;
}
default -> throw new AssertionError("Unknown option");
};
}

public void testUseDefaultConcurrentForOldVersion() throws Exception {
TransportVersion previousVersion = TransportVersionUtils.getPreviousVersion(TransportVersion.V_8_500_017);
try (BytesStreamOutput output = new BytesStreamOutput()) {
TransportVersion version = TransportVersionUtils.randomVersionBetween(random(), TransportVersion.V_8_0_0, previousVersion);
output.setTransportVersion(version);
OpenPointInTimeRequest original = createTestInstance();
original.writeTo(output);
try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), new NamedWriteableRegistry(List.of()))) {
in.setTransportVersion(version);
OpenPointInTimeRequest copy = new OpenPointInTimeRequest(in);
assertThat(copy.maxConcurrentShardRequests(), equalTo(5));
}
}
}
}
Loading