diff --git a/CHANGELOG.md b/CHANGELOG.md index 5383965916b68..8f9f1b4db9968 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Fix the visit of sub queries for HasParentQuery and HasChildQuery ([#18621](https://github.com/opensearch-project/OpenSearch/pull/18621)) - Fix the backward compatibility regression with COMPLEMENT for Regexp queries introduced in OpenSearch 3.0 ([#18640](https://github.com/opensearch-project/OpenSearch/pull/18640)) - Fix Replication lag computation ([#18602](https://github.com/opensearch-project/OpenSearch/pull/18602)) +- Add task cancellation checks in FetchPhase during aggregation reductions ([18673](https://github.com/opensearch-project/OpenSearch/pull/18673)) ### Security diff --git a/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/TermsReduceBenchmark.java b/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/TermsReduceBenchmark.java index 5073858848e05..53416124e27be 100644 --- a/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/TermsReduceBenchmark.java +++ b/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/TermsReduceBenchmark.java @@ -229,7 +229,8 @@ public SearchPhaseController.ReducedQueryPhase reduceAggs(TermsList candidateLis SearchProgressListener.NOOP, namedWriteableRegistry, shards.size(), - exc -> {} + exc -> {}, + () -> false ); CountDownLatch latch = new CountDownLatch(shards.size()); for (int i = 0; i < shards.size(); i++) { diff --git a/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/BaseGeoGrid.java b/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/BaseGeoGrid.java index 16a74f58c9d3a..bb6ffa599fd52 100644 --- a/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/BaseGeoGrid.java +++ b/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/bucket/geogrid/BaseGeoGrid.java @@ -114,6 +114,7 @@ public BaseGeoGrid reduce(List aggregations, ReduceContext final int size = Math.toIntExact(reduceContext.isFinalReduce() == false ? buckets.size() : Math.min(requiredSize, buckets.size())); BucketPriorityQueue ordered = new BucketPriorityQueue<>(size); for (LongObjectPagedHashMap.Cursor> cursor : buckets) { + reduceContext.checkCancelled(); List sameCellBuckets = cursor.value; ordered.insertWithOverflow(reduceBucket(sameCellBuckets, reduceContext)); } diff --git a/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/metrics/InternalGeoBounds.java b/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/metrics/InternalGeoBounds.java index c3cb47d244c6b..573fbc47f45bf 100644 --- a/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/metrics/InternalGeoBounds.java +++ b/modules/geo/src/main/java/org/opensearch/geo/search/aggregations/metrics/InternalGeoBounds.java @@ -119,6 +119,7 @@ public InternalAggregation reduce(List aggregations, Reduce double negRight = Double.NEGATIVE_INFINITY; for (InternalAggregation aggregation : aggregations) { + reduceContext.checkCancelled(); InternalGeoBounds bounds = (InternalGeoBounds) aggregation; if (bounds.top > top) { diff --git a/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java index f1b06378bd579..2bf85b3da2dc3 100644 --- a/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java +++ b/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java @@ -44,6 +44,7 @@ import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.InternalAggregation.ReduceContextBuilder; import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.builder.SearchSourceBuilder; @@ -57,6 +58,7 @@ import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BooleanSupplier; import java.util.function.Consumer; /** @@ -83,6 +85,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults onPartialMergeFailure; @@ -99,7 +102,8 @@ public QueryPhaseResultConsumer( SearchProgressListener progressListener, NamedWriteableRegistry namedWriteableRegistry, int expectedResultSize, - Consumer onPartialMergeFailure + Consumer onPartialMergeFailure, + BooleanSupplier isTaskCancelled ) { super(expectedResultSize); this.executor = executor; @@ -111,6 +115,7 @@ public QueryPhaseResultConsumer( this.topNSize = SearchPhaseController.getTopDocsSize(request); this.performFinalReduce = request.isFinalReduce(); this.onPartialMergeFailure = onPartialMergeFailure; + this.isTaskCancelled = isTaskCancelled; SearchSourceBuilder source = request.source(); this.hasTopDocs = source == null || source.size() != 0; @@ -158,7 +163,8 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { pendingMerges.numReducePhases, false, aggReduceContextBuilder, - performFinalReduce + performFinalReduce, + isTaskCancelled ); if (hasAggs) { // Update the circuit breaker to replace the estimation with the serialized size of the newly reduced result @@ -219,7 +225,9 @@ private MergeResult partialReduce( for (QuerySearchResult result : toConsume) { aggsList.add(result.consumeAggs().expand()); } - newAggs = InternalAggregations.topLevelReduce(aggsList, aggReduceContextBuilder.forPartialReduction()); + InternalAggregation.ReduceContext reduceContext = aggReduceContextBuilder.forPartialReduction(); + reduceContext.setIsTaskCancelled(isTaskCancelled); + newAggs = InternalAggregations.topLevelReduce(aggsList, reduceContext); } else { newAggs = null; } diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java index 43132b5cf58ab..f997b715529b6 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java @@ -78,6 +78,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.Executor; +import java.util.function.BooleanSupplier; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.IntFunction; @@ -431,7 +432,17 @@ public ReduceContext forFinalReduction() { topDocs.add(td.topDocs); } } - return reducedQueryPhase(queryResults, Collections.emptyList(), topDocs, topDocsStats, 0, true, aggReduceContextBuilder, true); + return reducedQueryPhase( + queryResults, + Collections.emptyList(), + topDocs, + topDocsStats, + 0, + true, + aggReduceContextBuilder, + true, + () -> false + ); } /** @@ -451,7 +462,8 @@ ReducedQueryPhase reducedQueryPhase( int numReducePhases, boolean isScrollRequest, InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, - boolean performFinalReduce + boolean performFinalReduce, + BooleanSupplier isTaskCancelled ) { assert numReducePhases >= 0 : "num reduce phases must be >= 0 but was: " + numReducePhases; numReducePhases++; // increment for this phase @@ -526,7 +538,7 @@ ReducedQueryPhase reducedQueryPhase( reducedSuggest = new Suggest(Suggest.reduce(groupedSuggestions)); reducedCompletionSuggestions = reducedSuggest.filter(CompletionSuggestion.class); } - final InternalAggregations aggregations = reduceAggs(aggReduceContextBuilder, performFinalReduce, bufferedAggs); + final InternalAggregations aggregations = reduceAggs(aggReduceContextBuilder, performFinalReduce, bufferedAggs, isTaskCancelled); final SearchProfileShardResults shardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults); final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, bufferedTopDocs, from, size, reducedCompletionSuggestions); final TotalHits totalHits = topDocsStats.getTotalHits(); @@ -551,14 +563,17 @@ ReducedQueryPhase reducedQueryPhase( private static InternalAggregations reduceAggs( InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, boolean performFinalReduce, - List toReduce + List toReduce, + BooleanSupplier isTaskCancelled ) { - return toReduce.isEmpty() - ? null - : InternalAggregations.topLevelReduce( - toReduce, - performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction() - ); + if (toReduce.isEmpty()) { + return null; + } + final ReduceContext reduceContext = performFinalReduce + ? aggReduceContextBuilder.forFinalReduction() + : aggReduceContextBuilder.forPartialReduction(); + reduceContext.setIsTaskCancelled(isTaskCancelled); + return InternalAggregations.topLevelReduce(toReduce, reduceContext); } /** @@ -757,7 +772,8 @@ QueryPhaseResultConsumer newSearchPhaseResults( SearchProgressListener listener, SearchRequest request, int numShards, - Consumer onPartialMergeFailure + Consumer onPartialMergeFailure, + BooleanSupplier isTaskCancelled ) { return new QueryPhaseResultConsumer( request, @@ -767,7 +783,8 @@ QueryPhaseResultConsumer newSearchPhaseResults( listener, namedWriteableRegistry, numShards, - onPartialMergeFailure + onPartialMergeFailure, + isTaskCancelled ); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchResponseMerger.java b/server/src/main/java/org/opensearch/action/search/SearchResponseMerger.java index 538e7fd54e2c3..66f91069cb20c 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchResponseMerger.java +++ b/server/src/main/java/org/opensearch/action/search/SearchResponseMerger.java @@ -63,6 +63,7 @@ import java.util.Objects; import java.util.TreeMap; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BooleanSupplier; /** * Merges multiple search responses into one. Used in cross-cluster search when reduction is performed locally on each cluster. @@ -110,7 +111,7 @@ final class SearchResponseMerger { /** * Add a search response to the list of responses to be merged together into one. * Merges currently happen at once when all responses are available and - * {@link #getMergedResponse(SearchResponse.Clusters, SearchRequestContext)} )} is called. + * {@link #getMergedResponse(SearchResponse.Clusters, SearchRequestContext searchCotext, BooleanSupplier isTaskCancelled)} )} is called. * That may change in the future as it's possible to introduce incremental merges as responses come in if necessary. */ void add(SearchResponse searchResponse) { @@ -126,7 +127,11 @@ int numResponses() { * Returns the merged response. To be called once all responses have been added through {@link #add(SearchResponse)} * so that all responses are merged into a single one. */ - SearchResponse getMergedResponse(SearchResponse.Clusters clusters, SearchRequestContext searchRequestContext) { + SearchResponse getMergedResponse( + SearchResponse.Clusters clusters, + SearchRequestContext searchRequestContext, + BooleanSupplier isTaskCancelled + ) { // if the search is only across remote clusters, none of them are available, and all of them have skip_unavailable set to true, // we end up calling merge without anything to merge, we just return an empty search response if (searchResponses.size() == 0) { @@ -214,7 +219,9 @@ SearchResponse getMergedResponse(SearchResponse.Clusters clusters, SearchRequest SearchHits mergedSearchHits = topDocsToSearchHits(topDocs, topDocsStats); setSuggestShardIndex(shards, groupedSuggestions); Suggest suggest = groupedSuggestions.isEmpty() ? null : new Suggest(Suggest.reduce(groupedSuggestions)); - InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(aggs, aggReduceContextBuilder.forFinalReduction()); + InternalAggregation.ReduceContext reduceContext = aggReduceContextBuilder.forFinalReduction(); + reduceContext.setIsTaskCancelled(isTaskCancelled); + InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(aggs, reduceContext); ShardSearchFailure[] shardFailures = failures.toArray(ShardSearchFailure.EMPTY_ARRAY); SearchProfileShardResults profileShardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults); // make failures ordering consistent between ordinary search and CCS by looking at the shard they come from diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 1da080e5bd302..4b1e14d452f13 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -119,6 +119,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.BiFunction; +import java.util.function.BooleanSupplier; import java.util.function.Function; import java.util.function.LongSupplier; import java.util.stream.Collectors; @@ -548,7 +549,8 @@ private ActionListener buildRewriteListener( searchAsyncActionProvider, searchRequestContext ), - searchRequestContext + searchRequestContext, + () -> task instanceof CancellableTask ? ((CancellableTask) task).isCancelled() : false ); } else { AtomicInteger skippedClusters = new AtomicInteger(0); @@ -639,7 +641,8 @@ static void ccsRemoteReduce( ThreadPool threadPool, ActionListener listener, BiConsumer> localSearchConsumer, - SearchRequestContext searchRequestContext + SearchRequestContext searchRequestContext, + BooleanSupplier isTaskCancelled ) { if (localIndices == null && remoteIndices.size() == 1) { @@ -728,7 +731,8 @@ public void onFailure(Exception e) { searchResponseMerger, totalClusters, listener, - searchRequestContext + searchRequestContext, + isTaskCancelled ); Client remoteClusterClient = remoteClusterService.getRemoteClusterClient(threadPool, clusterAlias); remoteClusterClient.search(ccsSearchRequest, ccsListener); @@ -743,7 +747,8 @@ public void onFailure(Exception e) { searchResponseMerger, totalClusters, listener, - searchRequestContext + searchRequestContext, + isTaskCancelled ); SearchRequest ccsLocalSearchRequest = SearchRequest.subSearchRequest( searchRequest, @@ -841,7 +846,8 @@ private static ActionListener createCCSListener( SearchResponseMerger searchResponseMerger, int totalClusters, ActionListener originalListener, - SearchRequestContext searchRequestContext + SearchRequestContext searchRequestContext, + BooleanSupplier isTaskCancelled ) { return new CCSActionListener( clusterAlias, @@ -863,7 +869,7 @@ SearchResponse createFinalResponse() { searchResponseMerger.numResponses(), skippedClusters.get() ); - return searchResponseMerger.getMergedResponse(clusters, searchRequestContext); + return searchResponseMerger.getMergedResponse(clusters, searchRequestContext, isTaskCancelled); } }; } @@ -1265,7 +1271,8 @@ private AbstractSearchAsyncAction searchAsyncAction task.getProgressListener(), searchRequest, shardIterators.size(), - exc -> cancelTask(task, exc) + exc -> cancelTask(task, exc), + task::isCancelled ); AbstractSearchAsyncAction searchAsyncAction; switch (searchRequest.searchType()) { diff --git a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java index 0dd4c3344af1e..3b2753000fb92 100644 --- a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java @@ -1045,6 +1045,7 @@ public ReaderContext readerContext() { @Override public InternalAggregation.ReduceContext partialOnShard() { InternalAggregation.ReduceContext rc = requestToAggReduceContextBuilder.apply(request.source()).forPartialReduction(); + rc.setIsTaskCancelled(this::isCancelled); rc.setSliceLevel(shouldUseConcurrentSearch()); return rc; } diff --git a/server/src/main/java/org/opensearch/search/aggregations/InternalAggregation.java b/server/src/main/java/org/opensearch/search/aggregations/InternalAggregation.java index 49b85ccaea2a8..5ed3f4fba9bec 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/InternalAggregation.java +++ b/server/src/main/java/org/opensearch/search/aggregations/InternalAggregation.java @@ -37,6 +37,7 @@ import org.opensearch.core.common.io.stream.NamedWriteable; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.tasks.TaskCancelledException; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.rest.action.search.RestSearchAction; @@ -52,6 +53,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.BooleanSupplier; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.IntConsumer; @@ -95,6 +97,7 @@ public static class ReduceContext { private final ScriptService scriptService; private final IntConsumer multiBucketConsumer; private final PipelineTree pipelineTreeRoot; + private BooleanSupplier isTaskCancelled = () -> false; private boolean isSliceLevel; /** @@ -210,6 +213,23 @@ public void consumeBucketsAndMaybeBreak(int size) { multiBucketConsumer.accept(size); } + /** + * Setter for task cancellation supplier. + * @param isTaskCancelled + */ + public void setIsTaskCancelled(BooleanSupplier isTaskCancelled) { + this.isTaskCancelled = isTaskCancelled; + } + + /** + * Will check and throw the exception to terminate the request + */ + public void checkCancelled() { + if (isTaskCancelled.getAsBoolean()) { + throw new TaskCancelledException("The query has been cancelled"); + } + } + } protected final String name; @@ -288,6 +308,7 @@ public InternalAggregation reducePipelines( ) { assert reduceContext.isFinalReduce(); for (PipelineAggregator pipelineAggregator : pipelinesForThisAgg.aggregators()) { + reduceContext.checkCancelled(); reducedAggs = pipelineAggregator.reduce(reducedAggs, reduceContext); } return reducedAggs; diff --git a/server/src/main/java/org/opensearch/search/aggregations/InternalAggregations.java b/server/src/main/java/org/opensearch/search/aggregations/InternalAggregations.java index 9d55ee4a04506..27b609b7b2104 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/InternalAggregations.java +++ b/server/src/main/java/org/opensearch/search/aggregations/InternalAggregations.java @@ -168,6 +168,7 @@ public static InternalAggregations reduce(List aggregation Map> aggByName = new HashMap<>(); for (InternalAggregations aggregations : aggregationsList) { for (Aggregation aggregation : aggregations.aggregations) { + context.checkCancelled(); List aggs = aggByName.computeIfAbsent( aggregation.getName(), k -> new ArrayList<>(aggregationsList.size()) diff --git a/server/src/main/java/org/opensearch/search/aggregations/InternalMultiBucketAggregation.java b/server/src/main/java/org/opensearch/search/aggregations/InternalMultiBucketAggregation.java index 1d9df65fee92d..3e9e29c39936d 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/InternalMultiBucketAggregation.java +++ b/server/src/main/java/org/opensearch/search/aggregations/InternalMultiBucketAggregation.java @@ -215,6 +215,7 @@ private List reducePipelineBuckets(ReduceContext reduceContext, PipelineTree for (B bucket : getBuckets()) { List aggs = new ArrayList<>(); for (Aggregation agg : bucket.getAggregations()) { + reduceContext.checkCancelled(); PipelineTree subTree = pipelineTree.subTree(agg.getName()); aggs.add(((InternalAggregation) agg).reducePipelines((InternalAggregation) agg, reduceContext, subTree)); } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/InternalSingleBucketAggregation.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/InternalSingleBucketAggregation.java index 03fade2edb392..6a0094c46d7c9 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/InternalSingleBucketAggregation.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/InternalSingleBucketAggregation.java @@ -119,6 +119,7 @@ public InternalAggregation reduce(List aggregations, Reduce long docCount = 0L; List subAggregationsList = new ArrayList<>(aggregations.size()); for (InternalAggregation aggregation : aggregations) { + reduceContext.checkCancelled(); assert aggregation.getName().equals(getName()); docCount += ((InternalSingleBucketAggregation) aggregation).docCount; subAggregationsList.add(((InternalSingleBucketAggregation) aggregation).aggregations); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filter/InternalFilters.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filter/InternalFilters.java index 104dab01d90fe..616491280f05a 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/filter/InternalFilters.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filter/InternalFilters.java @@ -213,6 +213,7 @@ public InternalBucket getBucketByKey(String key) { public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { List> bucketsList = null; for (InternalAggregation aggregation : aggregations) { + reduceContext.checkCancelled(); InternalFilters filters = (InternalFilters) aggregation; if (bucketsList == null) { bucketsList = new ArrayList<>(filters.buckets.size()); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/InternalAutoDateHistogram.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/InternalAutoDateHistogram.java index 0866d26526761..052d0e1df8798 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/InternalAutoDateHistogram.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/InternalAutoDateHistogram.java @@ -308,7 +308,7 @@ public Bucket createBucket(InternalAggregations aggregations, Bucket prototype) * is the same and they can be reduced together. */ private BucketReduceResult reduceBuckets(List aggregations, ReduceContext reduceContext) { - + reduceContext.checkCancelled(); // First we need to find the highest level rounding used across all the // shards int reduceRoundingIdx = 0; @@ -421,6 +421,7 @@ private List mergeBuckets(List reducedBuckets, Rounding.Prepared @Override protected Bucket reduceBucket(List buckets, ReduceContext context) { + context.checkCancelled(); assert buckets.size() > 0; List aggregations = new ArrayList<>(buckets.size()); long docCount = 0; @@ -542,6 +543,7 @@ static int getAppropriateRounding(long minKey, long maxKey, int roundingIdx, Rou @Override public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { + reduceContext.checkCancelled(); BucketReduceResult reducedBucketsResult = reduceBuckets(aggregations, reduceContext); if (reduceContext.isFinalReduce()) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/InternalDateHistogram.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/InternalDateHistogram.java index e0b6010c6c3e8..9b1ab2193c56f 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/InternalDateHistogram.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/InternalDateHistogram.java @@ -330,6 +330,7 @@ protected boolean lessThan(IteratorAndCurrent a, IteratorAndCurrent list, ReduceContext reduceContext) { @Override public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { + reduceContext.checkCancelled(); List reducedBuckets = reduceBuckets(aggregations, reduceContext); if (reduceContext.isFinalReduce()) { if (minDocCount == 0) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/InternalHistogram.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/InternalHistogram.java index a988b911de5a3..b99ebb4bbdb2a 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/InternalHistogram.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/InternalHistogram.java @@ -313,6 +313,7 @@ protected boolean lessThan(IteratorAndCurrent a, IteratorAndCurrent aggregations, Reduce rangeList[i] = new ArrayList<>(); } for (InternalAggregation aggregation : aggregations) { + reduceContext.checkCancelled(); InternalRange ranges = (InternalRange) aggregation; int i = 0; for (B range : ranges.ranges) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/DoubleTerms.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/DoubleTerms.java index de02d5a938644..fe16190c3390b 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/DoubleTerms.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/DoubleTerms.java @@ -219,6 +219,7 @@ protected Bucket[] createBucketsArray(int size) { @Override public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { + reduceContext.checkCancelled(); boolean promoteToDouble = false; for (InternalAggregation agg : aggregations) { if (agg instanceof LongTerms diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java index 03bb519ed9961..6894e4deb50e1 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalSignificantTerms.java @@ -240,6 +240,7 @@ public InternalAggregation reduce(List aggregations, Reduce // Compute the overall result set size and the corpus size using the // top-level Aggregations from each shard for (InternalAggregation aggregation : aggregations) { + reduceContext.checkCancelled(); @SuppressWarnings("unchecked") InternalSignificantTerms terms = (InternalSignificantTerms) aggregation; globalSubsetSize += terms.getSubsetSize(); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java index b8f9406ff55b9..95b9791b792e6 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java @@ -395,6 +395,7 @@ public InternalAggregation reduce(List aggregations, Reduce long otherDocCount = 0; InternalTerms referenceTerms = null; for (InternalAggregation aggregation : aggregations) { + reduceContext.checkCancelled(); @SuppressWarnings("unchecked") InternalTerms terms = (InternalTerms) aggregation; // For Concurrent Segment Search the aggregation will have a computed doc count error coming from the shards. diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/AbstractInternalHDRPercentiles.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/AbstractInternalHDRPercentiles.java index 6f50d791594ff..bc36d93128b23 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/AbstractInternalHDRPercentiles.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/AbstractInternalHDRPercentiles.java @@ -146,6 +146,7 @@ public boolean keyed() { public AbstractInternalHDRPercentiles reduce(List aggregations, ReduceContext reduceContext) { DoubleHistogram merged = null; for (InternalAggregation aggregation : aggregations) { + reduceContext.checkCancelled(); final AbstractInternalHDRPercentiles percentiles = (AbstractInternalHDRPercentiles) aggregation; if (merged == null) { merged = new DoubleHistogram(percentiles.state); diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/AbstractInternalTDigestPercentiles.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/AbstractInternalTDigestPercentiles.java index 398d0054403ac..2cd5884632e07 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/AbstractInternalTDigestPercentiles.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/AbstractInternalTDigestPercentiles.java @@ -129,6 +129,7 @@ public boolean keyed() { public AbstractInternalTDigestPercentiles reduce(List aggregations, ReduceContext reduceContext) { TDigestState merged = null; for (InternalAggregation aggregation : aggregations) { + reduceContext.checkCancelled(); final AbstractInternalTDigestPercentiles percentiles = (AbstractInternalTDigestPercentiles) aggregation; if (merged == null) { merged = new TDigestState(percentiles.state.compression()); diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalScriptedMetric.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalScriptedMetric.java index fbcf4a4d48603..cd9dc73ca29af 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalScriptedMetric.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/InternalScriptedMetric.java @@ -99,6 +99,7 @@ List aggregationsList() { public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { List aggregationObjects = new ArrayList<>(); for (InternalAggregation aggregation : aggregations) { + reduceContext.checkCancelled(); InternalScriptedMetric mapReduceAggregation = (InternalScriptedMetric) aggregation; aggregationObjects.addAll(mapReduceAggregation.aggregations); } diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index b0fab3b7a3556..f431829a0898c 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -814,7 +814,8 @@ private SearchDfsQueryThenFetchAsyncAction createSearchDfsQueryThenFetchAsyncAct task.getProgressListener(), writableRegistry(), shardsIter.size(), - exc -> {} + exc -> {}, + () -> false ); AtomicReference exception = new AtomicReference<>(); ActionListener listener = ActionListener.wrap(response -> fail("onResponse should not be called"), exception::set); @@ -869,7 +870,8 @@ private SearchQueryThenFetchAsyncAction createSearchQueryThenFetchAsyncAction( task.getProgressListener(), writableRegistry(), shardsIter.size(), - exc -> {} + exc -> {}, + () -> false ); AtomicReference exception = new AtomicReference<>(); ActionListener listener = ActionListener.wrap(response -> fail("onResponse should not be called"), exception::set); @@ -926,7 +928,8 @@ private FetchSearchPhase createFetchSearchPhase() { SearchProgressListener.NOOP, mockSearchPhaseContext.getRequest(), 1, - exc -> {} + exc -> {}, + () -> false ); return new FetchSearchPhase( results, diff --git a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java index bb51aeaeee9dd..8f5d38aa490d6 100644 --- a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java +++ b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java @@ -675,7 +675,8 @@ public void sendCanMatch( task.getProgressListener(), writableRegistry(), shardsIter.size(), - exc -> {} + exc -> {}, + () -> false ); CanMatchPreFilterSearchPhase canMatchPhase = new CanMatchPreFilterSearchPhase( diff --git a/server/src/test/java/org/opensearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/opensearch/action/search/DfsQueryPhaseTests.java index 2dd2d9f9576a9..1dbbe68716e80 100644 --- a/server/src/test/java/org/opensearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/opensearch/action/search/DfsQueryPhaseTests.java @@ -141,7 +141,8 @@ public void sendExecuteQuery( SearchProgressListener.NOOP, mockSearchPhaseContext.searchRequest, results.length(), - exc -> {} + exc -> {}, + () -> false ); DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, consumer, (response) -> new SearchPhase("test") { @Override @@ -226,7 +227,8 @@ public void sendExecuteQuery( SearchProgressListener.NOOP, mockSearchPhaseContext.searchRequest, results.length(), - exc -> {} + exc -> {}, + () -> false ); DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, consumer, (response) -> new SearchPhase("test") { @Override @@ -313,7 +315,8 @@ public void sendExecuteQuery( SearchProgressListener.NOOP, mockSearchPhaseContext.searchRequest, results.length(), - exc -> {} + exc -> {}, + () -> false ); DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, consumer, (response) -> new SearchPhase("test") { @Override diff --git a/server/src/test/java/org/opensearch/action/search/FetchSearchPhaseTests.java b/server/src/test/java/org/opensearch/action/search/FetchSearchPhaseTests.java index c5db475e9db02..e9314eef00e23 100644 --- a/server/src/test/java/org/opensearch/action/search/FetchSearchPhaseTests.java +++ b/server/src/test/java/org/opensearch/action/search/FetchSearchPhaseTests.java @@ -72,7 +72,8 @@ public void testShortcutQueryAndFetchOptimization() { SearchProgressListener.NOOP, mockSearchPhaseContext.getRequest(), 1, - exc -> {} + exc -> {}, + () -> false ); boolean hasHits = randomBoolean(); final int numHits; @@ -133,7 +134,8 @@ public void testFetchTwoDocument() { SearchProgressListener.NOOP, mockSearchPhaseContext.getRequest(), 2, - exc -> {} + exc -> {}, + () -> false ); int resultSetSize = randomIntBetween(2, 10); ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); @@ -229,7 +231,8 @@ public void testFailFetchOneDoc() { SearchProgressListener.NOOP, mockSearchPhaseContext.getRequest(), 2, - exc -> {} + exc -> {}, + () -> false ); int resultSetSize = randomIntBetween(2, 10); final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123); @@ -327,7 +330,8 @@ public void testFetchDocsConcurrently() throws InterruptedException { SearchProgressListener.NOOP, mockSearchPhaseContext.getRequest(), numHits, - exc -> {} + exc -> {}, + () -> false ); for (int i = 0; i < numHits; i++) { QuerySearchResult queryResult = new QuerySearchResult( @@ -417,7 +421,8 @@ public void testExceptionFailsPhase() { SearchProgressListener.NOOP, mockSearchPhaseContext.getRequest(), 2, - exc -> {} + exc -> {}, + () -> false ); int resultSetSize = randomIntBetween(2, 10); QuerySearchResult queryResult = new QuerySearchResult( @@ -509,7 +514,8 @@ public void testCleanupIrrelevantContexts() { // contexts that are not fetched s SearchProgressListener.NOOP, mockSearchPhaseContext.getRequest(), 2, - exc -> {} + exc -> {}, + () -> false ); int resultSetSize = 1; final ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); diff --git a/server/src/test/java/org/opensearch/action/search/QueryPhaseResultConsumerTests.java b/server/src/test/java/org/opensearch/action/search/QueryPhaseResultConsumerTests.java index 283c9e2f238cc..63c0da501922b 100644 --- a/server/src/test/java/org/opensearch/action/search/QueryPhaseResultConsumerTests.java +++ b/server/src/test/java/org/opensearch/action/search/QueryPhaseResultConsumerTests.java @@ -130,7 +130,8 @@ public void testProgressListenerExceptionsAreCaught() throws Exception { e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { curr.addSuppressed(prev); return curr; - }) + }), + () -> false ); CountDownLatch partialReduceLatch = new CountDownLatch(10); diff --git a/server/src/test/java/org/opensearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/opensearch/action/search/SearchPhaseControllerTests.java index 964f79d23447a..d5e1dcd1d1d04 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchPhaseControllerTests.java @@ -340,7 +340,8 @@ public void testMerge() { 0, true, InternalAggregationTestCase.emptyReduceContextBuilder(), - true + true, + () -> false ); AtomicArray fetchResults = generateFetchResults( nShards, @@ -668,7 +669,8 @@ private void consumerTestCase(int numEmptyResponses) throws Exception { SearchProgressListener.NOOP, request, 3 + numEmptyResponses, - exc -> {} + exc -> {}, + () -> false ); if (numEmptyResponses == 0) { assertEquals(0, reductions.size()); @@ -776,7 +778,8 @@ public void testConsumerConcurrently() throws Exception { SearchProgressListener.NOOP, request, expectedNumResults, - exc -> {} + exc -> {}, + () -> false ); AtomicInteger max = new AtomicInteger(); Thread[] threads = new Thread[expectedNumResults]; @@ -840,7 +843,8 @@ public void testConsumerOnlyAggs() throws Exception { SearchProgressListener.NOOP, request, expectedNumResults, - exc -> {} + exc -> {}, + () -> false ); AtomicInteger max = new AtomicInteger(); CountDownLatch latch = new CountDownLatch(expectedNumResults); @@ -893,7 +897,8 @@ public void testConsumerOnlyHits() throws Exception { SearchProgressListener.NOOP, request, expectedNumResults, - exc -> {} + exc -> {}, + () -> false ); AtomicInteger max = new AtomicInteger(); CountDownLatch latch = new CountDownLatch(expectedNumResults); @@ -951,7 +956,8 @@ public void testReduceTopNWithFromOffset() throws Exception { SearchProgressListener.NOOP, request, 4, - exc -> {} + exc -> {}, + () -> false ); int score = 100; CountDownLatch latch = new CountDownLatch(4); @@ -1000,7 +1006,8 @@ public void testConsumerSortByField() throws Exception { SearchProgressListener.NOOP, request, expectedNumResults, - exc -> {} + exc -> {}, + () -> false ); AtomicInteger max = new AtomicInteger(); SortField[] sortFields = { new SortField("field", SortField.Type.INT, true) }; @@ -1047,7 +1054,8 @@ public void testConsumerFieldCollapsing() throws Exception { SearchProgressListener.NOOP, request, expectedNumResults, - exc -> {} + exc -> {}, + () -> false ); SortField[] sortFields = { new SortField("field", SortField.Type.STRING) }; BytesRef a = new BytesRef("a"); @@ -1097,7 +1105,8 @@ public void testConsumerSuggestions() throws Exception { SearchProgressListener.NOOP, request, expectedNumResults, - exc -> {} + exc -> {}, + () -> false ); int maxScoreTerm = -1; int maxScorePhrase = -1; @@ -1235,7 +1244,8 @@ public void onFinalReduce(List shards, TotalHits totalHits, Interna progressListener, request, expectedNumResults, - exc -> {} + exc -> {}, + () -> false ); AtomicInteger max = new AtomicInteger(); Thread[] threads = new Thread[expectedNumResults]; @@ -1324,7 +1334,8 @@ private void testReduceCase(boolean shouldFail) throws Exception { SearchProgressListener.NOOP, request, expectedNumResults, - exc -> hasConsumedFailure.set(true) + exc -> hasConsumedFailure.set(true), + () -> false ); CountDownLatch latch = new CountDownLatch(expectedNumResults); Thread[] threads = new Thread[expectedNumResults]; diff --git a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java index 1ccbbf4196505..2c6922bd7cc91 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -219,7 +219,8 @@ public void sendExecuteQuery( task.getProgressListener(), writableRegistry(), shardsIter.size(), - exc -> {} + exc -> {}, + () -> false ); SearchQueryThenFetchAsyncAction action = new SearchQueryThenFetchAsyncAction( logger, diff --git a/server/src/test/java/org/opensearch/action/search/SearchResponseMergerTests.java b/server/src/test/java/org/opensearch/action/search/SearchResponseMergerTests.java index e93f4553063ac..b5915f297f7e1 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchResponseMergerTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchResponseMergerTests.java @@ -139,7 +139,8 @@ public void testMergeTookInMillis() throws InterruptedException { new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), new SearchRequest(), () -> null - ) + ), + () -> false ); assertEquals(TimeUnit.NANOSECONDS.toMillis(currentRelativeTime), searchResponse.getTook().millis()); } @@ -198,7 +199,8 @@ public void testMergeShardFailures() throws InterruptedException { new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), new SearchRequest(), () -> null - ) + ), + () -> false ); assertSame(clusters, mergedResponse.getClusters()); assertEquals(numResponses, mergedResponse.getTotalShards()); @@ -256,7 +258,8 @@ public void testMergeShardFailuresNullShardTarget() throws InterruptedException new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), new SearchRequest(), () -> null - ) + ), + () -> false ); assertSame(clusters, mergedResponse.getClusters()); assertEquals(numResponses, mergedResponse.getTotalShards()); @@ -309,7 +312,8 @@ public void testMergeShardFailuresNullShardId() throws InterruptedException { new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), new SearchRequest(), () -> null - ) + ), + () -> false ).getShardFailures(); assertThat(Arrays.asList(shardFailures), containsInAnyOrder(expectedFailures.toArray(ShardSearchFailure.EMPTY_ARRAY))); } @@ -350,7 +354,8 @@ public void testMergeProfileResults() throws InterruptedException { new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), new SearchRequest(), () -> null - ) + ), + () -> false ); assertSame(clusters, mergedResponse.getClusters()); assertEquals(numResponses, mergedResponse.getTotalShards()); @@ -419,7 +424,8 @@ public void testMergeCompletionSuggestions() throws InterruptedException { new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), new SearchRequest(), () -> null - ) + ), + () -> false ); assertSame(clusters, mergedResponse.getClusters()); assertEquals(numResponses, mergedResponse.getTotalShards()); @@ -498,7 +504,8 @@ public void testMergeCompletionSuggestionsTieBreak() throws InterruptedException new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), new SearchRequest(), () -> null - ) + ), + () -> false ); assertSame(clusters, mergedResponse.getClusters()); assertEquals(numResponses, mergedResponse.getTotalShards()); @@ -579,7 +586,8 @@ public void testMergeAggs() throws InterruptedException { new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), new SearchRequest(), () -> null - ) + ), + () -> false ); assertSame(clusters, mergedResponse.getClusters()); assertEquals(numResponses, mergedResponse.getTotalShards()); @@ -743,7 +751,8 @@ public void testMergeSearchHits() throws InterruptedException { new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), new SearchRequest(), () -> null - ) + ), + () -> false ); assertEquals(TimeUnit.NANOSECONDS.toMillis(currentRelativeTime), searchResponse.getTook().millis()); @@ -810,7 +819,8 @@ public void testMergeNoResponsesAdded() { new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), new SearchRequest(), () -> null - ) + ), + () -> false ); assertSame(clusters, response.getClusters()); assertEquals(TimeUnit.NANOSECONDS.toMillis(currentRelativeTime), response.getTook().millis()); @@ -890,7 +900,8 @@ public void testMergeEmptySearchHitsWithNonEmpty() { new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), new SearchRequest(), () -> null - ) + ), + () -> false ); assertEquals(10, mergedResponse.getHits().getTotalHits().value()); assertEquals(10, mergedResponse.getHits().getHits().length); @@ -939,7 +950,8 @@ public void testMergeOnlyEmptyHits() { new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), new SearchRequest(), () -> null - ) + ), + () -> false ); assertEquals(expectedTotalHits, mergedResponse.getHits().getTotalHits()); } diff --git a/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java index 0a0015ae8cbf6..9cad7e57ffc48 100644 --- a/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java @@ -489,7 +489,8 @@ public void testCCSRemoteReduceMergeFails() throws Exception { new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), searchRequest, () -> null - ) + ), + () -> false ); if (localIndices == null) { assertNull(setOnce.get()); @@ -552,7 +553,8 @@ public void testCCSRemoteReduce() throws Exception { new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), searchRequest, () -> null - ) + ), + () -> false ); if (localIndices == null) { assertNull(setOnce.get()); @@ -594,7 +596,8 @@ public void testCCSRemoteReduce() throws Exception { new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), searchRequest, () -> null - ) + ), + () -> false ); if (localIndices == null) { assertNull(setOnce.get()); @@ -657,7 +660,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), searchRequest, () -> null - ) + ), + () -> false ); if (localIndices == null) { assertNull(setOnce.get()); @@ -702,7 +706,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), searchRequest, () -> null - ) + ), + () -> false ); if (localIndices == null) { assertNull(setOnce.get()); @@ -758,7 +763,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), searchRequest, () -> null - ) + ), + () -> false ); if (localIndices == null) { assertNull(setOnce.get()); diff --git a/server/src/test/java/org/opensearch/search/aggregations/AggregationCancellationTests.java b/server/src/test/java/org/opensearch/search/aggregations/AggregationCancellationTests.java new file mode 100644 index 0000000000000..072f283f7fb95 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/aggregations/AggregationCancellationTests.java @@ -0,0 +1,99 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.SortedSetDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.util.BytesRef; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.indices.breaker.NoneCircuitBreakerService; +import org.opensearch.core.tasks.TaskCancelledException; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.aggregations.metrics.StatsAggregationBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.BooleanSupplier; + +public class AggregationCancellationTests extends AggregatorTestCase { + + public void testCancellationDuringReduce() throws IOException { + // Create test index and data + try (Directory directory = newDirectory()) { + RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory); + + // Index enough documents to make reduce phase substantial + int numDocs = 1000; + for (int i = 0; i < numDocs; i++) { + Document document = new Document(); + document.add(new NumericDocValuesField("value", i)); + document.add(new SortedSetDocValuesField("keyword", new BytesRef("value" + (i % 100)))); + indexWriter.addDocument(document); + } + indexWriter.close(); + + try (IndexReader reader = DirectoryReader.open(directory)) { + IndexSearcher searcher = newSearcher(reader); + + // Create a complex aggregation with nested terms to make reduce phase longer + TermsAggregationBuilder aggregationBuilder = new TermsAggregationBuilder("keywords").field("keyword") + .size(100) + .subAggregation( + new TermsAggregationBuilder("nested_keywords").field("keyword") + .size(100) + .subAggregation(new StatsAggregationBuilder("stats").field("value")) + ); + + BooleanSupplier cancellationChecker = () -> true; + + // Execute aggregation + List aggs = new ArrayList<>(); + MultiBucketConsumerService.MultiBucketConsumer bucketConsumer = new MultiBucketConsumerService.MultiBucketConsumer( + 100000, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ); + + Query query = new MatchAllDocsQuery(); + + Aggregator agg = createAggregator(query, aggregationBuilder, searcher, bucketConsumer); + + // Collect + agg.preCollection(); + searcher.search(query, agg); + agg.postCollection(); + aggs.add(agg.buildTopLevel()); + + // Create reduce context with cancellation checker + InternalAggregation.ReduceContext reduceContext = InternalAggregation.ReduceContext.forFinalReduction( + agg.context().bigArrays(), + null, + bucketConsumer, + aggregationBuilder.buildPipelineTree() + ); + + reduceContext.setIsTaskCancelled(cancellationChecker); + + // Perform reduce - should throw TaskCancelledException + expectThrows( + TaskCancelledException.class, + () -> { InternalAggregations.reduce(List.of(InternalAggregations.from(aggs)), reduceContext); } + ); + } + } + } +} diff --git a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java index c3408374659ef..1d2517d9bcfb6 100644 --- a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java +++ b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java @@ -759,7 +759,8 @@ public void testTransformSearchPhase() { SearchProgressListener.NOOP, writableRegistry(), 2, - exc -> {} + exc -> {}, + () -> false ); final QuerySearchResult querySearchResult = new QuerySearchResult();