diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java index 9a800c2656c45..465c394403bef 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse; import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; @@ -28,6 +29,8 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.IndexMode; +import org.elasticsearch.index.IndexSettings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.PluginsService; import org.elasticsearch.rest.RestStatus; @@ -36,13 +39,16 @@ import org.elasticsearch.script.ScriptType; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder; +import org.elasticsearch.search.aggregations.timeseries.TimeSeriesAggregationBuilder; import org.elasticsearch.search.lookup.LeafStoredFieldsLookup; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskInfo; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.transport.TransportService; +import org.junit.BeforeClass; +import java.time.Instant; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -55,9 +61,12 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import static org.elasticsearch.index.IndexSettings.TIME_SERIES_END_TIME; +import static org.elasticsearch.index.IndexSettings.TIME_SERIES_START_TIME; import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery; import static org.elasticsearch.index.query.QueryBuilders.scriptQuery; import static org.elasticsearch.search.SearchCancellationIT.ScriptedBlockPlugin.SEARCH_BLOCK_SCRIPT_NAME; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertFailures; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures; import static org.hamcrest.Matchers.containsString; @@ -69,6 +78,13 @@ @ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE) public class SearchCancellationIT extends ESIntegTestCase { + private static boolean lowLevelCancellation; + + @BeforeClass + public static void init() { + lowLevelCancellation = randomBoolean(); + } + @Override protected Collection> nodePlugins() { return Collections.singleton(ScriptedBlockPlugin.class); @@ -76,7 +92,6 @@ protected Collection> nodePlugins() { @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { - boolean lowLevelCancellation = randomBoolean(); logger.info("Using lowLevelCancellation: {}", lowLevelCancellation); return Settings.builder() .put(super.nodeSettings(nodeOrdinal, otherSettings)) @@ -227,7 +242,12 @@ public void testCancellationDuringAggregation() throws Exception { new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.COMBINE_SCRIPT_NAME, Collections.emptyMap()) ) .reduceScript( - new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.REDUCE_SCRIPT_NAME, Collections.emptyMap()) + new Script( + ScriptType.INLINE, + "mockscript", + ScriptedBlockPlugin.REDUCE_BLOCK_SCRIPT_NAME, + Collections.emptyMap() + ) ) ) ) @@ -238,6 +258,80 @@ public void testCancellationDuringAggregation() throws Exception { ensureSearchWasCancelled(searchResponse); } + public void testCancellationDuringTimeSeriesAggregation() throws Exception { + List plugins = initBlockFactory(); + int numberOfShards = between(2, 5); + long now = Instant.now().toEpochMilli(); + int numberOfRefreshes = between(1, 5); + int numberOfDocsPerRefresh = numberOfShards * between(1500, 2000) / numberOfRefreshes; + assertAcked( + prepareCreate("test").setSettings( + Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numberOfShards) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(IndexSettings.MODE.getKey(), IndexMode.TIME_SERIES.name()) + .put(IndexMetadata.INDEX_ROUTING_PATH.getKey(), "dim") + .put(TIME_SERIES_START_TIME.getKey(), now) + .put(TIME_SERIES_END_TIME.getKey(), now + (long) numberOfRefreshes * numberOfDocsPerRefresh + 1) + .build() + ).setMapping(""" + { + "properties": { + "@timestamp": {"type": "date", "format": "epoch_millis"}, + "dim": {"type": "keyword", "time_series_dimension": true} + } + } + """) + ); + + for (int i = 0; i < numberOfRefreshes; i++) { + // Make sure we sometimes have a few segments + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + for (int j = 0; j < numberOfDocsPerRefresh; j++) { + bulkRequestBuilder.add( + client().prepareIndex("test") + .setOpType(DocWriteRequest.OpType.CREATE) + .setSource("@timestamp", now + (long) i * numberOfDocsPerRefresh + j, "val", (double) j, "dim", String.valueOf(i)) + ); + } + assertNoFailures(bulkRequestBuilder.get()); + } + + logger.info("Executing search"); + TimeSeriesAggregationBuilder timeSeriesAggregationBuilder = new TimeSeriesAggregationBuilder("test_agg"); + ActionFuture searchResponse = client().prepareSearch("test") + .setQuery(matchAllQuery()) + .addAggregation( + timeSeriesAggregationBuilder.subAggregation( + new ScriptedMetricAggregationBuilder("sub_agg").initScript( + new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.INIT_SCRIPT_NAME, Collections.emptyMap()) + ) + .mapScript( + new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.MAP_BLOCK_SCRIPT_NAME, Collections.emptyMap()) + ) + .combineScript( + new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.COMBINE_SCRIPT_NAME, Collections.emptyMap()) + ) + .reduceScript( + new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.REDUCE_FAIL_SCRIPT_NAME, Collections.emptyMap()) + ) + ) + ) + .execute(); + awaitForBlock(plugins); + cancelSearch(SearchAction.NAME); + disableBlocks(plugins); + + SearchPhaseExecutionException ex = expectThrows(SearchPhaseExecutionException.class, searchResponse::actionGet); + assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.BAD_REQUEST)); + logger.info("All shards failed with", ex); + if (lowLevelCancellation) { + // Ensure that we cancelled in TimeSeriesIndexSearcher and not in reduce phase + assertThat(ExceptionsHelper.stackTrace(ex), containsString("TimeSeriesIndexSearcher")); + } + + } + public void testCancellationOfScrollSearches() throws Exception { List plugins = initBlockFactory(); @@ -414,8 +508,11 @@ public static class ScriptedBlockPlugin extends MockScriptPlugin { static final String SEARCH_BLOCK_SCRIPT_NAME = "search_block"; static final String INIT_SCRIPT_NAME = "init"; static final String MAP_SCRIPT_NAME = "map"; + static final String MAP_BLOCK_SCRIPT_NAME = "map_block"; static final String COMBINE_SCRIPT_NAME = "combine"; static final String REDUCE_SCRIPT_NAME = "reduce"; + static final String REDUCE_FAIL_SCRIPT_NAME = "reduce_fail"; + static final String REDUCE_BLOCK_SCRIPT_NAME = "reduce_block"; static final String TERM_SCRIPT_NAME = "term"; private final AtomicInteger hits = new AtomicInteger(); @@ -449,10 +546,16 @@ public Map, Object>> pluginScripts() { this::nullScript, MAP_SCRIPT_NAME, this::nullScript, + MAP_BLOCK_SCRIPT_NAME, + this::mapBlockScript, COMBINE_SCRIPT_NAME, this::nullScript, - REDUCE_SCRIPT_NAME, + REDUCE_BLOCK_SCRIPT_NAME, this::blockScript, + REDUCE_SCRIPT_NAME, + this::termScript, + REDUCE_FAIL_SCRIPT_NAME, + this::reduceFailScript, TERM_SCRIPT_NAME, this::termScript ); @@ -474,6 +577,11 @@ private Object searchBlockScript(Map params) { return true; } + private Object reduceFailScript(Map params) { + fail("Shouldn't reach reduce"); + return true; + } + private Object nullScript(Map params) { return null; } @@ -483,7 +591,9 @@ private Object blockScript(Map params) { if (runnable != null) { runnable.run(); } - LogManager.getLogger(SearchCancellationIT.class).info("Blocking in reduce"); + if (shouldBlock.get()) { + LogManager.getLogger(SearchCancellationIT.class).info("Blocking in reduce"); + } hits.incrementAndGet(); try { assertBusy(() -> assertFalse(shouldBlock.get())); @@ -493,6 +603,23 @@ private Object blockScript(Map params) { return 42; } + private Object mapBlockScript(Map params) { + final Runnable runnable = beforeExecution.get(); + if (runnable != null) { + runnable.run(); + } + if (shouldBlock.get()) { + LogManager.getLogger(SearchCancellationIT.class).info("Blocking in map"); + } + hits.incrementAndGet(); + try { + assertBusy(() -> assertFalse(shouldBlock.get())); + } catch (Exception e) { + throw new RuntimeException(e); + } + return 1; + } + private Object termScript(Map params) { return 1; } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/AggregationPhase.java b/server/src/main/java/org/elasticsearch/search/aggregations/AggregationPhase.java index ffcc971eeda7a..ce28ab0499d54 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/AggregationPhase.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/AggregationPhase.java @@ -8,11 +8,14 @@ package org.elasticsearch.search.aggregations; import org.apache.lucene.search.Collector; +import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.aggregations.timeseries.TimeSeriesIndexSearcher; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.profile.query.CollectorResult; import org.elasticsearch.search.profile.query.InternalProfileCollector; +import org.elasticsearch.search.query.QueryPhase; import java.io.IOException; import java.util.ArrayList; @@ -40,7 +43,7 @@ public void preProcess(SearchContext context) { } if (context.aggregations().factories().context() != null && context.aggregations().factories().context().isInSortOrderExecutionRequired()) { - TimeSeriesIndexSearcher searcher = new TimeSeriesIndexSearcher(context.searcher()); + TimeSeriesIndexSearcher searcher = new TimeSeriesIndexSearcher(context.searcher(), getCancellationChecks(context)); try { searcher.search(context.rewrittenQuery(), bucketCollector); } catch (IOException e) { @@ -55,6 +58,36 @@ public void preProcess(SearchContext context) { } } + private List getCancellationChecks(SearchContext context) { + List cancellationChecks = new ArrayList<>(); + if (context.lowLevelCancellation()) { + // This searching doesn't live beyond this phase, so we don't need to remove query cancellation + cancellationChecks.add(() -> { + final SearchShardTask task = context.getTask(); + if (task != null) { + task.ensureNotCancelled(); + } + }); + } + + boolean timeoutSet = context.scrollContext() == null + && context.timeout() != null + && context.timeout().equals(SearchService.NO_TIMEOUT) == false; + + if (timeoutSet) { + final long startTime = context.getRelativeTimeInMillis(); + final long timeout = context.timeout().millis(); + final long maxTime = startTime + timeout; + cancellationChecks.add(() -> { + final long time = context.getRelativeTimeInMillis(); + if (time > maxTime) { + throw new QueryPhase.TimeExceededException(); + } + }); + } + return cancellationChecks; + } + public void execute(SearchContext context) { if (context.aggregations() == null) { context.queryResult().aggregations(null); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/timeseries/TimeSeriesIndexSearcher.java b/server/src/main/java/org/elasticsearch/search/aggregations/timeseries/TimeSeriesIndexSearcher.java index 4837a291df98f..71ccf96fd6bc2 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/timeseries/TimeSeriesIndexSearcher.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/timeseries/TimeSeriesIndexSearcher.java @@ -37,22 +37,29 @@ * TODO: Convert it to use index sort instead of hard-coded tsid and timestamp values */ public class TimeSeriesIndexSearcher { + private static final int CHECK_CANCELLED_SCORER_INTERVAL = 1 << 11; // We need to delegate to the other searcher here as opposed to extending IndexSearcher and inheriting default implementations as the // IndexSearcher would most of the time be a ContextIndexSearcher that has important logic related to e.g. document-level security. private final IndexSearcher searcher; + private final List cancellations; - public TimeSeriesIndexSearcher(IndexSearcher searcher) { + public TimeSeriesIndexSearcher(IndexSearcher searcher, List cancellations) { this.searcher = searcher; + this.cancellations = cancellations; } public void search(Query query, BucketCollector bucketCollector) throws IOException { + int seen = 0; query = searcher.rewrite(query); Weight weight = searcher.createWeight(query, bucketCollector.scoreMode(), 1); // Create LeafWalker for each subreader List leafWalkers = new ArrayList<>(); for (LeafReaderContext leaf : searcher.getIndexReader().leaves()) { + if (++seen % CHECK_CANCELLED_SCORER_INTERVAL == 0) { + checkCancelled(); + } LeafBucketCollector leafCollector = bucketCollector.getLeafCollector(leaf); Scorer scorer = weight.scorer(leaf); if (scorer != null) { @@ -76,6 +83,9 @@ protected boolean lessThan(LeafWalker a, LeafWalker b) { // walkers are ordered by timestamp. while (populateQueue(leafWalkers, queue)) { do { + if (++seen % CHECK_CANCELLED_SCORER_INTERVAL == 0) { + checkCancelled(); + } LeafWalker walker = queue.top(); walker.collectCurrent(); if (walker.nextDoc() == DocIdSetIterator.NO_MORE_DOCS || walker.shouldPop()) { @@ -131,6 +141,12 @@ private boolean queueAllHaveTsid(PriorityQueue queue, BytesRef tsid) return true; } + private void checkCancelled() { + for (Runnable r : cancellations) { + r.run(); + } + } + private static class LeafWalker { private final LeafCollector collector; private final Bits liveDocs; diff --git a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java index 0b72df78a510f..937378719ff81 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java @@ -267,5 +267,5 @@ private static boolean canEarlyTerminate(IndexReader reader, SortAndFormats sort return true; } - static class TimeExceededException extends RuntimeException {} + public static class TimeExceededException extends RuntimeException {} } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/timeseries/TimeSeriesCancellationTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/timeseries/TimeSeriesCancellationTests.java new file mode 100644 index 0000000000000..b66db7736a7ff --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/aggregations/timeseries/TimeSeriesCancellationTests.java @@ -0,0 +1,128 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ +package org.elasticsearch.search.aggregations.timeseries; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.SortedDocValuesField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.cluster.metadata.DataStream; +import org.elasticsearch.core.internal.io.IOUtils; +import org.elasticsearch.index.mapper.TimeSeriesIdFieldMapper; +import org.elasticsearch.search.aggregations.BucketCollector; +import org.elasticsearch.search.aggregations.LeafBucketCollector; +import org.elasticsearch.search.internal.ContextIndexSearcher; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.test.ESTestCase; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.hamcrest.Matchers.equalTo; + +public class TimeSeriesCancellationTests extends ESTestCase { + + private static Directory dir; + private static IndexReader reader; + + @BeforeClass + public static void setup() throws IOException { + dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setIndexSort( + new Sort( + new SortField(TimeSeriesIdFieldMapper.NAME, SortField.Type.STRING), + new SortField(DataStream.TimestampField.FIXED_TIMESTAMP_FIELD, SortField.Type.LONG) + ) + ); + RandomIndexWriter iw = new RandomIndexWriter(random(), dir, iwc); + indexRandomDocuments(iw, randomIntBetween(2048, 4096)); + iw.flush(); + reader = iw.getReader(); + iw.close(); + } + + private static void indexRandomDocuments(RandomIndexWriter w, int numDocs) throws IOException { + for (int i = 1; i <= numDocs; ++i) { + Document doc = new Document(); + String tsid = "tsid" + randomIntBetween(0, 30); + long time = randomNonNegativeLong(); + doc.add(new SortedDocValuesField(TimeSeriesIdFieldMapper.NAME, new BytesRef(tsid))); + doc.add(new NumericDocValuesField(DataStream.TimestampField.FIXED_TIMESTAMP_FIELD, time)); + w.addDocument(doc); + } + } + + @AfterClass + public static void cleanup() throws IOException { + IOUtils.close(reader, dir); + dir = null; + reader = null; + } + + public void testLowLevelCancellationActions() throws IOException { + ContextIndexSearcher searcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true + ); + TimeSeriesIndexSearcher timeSeriesIndexSearcher = new TimeSeriesIndexSearcher( + searcher, + List.of(() -> { throw new TaskCancelledException("Cancel"); }) + ); + CountingBucketCollector bc = new CountingBucketCollector(); + expectThrows(TaskCancelledException.class, () -> timeSeriesIndexSearcher.search(new MatchAllDocsQuery(), bc)); + // We count every segment and every record as 1 and break on 2048th iteration counting from 0 + // so we expect to see 2048 - number_of_segments - 1 (-1 is because we check before we collect) + assertThat(bc.count.get(), equalTo(Math.max(0, 2048 - reader.leaves().size() - 1))); + } + + public static class CountingBucketCollector extends BucketCollector { + public AtomicInteger count = new AtomicInteger(); + + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException { + return new LeafBucketCollector() { + @Override + public void collect(int doc, long owningBucketOrd) throws IOException { + count.incrementAndGet(); + } + }; + } + + @Override + public void preCollection() throws IOException { + + } + + @Override + public void postCollection() throws IOException { + + } + + @Override + public ScoreMode scoreMode() { + return ScoreMode.COMPLETE; + } + } +} diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/timeseries/TimeSeriesIndexSearcherTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/timeseries/TimeSeriesIndexSearcherTests.java index 670a6b1f1d31d..7bc5a2522d55b 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/timeseries/TimeSeriesIndexSearcherTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/timeseries/TimeSeriesIndexSearcherTests.java @@ -34,6 +34,7 @@ import java.io.IOException; import java.io.UncheckedIOException; +import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -85,7 +86,7 @@ public void testCollectInOrderAcrossSegments() throws IOException, InterruptedEx IndexReader reader = DirectoryReader.open(dir); IndexSearcher searcher = new IndexSearcher(reader); - TimeSeriesIndexSearcher indexSearcher = new TimeSeriesIndexSearcher(searcher); + TimeSeriesIndexSearcher indexSearcher = new TimeSeriesIndexSearcher(searcher, List.of()); BucketCollector collector = new BucketCollector() { diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java index 369d07ab26446..dfdfd267373b5 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java @@ -578,7 +578,7 @@ private A searchAndReduce( C a = createAggregator(builder, context); a.preCollection(); if (context.isInSortOrderExecutionRequired()) { - new TimeSeriesIndexSearcher(subSearcher).search(rewritten, a); + new TimeSeriesIndexSearcher(subSearcher, List.of()).search(rewritten, a); } else { Weight weight = subSearcher.createWeight(rewritten, ScoreMode.COMPLETE, 1f); subSearcher.search(weight, a); @@ -589,7 +589,7 @@ private A searchAndReduce( } else { root.preCollection(); if (context.isInSortOrderExecutionRequired()) { - new TimeSeriesIndexSearcher(searcher).search(rewritten, MultiBucketCollector.wrap(true, List.of(root))); + new TimeSeriesIndexSearcher(searcher, List.of()).search(rewritten, MultiBucketCollector.wrap(true, List.of(root))); } else { searcher.search(rewritten, MultiBucketCollector.wrap(true, List.of(root))); }