Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.search.fetch.FetchSubPhaseProcessor;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESIntegTestCase;

import java.io.IOException;
Expand Down Expand Up @@ -131,9 +132,10 @@ public void testLocalClusterAlias() {
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL);
IndexResponse indexResponse = client().index(indexRequest).actionGet();
assertEquals(RestStatus.CREATED, indexResponse.status());
TaskId parentTaskId = new TaskId("node", randomNonNegativeLong());

{
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY,
SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(), Strings.EMPTY_ARRAY,
"local", nowInMillis, randomBoolean());
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(1, searchResponse.getHits().getTotalHits().value);
Expand All @@ -145,7 +147,7 @@ public void testLocalClusterAlias() {
assertEquals("1", hit.getId());
}
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY,
SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(), Strings.EMPTY_ARRAY,
"", nowInMillis, randomBoolean());
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(1, searchResponse.getHits().getTotalHits().value);
Expand All @@ -159,6 +161,7 @@ public void testLocalClusterAlias() {
}

public void testAbsoluteStartMillis() {
TaskId parentTaskId = new TaskId("node", randomNonNegativeLong());
{
IndexRequest indexRequest = new IndexRequest("test-1970.01.01");
indexRequest.id("1");
Expand Down Expand Up @@ -187,21 +190,21 @@ public void testAbsoluteStartMillis() {
assertEquals(0, searchResponse.getTotalShards());
}
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(),
SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(),
Strings.EMPTY_ARRAY, "", 0, randomBoolean());
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(2, searchResponse.getHits().getTotalHits().value);
}
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(),
SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(),
Strings.EMPTY_ARRAY, "", 0, randomBoolean());
searchRequest.indices("<test-{now/d}>");
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(1, searchResponse.getHits().getTotalHits().value);
assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex());
}
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(),
SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(),
Strings.EMPTY_ARRAY, "", 0, randomBoolean());
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
RangeQueryBuilder rangeQuery = new RangeQueryBuilder("date");
Expand All @@ -217,6 +220,7 @@ public void testAbsoluteStartMillis() {

public void testFinalReduce() {
long nowInMillis = randomLongBetween(0, Long.MAX_VALUE);
TaskId taskId = new TaskId("node", randomNonNegativeLong());
{
IndexRequest indexRequest = new IndexRequest("test");
indexRequest.id("1");
Expand All @@ -243,7 +247,7 @@ public void testFinalReduce() {
source.aggregation(terms);

{
SearchRequest searchRequest = randomBoolean() ? originalRequest : SearchRequest.subSearchRequest(originalRequest,
SearchRequest searchRequest = randomBoolean() ? originalRequest : SearchRequest.subSearchRequest(taskId, originalRequest,
Strings.EMPTY_ARRAY, "remote", nowInMillis, true);
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(2, searchResponse.getHits().getTotalHits().value);
Expand All @@ -252,7 +256,7 @@ public void testFinalReduce() {
assertEquals(1, longTerms.getBuckets().size());
}
{
SearchRequest searchRequest = SearchRequest.subSearchRequest(originalRequest,
SearchRequest searchRequest = SearchRequest.subSearchRequest(taskId, originalRequest,
Strings.EMPTY_ARRAY, "remote", nowInMillis, false);
SearchResponse searchResponse = client().search(searchRequest).actionGet();
assertEquals(2, searchResponse.getHits().getTotalHits().value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

package org.elasticsearch.search.ccs;

import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.PlainActionFuture;
Expand All @@ -27,6 +31,7 @@
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.index.IndexModule;
import org.elasticsearch.index.query.MatchAllQueryBuilder;
Expand All @@ -36,11 +41,13 @@
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.test.AbstractMultiClustersTestCase;
import org.elasticsearch.test.InternalTestCluster;
import org.elasticsearch.test.NodeRoles;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.transport.TransportService;
import org.hamcrest.Matchers;
import org.junit.Before;

import java.util.Collection;
Expand Down Expand Up @@ -146,6 +153,70 @@ public void testProxyConnectionDisconnect() throws Exception {
}
}

public void testCancel() throws Exception {
assertAcked(client(LOCAL_CLUSTER).admin().indices().prepareCreate("demo"));
indexDocs(client(LOCAL_CLUSTER), "demo");
final InternalTestCluster remoteCluster = cluster("cluster_a");
remoteCluster.ensureAtLeastNumDataNodes(1);
final Settings.Builder allocationFilter = Settings.builder();
if (randomBoolean()) {
remoteCluster.ensureAtLeastNumDataNodes(3);
List<String> remoteDataNodes = StreamSupport.stream(remoteCluster.clusterService().state().nodes().spliterator(), false)
.filter(DiscoveryNode::isDataNode)
.map(DiscoveryNode::getName)
.collect(Collectors.toList());
assertThat(remoteDataNodes.size(), Matchers.greaterThanOrEqualTo(3));
List<String> seedNodes = randomSubsetOf(between(1, remoteDataNodes.size() - 1), remoteDataNodes);
disconnectFromRemoteClusters();
configureRemoteCluster("cluster_a", seedNodes);
if (randomBoolean()) {
// Using proxy connections
allocationFilter.put("index.routing.allocation.exclude._name", String.join(",", seedNodes));
} else {
allocationFilter.put("index.routing.allocation.include._name", String.join(",", seedNodes));
}
}
assertAcked(client("cluster_a").admin().indices().prepareCreate("prod")
.setSettings(Settings.builder().put(allocationFilter.build()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)));
assertFalse(client("cluster_a").admin().cluster().prepareHealth("prod")
.setWaitForYellowStatus().setTimeout(TimeValue.timeValueSeconds(10)).get().isTimedOut());
indexDocs(client("cluster_a"), "prod");
SearchListenerPlugin.blockQueryPhase();
PlainActionFuture<SearchResponse> queryFuture = new PlainActionFuture<>();
SearchRequest searchRequest = new SearchRequest("demo", "cluster_a:prod");
searchRequest.allowPartialSearchResults(false);
searchRequest.setCcsMinimizeRoundtrips(false);
searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(1000));
client(LOCAL_CLUSTER).search(searchRequest, queryFuture);
SearchListenerPlugin.waitSearchStarted();
// Get the search task and cancelled
final TaskInfo rootTask = client().admin().cluster().prepareListTasks()
.setActions(SearchAction.INSTANCE.name())
.get().getTasks().stream().filter(t -> t.getParentTaskId().isSet() == false)
.findFirst().get();
final CancelTasksRequest cancelRequest = new CancelTasksRequest().setTaskId(rootTask.getTaskId());
cancelRequest.setWaitForCompletion(randomBoolean());
final ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().cancelTasks(cancelRequest);
assertBusy(() -> {
final Iterable<TransportService> transportServices = cluster("cluster_a").getInstances(TransportService.class);
for (TransportService transportService : transportServices) {
Collection<CancellableTask> cancellableTasks = transportService.getTaskManager().getCancellableTasks().values();
for (CancellableTask cancellableTask : cancellableTasks) {
assertTrue(cancellableTask.getDescription(), cancellableTask.isCancelled());
}
}
});
SearchListenerPlugin.allowQueryPhase();
assertBusy(() -> assertTrue(queryFuture.isDone()));
assertBusy(() -> assertTrue(cancelFuture.isDone()));
assertBusy(() -> {
final Iterable<TransportService> transportServices = cluster("cluster_a").getInstances(TransportService.class);
for (TransportService transportService : transportServices) {
assertThat(transportService.getTaskManager().getBannedTaskIds(), Matchers.empty());
}
});
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins(String clusterAlias) {
if (clusterAlias.equals(LOCAL_CLUSTER)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,21 +140,25 @@ public SearchRequest(String[] indices, SearchSourceBuilder source) {
* Used when a {@link SearchRequest} is created and executed as part of a cross-cluster search request
* performing reduction on each cluster in order to minimize network round-trips between the coordinating node and the remote clusters.
*
* @param parentTaskId the parent taskId of the original search request
* @param originalSearchRequest the original search request
* @param indices the indices to search against
* @param clusterAlias the alias to prefix index names with in the returned search results
* @param absoluteStartMillis the absolute start time to be used on the remote clusters to ensure that the same value is used
* @param finalReduce whether the reduction should be final or not
*/
static SearchRequest subSearchRequest(SearchRequest originalSearchRequest, String[] indices,
static SearchRequest subSearchRequest(TaskId parentTaskId, SearchRequest originalSearchRequest, String[] indices,
String clusterAlias, long absoluteStartMillis, boolean finalReduce) {
Objects.requireNonNull(parentTaskId, "parentTaskId must be specified");
Objects.requireNonNull(originalSearchRequest, "search request must not be null");
validateIndices(indices);
Objects.requireNonNull(clusterAlias, "cluster alias must not be null");
if (absoluteStartMillis < 0) {
throw new IllegalArgumentException("absoluteStartMillis must not be negative but was [" + absoluteStartMillis + "]");
}
return new SearchRequest(originalSearchRequest, indices, clusterAlias, absoluteStartMillis, finalReduce);
final SearchRequest request = new SearchRequest(originalSearchRequest, indices, clusterAlias, absoluteStartMillis, finalReduce);
request.setParentTask(parentTaskId);
return request;
}

private SearchRequest(SearchRequest searchRequest, String[] indices, String localClusterAlias, long absoluteStartMillis,
Expand Down Expand Up @@ -304,7 +308,7 @@ boolean isFinalReduce() {
/**
* Returns the current time in milliseconds from the time epoch, to be used for the execution of this search request. Used to
* ensure that the same value, determined by the coordinating node, is used on all nodes involved in the execution of the search
* request. When created through {@link #subSearchRequest(SearchRequest, String[], String, long, boolean)}, this method returns
* request. When created through {@link #subSearchRequest(TaskId, SearchRequest, String[], String, long, boolean)}, this method returns
* the provided current time, otherwise it will return {@link System#currentTimeMillis()}.
*/
long getOrCreateAbsoluteStartMillis() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import org.elasticsearch.search.profile.ProfileShardResult;
import org.elasticsearch.search.profile.SearchProfileShardResults;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.RemoteClusterAware;
import org.elasticsearch.transport.RemoteClusterService;
Expand Down Expand Up @@ -295,7 +296,8 @@ private void executeRequest(Task task, SearchRequest searchRequest,
task, timeProvider, searchRequest, localIndices, clusterState, listener, searchContext, searchAsyncActionProvider);
} else {
if (shouldMinimizeRoundtrips(searchRequest)) {
ccsRemoteReduce(searchRequest, localIndices, remoteClusterIndices, timeProvider,
final TaskId parentTaskId = task.taskInfo(clusterService.localNode().getId(), false).getTaskId();
ccsRemoteReduce(parentTaskId, searchRequest, localIndices, remoteClusterIndices, timeProvider,
searchService.aggReduceContextBuilder(searchRequest),
remoteClusterService, threadPool, listener,
(r, l) -> executeLocalSearch(
Expand Down Expand Up @@ -357,8 +359,9 @@ static boolean shouldMinimizeRoundtrips(SearchRequest searchRequest) {
source.collapse().getInnerHits().isEmpty();
}

static void ccsRemoteReduce(SearchRequest searchRequest, OriginalIndices localIndices, Map<String, OriginalIndices> remoteIndices,
SearchTimeProvider timeProvider, InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
static void ccsRemoteReduce(TaskId parentTaskId, SearchRequest searchRequest, OriginalIndices localIndices,
Map<String, OriginalIndices> remoteIndices, SearchTimeProvider timeProvider,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
RemoteClusterService remoteClusterService, ThreadPool threadPool, ActionListener<SearchResponse> listener,
BiConsumer<SearchRequest, ActionListener<SearchResponse>> localSearchConsumer) {

Expand All @@ -369,7 +372,7 @@ static void ccsRemoteReduce(SearchRequest searchRequest, OriginalIndices localIn
String clusterAlias = entry.getKey();
boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias);
OriginalIndices indices = entry.getValue();
SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(searchRequest, indices.indices(),
SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(parentTaskId, searchRequest, indices.indices(),
clusterAlias, timeProvider.getAbsoluteStartMillis(), true);
Client remoteClusterClient = remoteClusterService.getRemoteClusterClient(threadPool, clusterAlias);
remoteClusterClient.search(ccsSearchRequest, new ActionListener<SearchResponse>() {
Expand Down Expand Up @@ -407,7 +410,7 @@ public void onFailure(Exception e) {
String clusterAlias = entry.getKey();
boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias);
OriginalIndices indices = entry.getValue();
SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(searchRequest, indices.indices(),
SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(parentTaskId, searchRequest, indices.indices(),
clusterAlias, timeProvider.getAbsoluteStartMillis(), false);
ActionListener<SearchResponse> ccsListener = createCCSListener(clusterAlias, skipUnavailable, countDown,
skippedClusters, exceptions, searchResponseMerger, totalClusters, listener);
Expand All @@ -417,7 +420,7 @@ public void onFailure(Exception e) {
if (localIndices != null) {
ActionListener<SearchResponse> ccsListener = createCCSListener(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY,
false, countDown, skippedClusters, exceptions, searchResponseMerger, totalClusters, listener);
SearchRequest ccsLocalSearchRequest = SearchRequest.subSearchRequest(searchRequest, localIndices.indices(),
SearchRequest ccsLocalSearchRequest = SearchRequest.subSearchRequest(parentTaskId, searchRequest, localIndices.indices(),
RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, timeProvider.getAbsoluteStartMillis(), false);
localSearchConsumer.accept(ccsLocalSearchRequest, ccsListener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ private void setBanOnChildConnections(String reason, boolean waitForCompletion,
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(listener.map(r -> null), childConnections.size());
final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(taskId, reason, waitForCompletion);
for (Transport.Connection connection : childConnections) {
assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped";
transportService.sendRequest(connection, BAN_PARENT_ACTION_NAME, banRequest, TransportRequestOptions.EMPTY,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override
Expand All @@ -167,6 +168,7 @@ private void removeBanOnChildConnections(CancellableTask task, Collection<Transp
final BanParentTaskRequest request =
BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(localNodeId(), task.getId()));
for (Transport.Connection connection : childConnections) {
assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped";
logger.trace("Sending remove ban for tasks with the parent [{}] for connection [{}]", request.parentTaskId, connection);
transportService.sendRequest(connection, BAN_PARENT_ACTION_NAME, request, TransportRequestOptions.EMPTY,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
Expand Down
Loading