Skip to content

Commit c3da66d

Browse files
authored
Implement adaptive replica selection (#26128)
* Implement adaptive replica selection This implements the selection algorithm described in the C3 paper for determining which copy of the data a query should be routed to. By using the service time EWMA, response time EWMA, and queue size EWMA we calculate the score of a node by piggybacking these metrics with each search request. Since Elasticsearch lacks the "broadcast to every copy" behavior that Cassandra has (as mentioned in the C3 paper) to update metrics after a node has been highly weighted, this implementation adjusts a node's response stats using the average of the its own and the "best" node's metrics. This is so that a long GC or other activity that may cause a node's rank to increase dramatically does not permanently keep a node from having requests routed to it, instead it will eventually lower its score back to the realm where it is a potential candidate for new queries. This feature is off by default and can be turned on with the dynamic setting `cluster.routing.use_adaptive_replica_selection`. Relates to #24915, however instead of `b=3` I used `b=4` (after benchmarking) * Randomly use adaptive replica selection for internal test cluster * Use an action name *prefix* for retrieving pending requests * Add unit test for replica selection * don't use adaptive replica selection in SearchPreferenceIT * Track client connections in a SearchTransportService instead of TransportService * Bind `entry` pieces in local variables * Add javadoc link to C3 paper and javadocs for stat adjustments * Bind entry's key and value to local variables * Remove unneeded actionNamePrefix parameter * Use conns.longValue() instead of cached Long * Add comments about removing entries from the map * Pull out bindings for `entry` in IndexShardRoutingTable * Use .compareTo instead of manually comparing * add assert for connections not being null and gte to 1 * Copy map for pending search connections instead of "live" map * Increase the number of pending search requests used for calculating rank when chosen When a node gets chosen, this increases the number of search counts for the winning node so that it will not be as likely to be chosen again for non-concurrent search requests. * Remove unused HashMap import * Rename rank -> rankShardsAndUpdateStats * Rename rankedActiveInitializingShardsIt -> activeInitializingShardsRankedIt * Instead of precalculating winning node, use "winning" shard from ranked list * Sort null ranked nodes before nodes that have a rank
1 parent 432f162 commit c3da66d

File tree

16 files changed

+458
-45
lines changed

16 files changed

+458
-45
lines changed

core/src/main/java/org/elasticsearch/action/search/SearchExecutionStatsCollector.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public void onResponse(SearchPhaseResult response) {
6161
final int queueSize = queryResult.nodeQueueSize();
6262
final long responseDuration = System.nanoTime() - startNanos;
6363
// EWMA/queue size may be -1 if the query node doesn't support capturing it
64-
if (serviceTimeEWMA > 0 && queueSize > 0) {
64+
if (serviceTimeEWMA > 0 && queueSize >= 0) {
6565
collector.addNodeStatistics(nodeId, queueSize, responseDuration, serviceTimeEWMA);
6666
}
6767
}

core/src/main/java/org/elasticsearch/action/search/SearchTransportService.java

+68-9
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.elasticsearch.common.io.stream.StreamInput;
3131
import org.elasticsearch.common.io.stream.StreamOutput;
3232
import org.elasticsearch.common.settings.Settings;
33+
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
3334
import org.elasticsearch.search.SearchPhaseResult;
3435
import org.elasticsearch.search.SearchService;
3536
import org.elasticsearch.search.dfs.DfsSearchResult;
@@ -50,13 +51,17 @@
5051
import org.elasticsearch.transport.TransportActionProxy;
5152
import org.elasticsearch.transport.TaskAwareTransportRequestHandler;
5253
import org.elasticsearch.transport.TransportChannel;
54+
import org.elasticsearch.transport.TransportException;
5355
import org.elasticsearch.transport.TransportRequest;
5456
import org.elasticsearch.transport.TransportRequestOptions;
5557
import org.elasticsearch.transport.TransportResponse;
5658
import org.elasticsearch.transport.TransportService;
5759

5860
import java.io.IOException;
5961
import java.io.UncheckedIOException;
62+
import java.util.Collections;
63+
import java.util.HashMap;
64+
import java.util.Map;
6065
import java.util.function.BiFunction;
6166
import java.util.function.Supplier;
6267

@@ -80,6 +85,7 @@ public class SearchTransportService extends AbstractComponent {
8085

8186
private final TransportService transportService;
8287
private final BiFunction<Transport.Connection, SearchActionListener, ActionListener> responseWrapper;
88+
private final Map<String, Long> clientConnections = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();
8389

8490
public SearchTransportService(Settings settings, TransportService transportService,
8591
BiFunction<Transport.Connection, SearchActionListener, ActionListener> responseWrapper) {
@@ -131,7 +137,7 @@ public void sendClearAllScrollContexts(Transport.Connection connection, final Ac
131137
public void sendExecuteDfs(Transport.Connection connection, final ShardSearchTransportRequest request, SearchTask task,
132138
final SearchActionListener<DfsSearchResult> listener) {
133139
transportService.sendChildRequest(connection, DFS_ACTION_NAME, request, task,
134-
new ActionListenerResponseHandler<>(listener, DfsSearchResult::new));
140+
new ConnectionCountingHandler<>(listener, DfsSearchResult::new, clientConnections, connection.getNode().getId()));
135141
}
136142

137143
public void sendExecuteQuery(Transport.Connection connection, final ShardSearchTransportRequest request, SearchTask task,
@@ -143,25 +149,26 @@ public void sendExecuteQuery(Transport.Connection connection, final ShardSearchT
143149

144150
final ActionListener handler = responseWrapper.apply(connection, listener);
145151
transportService.sendChildRequest(connection, QUERY_ACTION_NAME, request, task,
146-
new ActionListenerResponseHandler<>(handler, supplier));
152+
new ConnectionCountingHandler<>(handler, supplier, clientConnections, connection.getNode().getId()));
147153
}
148154

149155
public void sendExecuteQuery(Transport.Connection connection, final QuerySearchRequest request, SearchTask task,
150156
final SearchActionListener<QuerySearchResult> listener) {
151157
transportService.sendChildRequest(connection, QUERY_ID_ACTION_NAME, request, task,
152-
new ActionListenerResponseHandler<>(listener, QuerySearchResult::new));
158+
new ConnectionCountingHandler<>(listener, QuerySearchResult::new, clientConnections, connection.getNode().getId()));
153159
}
154160

155161
public void sendExecuteScrollQuery(Transport.Connection connection, final InternalScrollSearchRequest request, SearchTask task,
156162
final SearchActionListener<ScrollQuerySearchResult> listener) {
157163
transportService.sendChildRequest(connection, QUERY_SCROLL_ACTION_NAME, request, task,
158-
new ActionListenerResponseHandler<>(listener, ScrollQuerySearchResult::new));
164+
new ConnectionCountingHandler<>(listener, ScrollQuerySearchResult::new, clientConnections, connection.getNode().getId()));
159165
}
160166

161167
public void sendExecuteScrollFetch(Transport.Connection connection, final InternalScrollSearchRequest request, SearchTask task,
162168
final SearchActionListener<ScrollQueryFetchSearchResult> listener) {
163169
transportService.sendChildRequest(connection, QUERY_FETCH_SCROLL_ACTION_NAME, request, task,
164-
new ActionListenerResponseHandler<>(listener, ScrollQueryFetchSearchResult::new));
170+
new ConnectionCountingHandler<>(listener, ScrollQueryFetchSearchResult::new,
171+
clientConnections, connection.getNode().getId()));
165172
}
166173

167174
public void sendExecuteFetch(Transport.Connection connection, final ShardFetchSearchRequest request, SearchTask task,
@@ -177,22 +184,31 @@ public void sendExecuteFetchScroll(Transport.Connection connection, final ShardF
177184
private void sendExecuteFetch(Transport.Connection connection, String action, final ShardFetchRequest request, SearchTask task,
178185
final SearchActionListener<FetchSearchResult> listener) {
179186
transportService.sendChildRequest(connection, action, request, task,
180-
new ActionListenerResponseHandler<>(listener, FetchSearchResult::new));
187+
new ConnectionCountingHandler<>(listener, FetchSearchResult::new, clientConnections, connection.getNode().getId()));
181188
}
182189

183190
/**
184191
* Used by {@link TransportSearchAction} to send the expand queries (field collapsing).
185192
*/
186193
void sendExecuteMultiSearch(final MultiSearchRequest request, SearchTask task,
187-
final ActionListener<MultiSearchResponse> listener) {
188-
transportService.sendChildRequest(transportService.getConnection(transportService.getLocalNode()), MultiSearchAction.NAME, request,
189-
task, new ActionListenerResponseHandler<>(listener, MultiSearchResponse::new));
194+
final ActionListener<MultiSearchResponse> listener) {
195+
final Transport.Connection connection = transportService.getConnection(transportService.getLocalNode());
196+
transportService.sendChildRequest(connection, MultiSearchAction.NAME, request, task,
197+
new ConnectionCountingHandler<>(listener, MultiSearchResponse::new, clientConnections, connection.getNode().getId()));
190198
}
191199

192200
public RemoteClusterService getRemoteClusterService() {
193201
return transportService.getRemoteClusterService();
194202
}
195203

204+
/**
205+
* Return a map of nodeId to pending number of search requests.
206+
* This is a snapshot of the current pending search and not a live map.
207+
*/
208+
public Map<String, Long> getPendingSearchRequests() {
209+
return new HashMap<>(clientConnections);
210+
}
211+
196212
static class ScrollFreeContextRequest extends TransportRequest {
197213
private long id;
198214

@@ -486,4 +502,47 @@ Transport.Connection getConnection(String clusterAlias, DiscoveryNode node) {
486502
return transportService.getRemoteClusterService().getConnection(node, clusterAlias);
487503
}
488504
}
505+
506+
final class ConnectionCountingHandler<Response extends TransportResponse> extends ActionListenerResponseHandler<Response> {
507+
private final Map<String, Long> clientConnections;
508+
private final String nodeId;
509+
510+
ConnectionCountingHandler(final ActionListener<? super Response> listener, final Supplier<Response> responseSupplier,
511+
final Map<String, Long> clientConnections, final String nodeId) {
512+
super(listener, responseSupplier);
513+
this.clientConnections = clientConnections;
514+
this.nodeId = nodeId;
515+
// Increment the number of connections for this node by one
516+
clientConnections.compute(nodeId, (id, conns) -> conns == null ? 1 : conns + 1);
517+
}
518+
519+
@Override
520+
public void handleResponse(Response response) {
521+
super.handleResponse(response);
522+
// Decrement the number of connections or remove it entirely if there are no more connections
523+
// We need to remove the entry here so we don't leak when nodes go away forever
524+
assert assertNodePresent();
525+
clientConnections.computeIfPresent(nodeId, (id, conns) -> conns.longValue() == 1 ? null : conns - 1);
526+
}
527+
528+
@Override
529+
public void handleException(TransportException e) {
530+
super.handleException(e);
531+
// Decrement the number of connections or remove it entirely if there are no more connections
532+
// We need to remove the entry here so we don't leak when nodes go away forever
533+
assert assertNodePresent();
534+
clientConnections.computeIfPresent(nodeId, (id, conns) -> conns.longValue() == 1 ? null : conns - 1);
535+
}
536+
537+
private boolean assertNodePresent() {
538+
clientConnections.compute(nodeId, (id, conns) -> {
539+
assert conns != null : "number of connections for " + id + " is null, but should be an integer";
540+
assert conns >= 1 : "number of connections for " + id + " should be >= 1 but was " + conns;
541+
return conns;
542+
});
543+
// Always return true, there is additional asserting here, the boolean is just so this
544+
// can be skipped when assertions are not enabled
545+
return true;
546+
}
547+
}
489548
}

core/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,9 @@ private void executeSearch(SearchTask task, SearchTimeProvider timeProvider, Sea
284284
for (int i = 0; i < indices.length; i++) {
285285
concreteIndices[i] = indices[i].getName();
286286
}
287+
Map<String, Long> nodeSearchCounts = searchTransportService.getPendingSearchRequests();
287288
GroupShardsIterator<ShardIterator> localShardsIterator = clusterService.operationRouting().searchShards(clusterState,
288-
concreteIndices, routingMap, searchRequest.preference());
289+
concreteIndices, routingMap, searchRequest.preference(), searchService.getResponseCollectorService(), nodeSearchCounts);
289290
GroupShardsIterator<SearchShardIterator> shardIterators = mergeShardsIterators(localShardsIterator, localIndices,
290291
remoteShardIterators);
291292

core/src/main/java/org/elasticsearch/cluster/routing/IndexShardRoutingTable.java

+165
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,24 @@
2929
import org.elasticsearch.common.util.set.Sets;
3030
import org.elasticsearch.index.Index;
3131
import org.elasticsearch.index.shard.ShardId;
32+
import org.elasticsearch.node.ResponseCollectorService;
3233

3334
import java.io.IOException;
3435
import java.util.ArrayList;
3536
import java.util.Arrays;
3637
import java.util.Collections;
38+
import java.util.Comparator;
39+
import java.util.HashMap;
3740
import java.util.HashSet;
3841
import java.util.Iterator;
3942
import java.util.LinkedList;
4043
import java.util.List;
4144
import java.util.Locale;
4245
import java.util.Map;
46+
import java.util.Optional;
47+
import java.util.OptionalDouble;
4348
import java.util.Set;
49+
import java.util.stream.Collectors;
4450

4551
import static java.util.Collections.emptyMap;
4652

@@ -261,6 +267,165 @@ public ShardIterator activeInitializingShardsIt(int seed) {
261267
return new PlainShardIterator(shardId, ordered);
262268
}
263269

270+
/**
271+
* Returns an iterator over active and initializing shards, ordered by the adaptive replica
272+
* selection forumla. Making sure though that its random within the active shards of the same
273+
* (or missing) rank, and initializing shards are the last to iterate through.
274+
*/
275+
public ShardIterator activeInitializingShardsRankedIt(@Nullable ResponseCollectorService collector,
276+
@Nullable Map<String, Long> nodeSearchCounts) {
277+
final int seed = shuffler.nextSeed();
278+
if (allInitializingShards.isEmpty()) {
279+
return new PlainShardIterator(shardId,
280+
rankShardsAndUpdateStats(shuffler.shuffle(activeShards, seed), collector, nodeSearchCounts));
281+
}
282+
283+
ArrayList<ShardRouting> ordered = new ArrayList<>(activeShards.size() + allInitializingShards.size());
284+
List<ShardRouting> rankedActiveShards =
285+
rankShardsAndUpdateStats(shuffler.shuffle(activeShards, seed), collector, nodeSearchCounts);
286+
ordered.addAll(rankedActiveShards);
287+
List<ShardRouting> rankedInitializingShards =
288+
rankShardsAndUpdateStats(allInitializingShards, collector, nodeSearchCounts);
289+
ordered.addAll(rankedInitializingShards);
290+
return new PlainShardIterator(shardId, ordered);
291+
}
292+
293+
private static Set<String> getAllNodeIds(final List<ShardRouting> shards) {
294+
final Set<String> nodeIds = new HashSet<>();
295+
for (ShardRouting shard : shards) {
296+
nodeIds.add(shard.currentNodeId());
297+
}
298+
return nodeIds;
299+
}
300+
301+
private static Map<String, Optional<ResponseCollectorService.ComputedNodeStats>>
302+
getNodeStats(final Set<String> nodeIds, final ResponseCollectorService collector) {
303+
304+
final Map<String, Optional<ResponseCollectorService.ComputedNodeStats>> nodeStats = new HashMap<>(nodeIds.size());
305+
for (String nodeId : nodeIds) {
306+
nodeStats.put(nodeId, collector.getNodeStatistics(nodeId));
307+
}
308+
return nodeStats;
309+
}
310+
311+
private static Map<String, Double> rankNodes(final Map<String, Optional<ResponseCollectorService.ComputedNodeStats>> nodeStats,
312+
final Map<String, Long> nodeSearchCounts) {
313+
final Map<String, Double> nodeRanks = new HashMap<>(nodeStats.size());
314+
for (Map.Entry<String, Optional<ResponseCollectorService.ComputedNodeStats>> entry : nodeStats.entrySet()) {
315+
Optional<ResponseCollectorService.ComputedNodeStats> maybeStats = entry.getValue();
316+
maybeStats.ifPresent(stats -> {
317+
final String nodeId = entry.getKey();
318+
nodeRanks.put(nodeId, stats.rank(nodeSearchCounts.getOrDefault(nodeId, 1L)));
319+
});
320+
}
321+
return nodeRanks;
322+
}
323+
324+
/**
325+
* Adjust the for all other nodes' collected stats. In the original ranking paper there is no need to adjust other nodes' stats because
326+
* Cassandra sends occasional requests to all copies of the data, so their stats will be updated during that broadcast phase. In
327+
* Elasticsearch, however, we do not have that sort of broadcast-to-all behavior. In order to prevent a node that gets a high score and
328+
* then never gets any more requests, we must ensure it eventually returns to a more normal score and can be a candidate for serving
329+
* requests.
330+
*
331+
* This adjustment takes the "winning" node's statistics and adds the average of those statistics with each non-winning node. Let's say
332+
* the winning node had a queue size of 10 and a non-winning node had a queue of 18. The average queue size is (10 + 18) / 2 = 14 so the
333+
* non-winning node will have statistics added for a queue size of 14. This is repeated for the response time and service times as well.
334+
*/
335+
private static void adjustStats(final ResponseCollectorService collector,
336+
final Map<String, Optional<ResponseCollectorService.ComputedNodeStats>> nodeStats,
337+
final String minNodeId,
338+
final ResponseCollectorService.ComputedNodeStats minStats) {
339+
if (minNodeId != null) {
340+
for (Map.Entry<String, Optional<ResponseCollectorService.ComputedNodeStats>> entry : nodeStats.entrySet()) {
341+
final String nodeId = entry.getKey();
342+
final Optional<ResponseCollectorService.ComputedNodeStats> maybeStats = entry.getValue();
343+
if (nodeId.equals(minNodeId) == false && maybeStats.isPresent()) {
344+
final ResponseCollectorService.ComputedNodeStats stats = maybeStats.get();
345+
final int updatedQueue = (minStats.queueSize + stats.queueSize) / 2;
346+
final long updatedResponse = (long) (minStats.responseTime + stats.responseTime) / 2;
347+
final long updatedService = (long) (minStats.serviceTime + stats.serviceTime) / 2;
348+
collector.addNodeStatistics(nodeId, updatedQueue, updatedResponse, updatedService);
349+
}
350+
}
351+
}
352+
}
353+
354+
private static List<ShardRouting> rankShardsAndUpdateStats(List<ShardRouting> shards, final ResponseCollectorService collector,
355+
final Map<String, Long> nodeSearchCounts) {
356+
if (collector == null || nodeSearchCounts == null || shards.size() <= 1) {
357+
return shards;
358+
}
359+
360+
// Retrieve which nodes we can potentially send the query to
361+
final Set<String> nodeIds = getAllNodeIds(shards);
362+
final int nodeCount = nodeIds.size();
363+
364+
final Map<String, Optional<ResponseCollectorService.ComputedNodeStats>> nodeStats = getNodeStats(nodeIds, collector);
365+
366+
// Retrieve all the nodes the shards exist on
367+
final Map<String, Double> nodeRanks = rankNodes(nodeStats, nodeSearchCounts);
368+
369+
// sort all shards based on the shard rank
370+
ArrayList<ShardRouting> sortedShards = new ArrayList<>(shards);
371+
Collections.sort(sortedShards, new NodeRankComparator(nodeRanks));
372+
373+
// adjust the non-winner nodes' stats so they will get a chance to receive queries
374+
if (sortedShards.size() > 1) {
375+
ShardRouting minShard = sortedShards.get(0);
376+
// If the winning shard is not started we are ranking initializing
377+
// shards, don't bother to do adjustments
378+
if (minShard.started()) {
379+
String minNodeId = minShard.currentNodeId();
380+
Optional<ResponseCollectorService.ComputedNodeStats> maybeMinStats = nodeStats.get(minNodeId);
381+
if (maybeMinStats.isPresent()) {
382+
adjustStats(collector, nodeStats, minNodeId, maybeMinStats.get());
383+
// Increase the number of searches for the "winning" node by one.
384+
// Note that this doesn't actually affect the "real" counts, instead
385+
// it only affects the captured node search counts, which is
386+
// captured once for each query in TransportSearchAction
387+
nodeSearchCounts.compute(minNodeId, (id, conns) -> conns == null ? 1 : conns + 1);
388+
}
389+
}
390+
}
391+
392+
return sortedShards;
393+
}
394+
395+
private static class NodeRankComparator implements Comparator<ShardRouting> {
396+
private final Map<String, Double> nodeRanks;
397+
398+
NodeRankComparator(Map<String, Double> nodeRanks) {
399+
this.nodeRanks = nodeRanks;
400+
}
401+
402+
@Override
403+
public int compare(ShardRouting s1, ShardRouting s2) {
404+
if (s1.currentNodeId().equals(s2.currentNodeId())) {
405+
// these shards on the the same node
406+
return 0;
407+
}
408+
Double shard1rank = nodeRanks.get(s1.currentNodeId());
409+
Double shard2rank = nodeRanks.get(s2.currentNodeId());
410+
if (shard1rank != null) {
411+
if (shard2rank != null) {
412+
return shard1rank.compareTo(shard2rank);
413+
} else {
414+
// place non-nulls after null values
415+
return 1;
416+
}
417+
} else {
418+
if (shard2rank != null) {
419+
// place nulls before non-null values
420+
return -1;
421+
} else {
422+
// Both nodes do not have stats, they are equal
423+
return 0;
424+
}
425+
}
426+
}
427+
}
428+
264429
/**
265430
* Returns true if no primaries are active or initializing for this shard
266431
*/

0 commit comments

Comments
 (0)