Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.logging.log4j.LogManager;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
Expand All @@ -28,6 +29,8 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.IndexMode;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.PluginsService;
import org.elasticsearch.rest.RestStatus;
Expand All @@ -36,13 +39,16 @@
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
import org.elasticsearch.search.aggregations.timeseries.TimeSeriesAggregationBuilder;
import org.elasticsearch.search.lookup.LeafStoredFieldsLookup;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.transport.TransportService;
import org.junit.BeforeClass;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
Expand All @@ -55,9 +61,12 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import static org.elasticsearch.index.IndexSettings.TIME_SERIES_END_TIME;
import static org.elasticsearch.index.IndexSettings.TIME_SERIES_START_TIME;
import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery;
import static org.elasticsearch.index.query.QueryBuilders.scriptQuery;
import static org.elasticsearch.search.SearchCancellationIT.ScriptedBlockPlugin.SEARCH_BLOCK_SCRIPT_NAME;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertFailures;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures;
import static org.hamcrest.Matchers.containsString;
Expand All @@ -69,14 +78,20 @@
@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE)
public class SearchCancellationIT extends ESIntegTestCase {

private static boolean lowLevelCancellation;

@BeforeClass
public static void init() {
lowLevelCancellation = randomBoolean();
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Collections.singleton(ScriptedBlockPlugin.class);
}

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
boolean lowLevelCancellation = randomBoolean();
logger.info("Using lowLevelCancellation: {}", lowLevelCancellation);
return Settings.builder()
.put(super.nodeSettings(nodeOrdinal, otherSettings))
Expand Down Expand Up @@ -227,7 +242,12 @@ public void testCancellationDuringAggregation() throws Exception {
new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.COMBINE_SCRIPT_NAME, Collections.emptyMap())
)
.reduceScript(
new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.REDUCE_SCRIPT_NAME, Collections.emptyMap())
new Script(
ScriptType.INLINE,
"mockscript",
ScriptedBlockPlugin.REDUCE_BLOCK_SCRIPT_NAME,
Collections.emptyMap()
)
)
)
)
Expand All @@ -238,6 +258,93 @@ public void testCancellationDuringAggregation() throws Exception {
ensureSearchWasCancelled(searchResponse);
}

public void testCancellationDuringTimeSeriesAggregation() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory();
boolean blockInReduce = false;
int numberOfShards = between(2, 5);
long now = Instant.now().toEpochMilli();
int numberOfRefreshes = between(2, 5);
// need to make sure we hit the low level check that happens only every 1024 docs, so we need to make sure that we have at
// least 1024 docs on the shard that we are blocked on otherwise we might never hit the low level cancellation there.
int numberOfDocsPerRefresh = numberOfShards * between(1500, 2000);
assertAcked(
prepareCreate("test").setSettings(
Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numberOfShards)
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
.put(IndexSettings.MODE.getKey(), IndexMode.TIME_SERIES.name())
.put(IndexMetadata.INDEX_ROUTING_PATH.getKey(), "dim")
.put(TIME_SERIES_START_TIME.getKey(), now)
.put(TIME_SERIES_END_TIME.getKey(), now + (long) numberOfRefreshes * numberOfDocsPerRefresh + 1)
.build()
).setMapping("""
{
"properties": {
"@timestamp": {"type": "date", "format": "epoch_millis"},
"dim": {"type": "keyword", "time_series_dimension": true}
}
}
""")
);

for (int i = 0; i < numberOfRefreshes; i++) {
// Make sure we have a few segments
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int j = 0; j < numberOfDocsPerRefresh; j++) {
bulkRequestBuilder.add(
client().prepareIndex("test")
.setOpType(DocWriteRequest.OpType.CREATE)
.setSource("@timestamp", now + (long) i * numberOfDocsPerRefresh + j, "val", (double) j, "dim", String.valueOf(i))
);
}
assertNoFailures(bulkRequestBuilder.get());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't really "make sure" we have a few segments - just make it quite likely. I'm guessing if you were to run an index stats after this to confirm we had a few segments it'd fail some percent of the time. Its fine, I think, but maybe comment is optimistic.

}

logger.info("Executing search");
TimeSeriesAggregationBuilder timeSeriesAggregationBuilder = new TimeSeriesAggregationBuilder("test_agg");
ActionFuture<SearchResponse> searchResponse = client().prepareSearch("test")
.setQuery(matchAllQuery())
.addAggregation(
timeSeriesAggregationBuilder.subAggregation(
new ScriptedMetricAggregationBuilder("sub_agg").initScript(
new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.INIT_SCRIPT_NAME, Collections.emptyMap())
)
.mapScript(
new Script(
ScriptType.INLINE,
"mockscript",
blockInReduce ? ScriptedBlockPlugin.MAP_SCRIPT_NAME : ScriptedBlockPlugin.MAP_BLOCK_SCRIPT_NAME,
Collections.emptyMap()
)
)
.combineScript(
new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.COMBINE_SCRIPT_NAME, Collections.emptyMap())
)
.reduceScript(
new Script(
ScriptType.INLINE,
"mockscript",
blockInReduce ? ScriptedBlockPlugin.REDUCE_BLOCK_SCRIPT_NAME : ScriptedBlockPlugin.REDUCE_FAIL_SCRIPT_NAME,
Collections.emptyMap()
)
)
)
)
.execute();
awaitForBlock(plugins);
cancelSearch(SearchAction.NAME);
disableBlocks(plugins);

SearchPhaseExecutionException ex = expectThrows(SearchPhaseExecutionException.class, searchResponse::actionGet);
assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.BAD_REQUEST));
logger.info("All shards failed with", ex);
if (lowLevelCancellation) {
// Ensure that we cancelled in LeafWalker and not in reduce phase
assertThat(ExceptionsHelper.stackTrace(ex), containsString("LeafWalker"));
}

}

public void testCancellationOfScrollSearches() throws Exception {

List<ScriptedBlockPlugin> plugins = initBlockFactory();
Expand Down Expand Up @@ -414,8 +521,11 @@ public static class ScriptedBlockPlugin extends MockScriptPlugin {
static final String SEARCH_BLOCK_SCRIPT_NAME = "search_block";
static final String INIT_SCRIPT_NAME = "init";
static final String MAP_SCRIPT_NAME = "map";
static final String MAP_BLOCK_SCRIPT_NAME = "map_block";
static final String COMBINE_SCRIPT_NAME = "combine";
static final String REDUCE_SCRIPT_NAME = "reduce";
static final String REDUCE_FAIL_SCRIPT_NAME = "reduce_fail";
static final String REDUCE_BLOCK_SCRIPT_NAME = "reduce_block";
static final String TERM_SCRIPT_NAME = "term";

private final AtomicInteger hits = new AtomicInteger();
Expand Down Expand Up @@ -449,10 +559,16 @@ public Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
this::nullScript,
MAP_SCRIPT_NAME,
this::nullScript,
MAP_BLOCK_SCRIPT_NAME,
this::mapBlockScript,
COMBINE_SCRIPT_NAME,
this::nullScript,
REDUCE_SCRIPT_NAME,
REDUCE_BLOCK_SCRIPT_NAME,
this::blockScript,
REDUCE_SCRIPT_NAME,
this::termScript,
REDUCE_FAIL_SCRIPT_NAME,
this::reduceFailScript,
TERM_SCRIPT_NAME,
this::termScript
);
Expand All @@ -474,6 +590,11 @@ private Object searchBlockScript(Map<String, Object> params) {
return true;
}

private Object reduceFailScript(Map<String, Object> params) {
fail("Shouldn't reach reduce");
return true;
}

private Object nullScript(Map<String, Object> params) {
return null;
}
Expand All @@ -483,7 +604,9 @@ private Object blockScript(Map<String, Object> params) {
if (runnable != null) {
runnable.run();
}
LogManager.getLogger(SearchCancellationIT.class).info("Blocking in reduce");
if (shouldBlock.get()) {
LogManager.getLogger(SearchCancellationIT.class).info("Blocking in reduce");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just declare a private static final Logger logger = LogManager.getLogger()?

}
hits.incrementAndGet();
try {
assertBusy(() -> assertFalse(shouldBlock.get()));
Expand All @@ -493,6 +616,23 @@ private Object blockScript(Map<String, Object> params) {
return 42;
}

private Object mapBlockScript(Map<String, Object> params) {
final Runnable runnable = beforeExecution.get();
if (runnable != null) {
runnable.run();
}
if (shouldBlock.get()) {
LogManager.getLogger(SearchCancellationIT.class).info("Blocking in map");
}
hits.incrementAndGet();
try {
assertBusy(() -> assertFalse(shouldBlock.get()));
} catch (Exception e) {
throw new RuntimeException(e);
}
return 1;
}

private Object termScript(Map<String, Object> params) {
return 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
package org.elasticsearch.search.aggregations;

import org.apache.lucene.search.Collector;
import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.aggregations.timeseries.TimeSeriesIndexSearcher;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.profile.query.CollectorResult;
import org.elasticsearch.search.profile.query.InternalProfileCollector;
import org.elasticsearch.search.query.QueryPhase;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -40,7 +43,7 @@ public void preProcess(SearchContext context) {
}
if (context.aggregations().factories().context() != null
&& context.aggregations().factories().context().isInSortOrderExecutionRequired()) {
TimeSeriesIndexSearcher searcher = new TimeSeriesIndexSearcher(context.searcher());
TimeSeriesIndexSearcher searcher = new TimeSeriesIndexSearcher(context.searcher(), getCancellationChecks(context));
try {
searcher.search(context.rewrittenQuery(), bucketCollector);
} catch (IOException e) {
Expand All @@ -55,6 +58,36 @@ public void preProcess(SearchContext context) {
}
}

private List<Runnable> getCancellationChecks(SearchContext context) {
List<Runnable> cancellationChecks = new ArrayList<>();
if (context.lowLevelCancellation()) {
// This searching doesn't live beyond this phase, so we don't need to remove query cancellation
cancellationChecks.add(() -> {
final SearchShardTask task = context.getTask();
if (task != null) {
task.ensureNotCancelled();
}
});
}

boolean timeoutSet = context.scrollContext() == null
&& context.timeout() != null
&& context.timeout().equals(SearchService.NO_TIMEOUT) == false;

if (timeoutSet) {
final long startTime = context.getRelativeTimeInMillis();
final long timeout = context.timeout().millis();
final long maxTime = startTime + timeout;
cancellationChecks.add(() -> {
final long time = context.getRelativeTimeInMillis();
if (time > maxTime) {
throw new QueryPhase.TimeExceededException();
}
});
}
return cancellationChecks;
}

public void execute(SearchContext context) {
if (context.aggregations() == null) {
context.queryResult().aggregations(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
Expand All @@ -26,6 +28,7 @@
import org.elasticsearch.index.mapper.TimeSeriesIdFieldMapper;
import org.elasticsearch.search.aggregations.BucketCollector;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.internal.CancellableScorer;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -41,14 +44,16 @@ public class TimeSeriesIndexSearcher {
// We need to delegate to the other searcher here as opposed to extending IndexSearcher and inheriting default implementations as the
// IndexSearcher would most of the time be a ContextIndexSearcher that has important logic related to e.g. document-level security.
private final IndexSearcher searcher;
private final List<Runnable> cancellations;

public TimeSeriesIndexSearcher(IndexSearcher searcher) {
public TimeSeriesIndexSearcher(IndexSearcher searcher, List<Runnable> cancellations) {
this.searcher = searcher;
this.cancellations = cancellations;
}

public void search(Query query, BucketCollector bucketCollector) throws IOException {
query = searcher.rewrite(query);
Weight weight = searcher.createWeight(query, bucketCollector.scoreMode(), 1);
Weight weight = wrapWeight(searcher.createWeight(query, bucketCollector.scoreMode(), 1));

// Create LeafWalker for each subreader
List<LeafWalker> leafWalkers = new ArrayList<>();
Expand Down Expand Up @@ -131,6 +136,45 @@ private boolean queueAllHaveTsid(PriorityQueue<LeafWalker> queue, BytesRef tsid)
return true;
}

private void checkCancelled() {
for (Runnable r : cancellations) {
r.run();
}
}

private Weight wrapWeight(Weight weight) {
if (cancellations.isEmpty() == false) {
return new Weight(weight.getQuery()) {
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
return weight.explain(context, doc);
}

@Override
public boolean isCacheable(LeafReaderContext ctx) {
return weight.isCacheable(ctx);
}

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
Scorer scorer = weight.scorer(context);
if (scorer != null) {
return new CancellableScorer(scorer, () -> checkCancelled());
} else {
return null;
}
}

@Override
public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
return weight.bulkScorer(context);
}
};
} else {
return weight;
}
}

private static class LeafWalker {
private final LeafCollector collector;
private final Bits liveDocs;
Expand Down
Loading