Skip to content

Commit

Permalink
Limit concurrent shards per node for ESQL (#104832)
Browse files Browse the repository at this point in the history
Today, we allow ESQL to execute against an unlimited number of shards
concurrently on each node. This can lead to cases where we open and hold
too many shards, equivalent to opening too many file descriptors or
using too much memory for FieldInfos in ValuesSourceReaderOperator.

This change limits the number of concurrent shards to 10 per node. This
number was chosen based on the _search API, which limits it to 5.
Besides the primary reason stated above, this change has other
implications:

We might execute fewer shards for queries with LIMIT only, leading to
scenarios where we execute only some high-priority shards then stop.
For now, we don't have a partial reduce at the node level, but if we
introduce one in the future, it might not be as efficient as executing
all shards at the same time.  There are pauses between batches because
batches are executed sequentially one by one.  However, I believe the
performance of queries executing against many shards (after can_match)
is less important than resiliency.

Closes #103666
Backport of #104832
  • Loading branch information
dnhatn authored Jan 30, 2024
1 parent b3bca4b commit 6bf2c71
Show file tree
Hide file tree
Showing 10 changed files with 212 additions and 36 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/104832.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 104832
summary: Limit concurrent shards per node for ESQL
area: ES|QL
type: bug
issues:
- 103666
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ protected SearchContext createContext(
return context;
}

public DefaultSearchContext createSearchContext(ShardSearchRequest request, TimeValue timeout) throws IOException {
public SearchContext createSearchContext(ShardSearchRequest request, TimeValue timeout) throws IOException {
final IndexService indexService = indicesService.indexServiceSafe(request.shardId().getIndex());
final IndexShard indexShard = indexService.getShard(request.shardId().getId());
final Engine.SearcherSupplier reader = indexShard.acquireSearcherSupplier();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1264,7 +1264,7 @@ public void testCreateSearchContext() throws IOException {
nowInMillis,
clusterAlias
);
try (DefaultSearchContext searchContext = service.createSearchContext(request, new TimeValue(System.currentTimeMillis()))) {
try (SearchContext searchContext = service.createSearchContext(request, new TimeValue(System.currentTimeMillis()))) {
SearchShardTarget searchShardTarget = searchContext.shardTarget();
SearchExecutionContext searchExecutionContext = searchContext.getSearchExecutionContext();
String expectedIndexName = clusterAlias == null ? index : clusterAlias + ":" + index;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.indices.ExecutorSelector;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
Expand Down Expand Up @@ -41,6 +42,7 @@ public static class TestPlugin extends Plugin {}
private static final Map<ReaderContext, Throwable> ACTIVE_SEARCH_CONTEXTS = new ConcurrentHashMap<>();

private Consumer<ReaderContext> onPutContext = context -> {};
private Consumer<ReaderContext> onRemoveContext = context -> {};

private Consumer<SearchContext> onCreateSearchContext = context -> {};

Expand Down Expand Up @@ -110,6 +112,7 @@ protected void putReaderContext(ReaderContext context) {
protected ReaderContext removeReaderContext(long id) {
final ReaderContext removed = super.removeReaderContext(id);
if (removed != null) {
onRemoveContext.accept(removed);
removeActiveContext(removed);
}
return removed;
Expand All @@ -119,6 +122,10 @@ public void setOnPutContext(Consumer<ReaderContext> onPutContext) {
this.onPutContext = onPutContext;
}

public void setOnRemoveContext(Consumer<ReaderContext> onRemoveContext) {
this.onRemoveContext = onRemoveContext;
}

public void setOnCreateSearchContext(Consumer<SearchContext> onCreateSearchContext) {
this.onCreateSearchContext = onCreateSearchContext;
}
Expand All @@ -141,6 +148,14 @@ protected SearchContext createContext(
return searchContext;
}

@Override
public SearchContext createSearchContext(ShardSearchRequest request, TimeValue timeout) throws IOException {
SearchContext searchContext = super.createSearchContext(request, timeout);
onPutContext.accept(searchContext.readerContext());
searchContext.addReleasable(() -> onRemoveContext.accept(searchContext.readerContext()));
return searchContext;
}

public void setOnCheckCancelled(Function<SearchShardTask, SearchShardTask> onCheckCancelled) {
this.onCheckCancelled = onCheckCancelled;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ public void addCompletionListener(ActionListener<Void> listener) {
completionFuture.addListener(listener);
}

boolean isFinished() {
/**
* Returns true if an exchange is finished
*/
public boolean isFinished() {
return completionFuture.isDone();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ protected static QueryPragmas randomPragmas() {
};
settings.put("page_size", pageSize);
}
if (randomBoolean()) {
settings.put("max_concurrent_shards_per_node", randomIntBetween(1, 10));
}
}
return new QueryPragmas(settings.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,36 @@
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.MockSearchService;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
import org.hamcrest.Matchers;
import org.junit.Before;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
* Make sures that we can run many concurrent requests with large number of shards with any data_partitioning.
*/
@LuceneTestCase.SuppressFileSystems(value = "HandleLimitFS")
public class ManyShardsIT extends AbstractEsqlIntegTestCase {

public void testConcurrentQueries() throws Exception {
@Override
protected Collection<Class<? extends Plugin>> getMockPlugins() {
var plugins = new ArrayList<>(super.getMockPlugins());
plugins.add(MockSearchService.TestPlugin.class);
return plugins;
}

@Before
public void setupIndices() {
int numIndices = between(10, 20);
for (int i = 0; i < numIndices; i++) {
String index = "test-" + i;
Expand All @@ -49,6 +66,9 @@ public void testConcurrentQueries() throws Exception {
}
bulk.get();
}
}

public void testConcurrentQueries() throws Exception {
int numQueries = between(10, 20);
Thread[] threads = new Thread[numQueries];
CountDownLatch latch = new CountDownLatch(1);
Expand Down Expand Up @@ -76,4 +96,57 @@ public void testConcurrentQueries() throws Exception {
thread.join();
}
}

static class SearchContextCounter {
private final int maxAllowed;
private final AtomicInteger current = new AtomicInteger();

SearchContextCounter(int maxAllowed) {
this.maxAllowed = maxAllowed;
}

void onNewContext() {
int total = current.incrementAndGet();
assertThat("opening more shards than the limit", total, Matchers.lessThanOrEqualTo(maxAllowed));
}

void onContextReleased() {
int total = current.decrementAndGet();
assertThat(total, Matchers.greaterThanOrEqualTo(0));
}
}

public void testLimitConcurrentShards() {
Iterable<SearchService> searchServices = internalCluster().getInstances(SearchService.class);
try {
var queries = List.of(
"from test-* | stats count(user) by tags",
"from test-* | stats count(user) by tags | LIMIT 0",
"from test-* | stats count(user) by tags | LIMIT 1",
"from test-* | stats count(user) by tags | LIMIT 1000",
"from test-* | LIMIT 0",
"from test-* | LIMIT 1",
"from test-* | LIMIT 1000",
"from test-* | SORT tags | LIMIT 0",
"from test-* | SORT tags | LIMIT 1",
"from test-* | SORT tags | LIMIT 1000"
);
for (String q : queries) {
QueryPragmas pragmas = randomPragmas();
for (SearchService searchService : searchServices) {
SearchContextCounter counter = new SearchContextCounter(pragmas.maxConcurrentShardsPerNode());
var mockSearchService = (MockSearchService) searchService;
mockSearchService.setOnPutContext(r -> counter.onNewContext());
mockSearchService.setOnRemoveContext(r -> counter.onContextReleased());
}
run(q, pragmas).close();
}
} finally {
for (SearchService searchService : searchServices) {
var mockSearchService = (MockSearchService) searchService;
mockSearchService.setOnPutContext(r -> {});
mockSearchService.setOnRemoveContext(r -> {});
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.test.junit.annotations.TestLogging;
Expand Down Expand Up @@ -40,7 +41,11 @@ public void testCollectWarnings() {
client().admin()
.indices()
.prepareCreate("index-1")
.setSettings(Settings.builder().put("index.routing.allocation.require._name", node1))
.setSettings(
Settings.builder()
.put("index.routing.allocation.require._name", node1)
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 5))
)
.setMapping("host", "type=keyword")
);
for (int i = 0; i < numDocs1; i++) {
Expand All @@ -51,7 +56,11 @@ public void testCollectWarnings() {
client().admin()
.indices()
.prepareCreate("index-2")
.setSettings(Settings.builder().put("index.routing.allocation.require._name", node2))
.setSettings(
Settings.builder()
.put("index.routing.allocation.require._name", node2)
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 5))
)
.setMapping("host", "type=keyword")
);
for (int i = 0; i < numDocs2; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.action.search.SearchShardsRequest;
import org.elasticsearch.action.search.SearchShardsResponse;
import org.elasticsearch.action.search.TransportSearchShardsAction;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.RefCountingListener;
import org.elasticsearch.action.support.RefCountingRunnable;
Expand All @@ -26,7 +27,6 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.compute.OwningChannelActionListener;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.Driver;
Expand All @@ -35,6 +35,7 @@
import org.elasticsearch.compute.operator.ResponseHeadersCollector;
import org.elasticsearch.compute.operator.exchange.ExchangeResponse;
import org.elasticsearch.compute.operator.exchange.ExchangeService;
import org.elasticsearch.compute.operator.exchange.ExchangeSink;
import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler;
import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler;
import org.elasticsearch.core.IOUtils;
Expand Down Expand Up @@ -279,7 +280,7 @@ private ActionListener<Void> cancelOnFailure(CancellableTask task, AtomicBoolean
}

void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, ActionListener<List<DriverProfile>> listener) {
listener = ActionListener.runAfter(listener, () -> Releasables.close(context.searchContexts));
listener = ActionListener.runBefore(listener, () -> Releasables.close(context.searchContexts));
final List<Driver> drivers;
try {
LocalExecutionPlanner planner = new LocalExecutionPlanner(
Expand Down Expand Up @@ -500,37 +501,93 @@ public void writeTo(StreamOutput out) throws IOException {
// TODO: Use an internal action here
public static final String DATA_ACTION_NAME = EsqlQueryAction.NAME + "/data";

private class DataNodeRequestExecutor {
private final DataNodeRequest request;
private final CancellableTask parentTask;
private final ExchangeSinkHandler exchangeSink;
private final ActionListener<DataNodeResponse> listener;
private final List<DriverProfile> driverProfiles;
private final int maxConcurrentShards;
private final ExchangeSink blockingSink; // block until we have completed on all shards or the coordinator has enough data

DataNodeRequestExecutor(
DataNodeRequest request,
CancellableTask parentTask,
ExchangeSinkHandler exchangeSink,
int maxConcurrentShards,
ActionListener<DataNodeResponse> listener
) {
this.request = request;
this.parentTask = parentTask;
this.exchangeSink = exchangeSink;
this.listener = listener;
this.driverProfiles = request.configuration().profile() ? Collections.synchronizedList(new ArrayList<>()) : List.of();
this.maxConcurrentShards = maxConcurrentShards;
this.blockingSink = exchangeSink.createExchangeSink();
}

void start() {
parentTask.addListener(
() -> exchangeService.finishSinkHandler(request.sessionId(), new TaskCancelledException(parentTask.getReasonCancelled()))
);
runBatch(0);
}

private void runBatch(int startBatchIndex) {
final EsqlConfiguration configuration = request.configuration();
final var sessionId = request.sessionId();
final int endBatchIndex = Math.min(startBatchIndex + maxConcurrentShards, request.shardIds().size());
List<ShardId> shardIds = request.shardIds().subList(startBatchIndex, endBatchIndex);
acquireSearchContexts(shardIds, configuration, request.aliasFilters(), ActionListener.wrap(searchContexts -> {
assert ThreadPool.assertCurrentThreadPool(ESQL_THREAD_POOL_NAME, ESQL_WORKER_THREAD_POOL_NAME);
var computeContext = new ComputeContext(sessionId, searchContexts, configuration, null, exchangeSink);
runCompute(
parentTask,
computeContext,
request.plan(),
ActionListener.wrap(profiles -> onBatchCompleted(endBatchIndex, profiles), this::onFailure)
);
}, this::onFailure));
}

private void onBatchCompleted(int lastBatchIndex, List<DriverProfile> batchProfiles) {
if (request.configuration().profile()) {
driverProfiles.addAll(batchProfiles);
}
if (lastBatchIndex < request.shardIds().size() && exchangeSink.isFinished() == false) {
runBatch(lastBatchIndex);
} else {
blockingSink.finish();
// don't return until all pages are fetched
exchangeSink.addCompletionListener(
ContextPreservingActionListener.wrapPreservingContext(
ActionListener.runBefore(
listener.map(nullValue -> new DataNodeResponse(driverProfiles)),
() -> exchangeService.finishSinkHandler(request.sessionId(), null)
),
transportService.getThreadPool().getThreadContext()
)
);
}
}

private void onFailure(Exception e) {
exchangeService.finishSinkHandler(request.sessionId(), e);
listener.onFailure(e);
}
}

private class DataNodeRequestHandler implements TransportRequestHandler<DataNodeRequest> {
@Override
public void messageReceived(DataNodeRequest request, TransportChannel channel, Task task) {
final var parentTask = (CancellableTask) task;
final var sessionId = request.sessionId();
final var exchangeSink = exchangeService.getSinkHandler(sessionId);
parentTask.addListener(() -> exchangeService.finishSinkHandler(sessionId, new TaskCancelledException("task cancelled")));
final ActionListener<DataNodeResponse> listener = new OwningChannelActionListener<>(channel);
final EsqlConfiguration configuration = request.configuration();
acquireSearchContexts(request.shardIds(), configuration, request.aliasFilters(), ActionListener.wrap(searchContexts -> {
assert ThreadPool.assertCurrentThreadPool(ESQL_THREAD_POOL_NAME);
var computeContext = new ComputeContext(sessionId, searchContexts, configuration, null, exchangeSink);
runCompute(parentTask, computeContext, request.plan(), ActionListener.wrap(driverProfiles -> {
// don't return until all pages are fetched
exchangeSink.addCompletionListener(
ContextPreservingActionListener.wrapPreservingContext(
ActionListener.releaseAfter(
listener.map(nullValue -> new DataNodeResponse(driverProfiles)),
() -> exchangeService.finishSinkHandler(sessionId, null)
),
transportService.getThreadPool().getThreadContext()
)
);
}, e -> {
exchangeService.finishSinkHandler(sessionId, e);
listener.onFailure(e);
}));
}, e -> {
exchangeService.finishSinkHandler(sessionId, e);
listener.onFailure(e);
}));
DataNodeRequestExecutor executor = new DataNodeRequestExecutor(
request,
(CancellableTask) task,
exchangeService.getSinkHandler(request.sessionId()),
request.configuration().pragmas().maxConcurrentShardsPerNode(),
new ChannelActionListener<>(channel)
);
executor.start();
}
}

Expand Down
Loading

0 comments on commit 6bf2c71

Please sign in to comment.