Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.http;

import org.apache.http.client.methods.HttpGet;
import org.elasticsearch.action.admin.indices.recovery.RecoveryAction;
import org.elasticsearch.action.admin.indices.recovery.TransportRecoveryAction;
import org.elasticsearch.action.admin.indices.recovery.TransportRecoveryActionHelper;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Cancellable;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseListener;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.transport.TransportService;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CancellationException;
import java.util.concurrent.Semaphore;

import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.not;

public class IndicesRecoveryRestCancellationIT extends HttpSmokeTestCase {

public void testIndicesRecoveryRestCancellation() throws Exception {
runTest(new Request(HttpGet.METHOD_NAME, "/_recovery"));
}

public void testCatRecoveryRestCancellation() throws Exception {
runTest(new Request(HttpGet.METHOD_NAME, "/_cat/recovery"));
}

private void runTest(Request request) throws Exception {

createIndex("test");
ensureGreen("test");

final List<Semaphore> operationBlocks = new ArrayList<>();
for (final TransportRecoveryAction transportRecoveryAction : internalCluster().getInstances(TransportRecoveryAction.class)) {
final Semaphore operationBlock = new Semaphore(1);
operationBlocks.add(operationBlock);
TransportRecoveryActionHelper.setOnShardOperation(transportRecoveryAction, () -> {
try {
operationBlock.acquire();
} catch (InterruptedException e) {
throw new AssertionError(e);
}
operationBlock.release();
});
}
assertThat(operationBlocks, not(empty()));

final List<Releasable> releasables = new ArrayList<>();
try {
for (final Semaphore operationBlock : operationBlocks) {
operationBlock.acquire();
releasables.add(operationBlock::release);
}

final PlainActionFuture<Void> future = new PlainActionFuture<>();
logger.info("--> sending request");
final Cancellable cancellable = getRestClient().performRequestAsync(request, new ResponseListener() {
@Override
public void onSuccess(Response response) {
future.onResponse(null);
}

@Override
public void onFailure(Exception exception) {
future.onFailure(exception);
}
});

logger.info("--> waiting for task to start");
assertBusy(() -> {
final List<TaskInfo> tasks = client().admin().cluster().prepareListTasks().get().getTasks();
assertTrue(tasks.toString(), tasks.stream().anyMatch(t -> t.getAction().startsWith(RecoveryAction.NAME)));
});

logger.info("--> waiting for at least one task to hit a block");
assertBusy(() -> assertTrue(operationBlocks.stream().anyMatch(Semaphore::hasQueuedThreads)));

logger.info("--> cancelling request");
cancellable.cancel();
expectThrows(CancellationException.class, future::actionGet);

logger.info("--> checking that all tasks are marked as cancelled");
assertBusy(() -> {
boolean foundTask = false;
for (TransportService transportService : internalCluster().getInstances(TransportService.class)) {
for (CancellableTask cancellableTask : transportService.getTaskManager().getCancellableTasks().values()) {
if (cancellableTask.getAction().startsWith(RecoveryAction.NAME)) {
foundTask = true;
assertTrue("task " + cancellableTask.getId() + " not cancelled", cancellableTask.isCancelled());
}
}
}
assertTrue("found no cancellable tasks", foundTask);
});
} finally {
Releasables.close(releasables);
}

logger.info("--> checking that all tasks have finished");
assertBusy(() -> {
final List<TaskInfo> tasks = client().admin().cluster().prepareListTasks().get().getTasks();
assertTrue(tasks.toString(), tasks.stream().noneMatch(t -> t.getAction().startsWith(RecoveryAction.NAME)));
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;

import java.io.IOException;
import java.util.Map;

/**
* Request for recovery information
Expand Down Expand Up @@ -90,4 +94,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(detailed);
out.writeBoolean(activeOnly);
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, "", parentTaskId, headers);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.cluster.routing.ShardsIterator;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.index.IndexService;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.indices.recovery.RecoveryState;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
Expand Down Expand Up @@ -88,6 +90,8 @@ protected RecoveryRequest readRequestFrom(StreamInput in) throws IOException {

@Override
protected RecoveryState shardOperation(RecoveryRequest request, ShardRouting shardRouting, Task task) {
assert task instanceof CancellableTask;
runOnShardOperation();
IndexService indexService = indicesService.indexServiceSafe(shardRouting.shardId().getIndex());
IndexShard indexShard = indexService.getShard(shardRouting.shardId().id());
return indexShard.recoveryState();
Expand All @@ -107,4 +111,19 @@ protected ClusterBlockException checkGlobalBlock(ClusterState state, RecoveryReq
protected ClusterBlockException checkRequestBlock(ClusterState state, RecoveryRequest request, String[] concreteIndices) {
return state.blocks().indicesBlockedException(ClusterBlockLevel.METADATA_READ, concreteIndices);
}

@Nullable // unless running tests that inject extra behaviour
private volatile Runnable onShardOperation;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is totally stupid, but I just can't see a better way to block these operations for the purposes of this test.


private void runOnShardOperation() {
final Runnable onShardOperation = this.onShardOperation;
if (onShardOperation != null) {
onShardOperation.run();
}
}

// exposed for tests: inject some extra behaviour that runs when shardOperation() is called
void setOnShardOperation(@Nullable Runnable onShardOperation) {
this.onShardOperation = onShardOperation;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.rest.action.RestToXContentListener;

import java.io.IOException;
Expand Down Expand Up @@ -50,7 +51,8 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
recoveryRequest.detailed(request.paramAsBoolean("detailed", false));
recoveryRequest.activeOnly(request.paramAsBoolean("active_only", false));
recoveryRequest.indicesOptions(IndicesOptions.fromRequest(request, recoveryRequest.indicesOptions()));
return channel -> client.admin().indices().recoveries(recoveryRequest, new RestToXContentListener<>(channel));
return channel -> new RestCancellableNodeClient(client, request.getHttpChannel())
.admin().indices().recoveries(recoveryRequest, new RestToXContentListener<>(channel));
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.indices.recovery.RecoveryState;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.rest.action.RestResponseListener;

import java.util.Comparator;
Expand Down Expand Up @@ -63,7 +64,8 @@ public RestChannelConsumer doCatRequest(final RestRequest request, final NodeCli
recoveryRequest.activeOnly(request.paramAsBoolean("active_only", false));
recoveryRequest.indicesOptions(IndicesOptions.fromRequest(request, recoveryRequest.indicesOptions()));

return channel -> client.admin().indices().recoveries(recoveryRequest, new RestResponseListener<RecoveryResponse>(channel) {
return channel -> new RestCancellableNodeClient(client, request.getHttpChannel())
.admin().indices().recoveries(recoveryRequest, new RestResponseListener<RecoveryResponse>(channel) {
@Override
public RestResponse buildResponse(final RecoveryResponse response) throws Exception {
return RestTable.buildResponse(buildRecoveryTable(request, response), channel);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.action.admin.indices.recovery;

/**
* Helper methods for {@link TransportRecoveryAction}.
*/
public class TransportRecoveryActionHelper {

/**
* Helper method for tests to call {@link TransportRecoveryAction#setOnShardOperation}.
*/
public static void setOnShardOperation(TransportRecoveryAction transportRecoveryAction, Runnable setOnShardOperation) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Maybe just inline this into the test class that actually uses it? Why have a separate method for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put it here so I could keep TransportRecoveryAction#setOnShardOperation package-private (without needing the test itself to be in the same package as the action)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok, makes sense :)

transportRecoveryAction.setOnShardOperation(setOnShardOperation);
}
}