Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,81 @@

package org.elasticsearch.xpack.profiler;

import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.apache.logging.log4j.LogManager;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.admin.cluster.node.info.NodeInfo;
import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Cancellable;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.common.network.NetworkModule;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.PluginsService;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.script.MockScriptPlugin;
import org.elasticsearch.search.lookup.LeafStoredFieldsLookup;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.transport.netty4.Netty4Plugin;
import org.elasticsearch.xcontent.XContentType;
import org.junit.Before;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CancellationException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;

import static org.elasticsearch.action.support.ActionTestUtils.wrapAsRestResponseListener;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.instanceOf;

@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 1)
public class GetProfilingActionIT extends ESIntegTestCase {
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return List.of(ProfilingPlugin.class);
return List.of(ProfilingPlugin.class, ScriptedBlockPlugin.class, getTestTransportPlugin());
}

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder()
.put(super.nodeSettings(nodeOrdinal, otherSettings))
.put(ProfilingPlugin.PROFILING_ENABLED.getKey(), true)
.put(NetworkModule.TRANSPORT_TYPE_KEY, Netty4Plugin.NETTY_TRANSPORT_NAME)
.put(NetworkModule.HTTP_TYPE_KEY, Netty4Plugin.NETTY_HTTP_TRANSPORT_NAME)
.build();
}

@Override
protected boolean addMockHttpTransport() {
return false; // enable http
}

@Override
protected boolean ignoreExternalCluster() {
return true;
}

private byte[] read(String resource) throws IOException {
return GetProfilingAction.class.getClassLoader().getResourceAsStream(resource).readAllBytes();
}
Expand Down Expand Up @@ -104,4 +151,137 @@ public void testGetProfilingDataUnfiltered() throws Exception {
assertNotNull(response.getExecutables());
assertNotNull("libc.so.6", response.getExecutables().get("QCCDqjSg3bMK1C4YRK6Tiw"));
}

public void testAutomaticCancellation() throws Exception {
Request restRequest = new Request("POST", "/_profiling/stacktraces");
restRequest.setEntity(new StringEntity("""
{
"sample_size": 10000,
"query": {
"bool": {
"filter": [
{
"script": {
"script": {
"lang": "mockscript",
"source": "search_block",
"params": {}
}
}
}
]
}
}
}
""", ContentType.APPLICATION_JSON.withCharset(StandardCharsets.UTF_8)));
verifyCancellation(GetProfilingAction.NAME, restRequest);
}

void verifyCancellation(String action, Request restRequest) throws Exception {
Map<String, String> nodeIdToName = readNodesInfo();
List<ScriptedBlockPlugin> plugins = initBlockFactory();

PlainActionFuture<Response> future = PlainActionFuture.newFuture();
Cancellable cancellable = getRestClient().performRequestAsync(restRequest, wrapAsRestResponseListener(future));

awaitForBlock(plugins);
cancellable.cancel();
ensureTaskIsCancelled(action, nodeIdToName::get);

disableBlocks(plugins);
expectThrows(CancellationException.class, future::actionGet);
}

private static Map<String, String> readNodesInfo() {
Map<String, String> nodeIdToName = new HashMap<>();
NodesInfoResponse nodesInfoResponse = client().admin().cluster().prepareNodesInfo().get();
assertFalse(nodesInfoResponse.hasFailures());
for (NodeInfo node : nodesInfoResponse.getNodes()) {
nodeIdToName.put(node.getNode().getId(), node.getNode().getName());
}
return nodeIdToName;
}

private static void ensureTaskIsCancelled(String transportAction, Function<String, String> nodeIdToName) throws Exception {
SetOnce<TaskInfo> searchTask = new SetOnce<>();
ListTasksResponse listTasksResponse = client().admin().cluster().prepareListTasks().get();
for (TaskInfo task : listTasksResponse.getTasks()) {
if (task.action().equals(transportAction)) {
searchTask.set(task);
}
}
assertNotNull(searchTask.get());
TaskId taskId = searchTask.get().taskId();
String nodeName = nodeIdToName.apply(taskId.getNodeId());
assertBusy(() -> {
TaskManager taskManager = internalCluster().getInstance(TransportService.class, nodeName).getTaskManager();
Task task = taskManager.getTask(taskId.getId());
assertThat(task, instanceOf(CancellableTask.class));
assertTrue(((CancellableTask) task).isCancelled());
Copy link
Member

@javanna javanna Nov 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also check that all children tasks are cancelled? My understanding is that we are now checking that the top-level profiling action is cancelled. Ideally, we'd make sure that the children search tasks either don't start before the cancellation, or get cancelled, or complete before the cancellation.

We could check here that there are no tasks with this specific parent task id, on the other hand that would be the case even if we did not set the parent task id properly. Should we instead try and collect those children tasks before the cancellation and double check that they do get cancelled afterwards?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not entirely sure though where things get blocked with the script below, maybe always at the first inner search request? ideally it would block at any stage but that sounds like it would complicate this test quite a bit?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I've implemented this now in 55c158a.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'll merge it then when CI is green. :)

});
}

private static List<ScriptedBlockPlugin> initBlockFactory() {
List<ScriptedBlockPlugin> plugins = new ArrayList<>();
for (PluginsService pluginsService : internalCluster().getDataNodeInstances(PluginsService.class)) {
plugins.addAll(pluginsService.filterPlugins(ScriptedBlockPlugin.class));
}
for (ScriptedBlockPlugin plugin : plugins) {
plugin.reset();
plugin.enableBlock();
}
return plugins;
}

private void awaitForBlock(List<ScriptedBlockPlugin> plugins) throws Exception {
assertBusy(() -> {
int numberOfBlockedPlugins = 0;
for (ScriptedBlockPlugin plugin : plugins) {
numberOfBlockedPlugins += plugin.hits.get();
}
logger.info("The plugin blocked on {} shards", numberOfBlockedPlugins);
assertThat(numberOfBlockedPlugins, greaterThan(0));
}, 10, TimeUnit.SECONDS);
}

private static void disableBlocks(List<ScriptedBlockPlugin> plugins) {
for (ScriptedBlockPlugin plugin : plugins) {
plugin.disableBlock();
}
}

public static class ScriptedBlockPlugin extends MockScriptPlugin {
static final String SCRIPT_NAME = "search_block";

private final AtomicInteger hits = new AtomicInteger();

private final AtomicBoolean shouldBlock = new AtomicBoolean(true);

void reset() {
hits.set(0);
}

void disableBlock() {
shouldBlock.set(false);
}

void enableBlock() {
shouldBlock.set(true);
}

@Override
public Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
return Collections.singletonMap(SCRIPT_NAME, params -> {
LeafStoredFieldsLookup fieldsLookup = (LeafStoredFieldsLookup) params.get("_fields");
LogManager.getLogger(GetProfilingActionIT.class).info("Blocking on the document {}", fieldsLookup.get("_id"));
hits.incrementAndGet();
try {
waitUntil(() -> shouldBlock.get() == false);
} catch (Exception e) {
throw new RuntimeException(e);
}
return true;
});
}
}
}