Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix the visit of sub queries for HasParentQuery and HasChildQuery ([#18621](https://github.com/opensearch-project/OpenSearch/pull/18621))
- Fix the backward compatibility regression with COMPLEMENT for Regexp queries introduced in OpenSearch 3.0 ([#18640](https://github.com/opensearch-project/OpenSearch/pull/18640))
- Fix Replication lag computation ([#18602](https://github.com/opensearch-project/OpenSearch/pull/18602))
- Add task cancellation checks in FetchPhase during aggregation reductions ([18673](https://github.com/opensearch-project/OpenSearch/pull/18673))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ public SearchPhaseController.ReducedQueryPhase reduceAggs(TermsList candidateLis
SearchProgressListener.NOOP,
namedWriteableRegistry,
shards.size(),
exc -> {}
exc -> {},
() -> false
);
CountDownLatch latch = new CountDownLatch(shards.size());
for (int i = 0; i < shards.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ public BaseGeoGrid reduce(List<InternalAggregation> aggregations, ReduceContext
final int size = Math.toIntExact(reduceContext.isFinalReduce() == false ? buckets.size() : Math.min(requiredSize, buckets.size()));
BucketPriorityQueue<BaseGeoGridBucket> ordered = new BucketPriorityQueue<>(size);
for (LongObjectPagedHashMap.Cursor<List<BaseGeoGridBucket>> cursor : buckets) {
reduceContext.checkCancelled();
List<BaseGeoGridBucket> sameCellBuckets = cursor.value;
ordered.insertWithOverflow(reduceBucket(sameCellBuckets, reduceContext));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ public InternalAggregation reduce(List<InternalAggregation> aggregations, Reduce
double negRight = Double.NEGATIVE_INFINITY;

for (InternalAggregation aggregation : aggregations) {
reduceContext.checkCancelled();
InternalGeoBounds bounds = (InternalGeoBounds) aggregation;

if (bounds.top > top) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalAggregation.ReduceContextBuilder;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.builder.SearchSourceBuilder;
Expand All @@ -57,6 +58,7 @@
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;

/**
Expand All @@ -83,6 +85,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
private final boolean hasTopDocs;
private final boolean hasAggs;
private final boolean performFinalReduce;
private final BooleanSupplier isTaskCancelled;

private final PendingMerges pendingMerges;
private final Consumer<Exception> onPartialMergeFailure;
Expand All @@ -99,7 +102,8 @@ public QueryPhaseResultConsumer(
SearchProgressListener progressListener,
NamedWriteableRegistry namedWriteableRegistry,
int expectedResultSize,
Consumer<Exception> onPartialMergeFailure
Consumer<Exception> onPartialMergeFailure,
BooleanSupplier isTaskCancelled
) {
super(expectedResultSize);
this.executor = executor;
Expand All @@ -111,6 +115,7 @@ public QueryPhaseResultConsumer(
this.topNSize = SearchPhaseController.getTopDocsSize(request);
this.performFinalReduce = request.isFinalReduce();
this.onPartialMergeFailure = onPartialMergeFailure;
this.isTaskCancelled = isTaskCancelled;

SearchSourceBuilder source = request.source();
this.hasTopDocs = source == null || source.size() != 0;
Expand Down Expand Up @@ -158,7 +163,8 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
pendingMerges.numReducePhases,
false,
aggReduceContextBuilder,
performFinalReduce
performFinalReduce,
isTaskCancelled
);
if (hasAggs) {
// Update the circuit breaker to replace the estimation with the serialized size of the newly reduced result
Expand Down Expand Up @@ -219,7 +225,9 @@ private MergeResult partialReduce(
for (QuerySearchResult result : toConsume) {
aggsList.add(result.consumeAggs().expand());
}
newAggs = InternalAggregations.topLevelReduce(aggsList, aggReduceContextBuilder.forPartialReduction());
InternalAggregation.ReduceContext reduceContext = aggReduceContextBuilder.forPartialReduction();
reduceContext.setIsTaskCancelled(isTaskCancelled);
newAggs = InternalAggregations.topLevelReduce(aggsList, reduceContext);
} else {
newAggs = null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntFunction;
Expand Down Expand Up @@ -431,7 +432,17 @@
topDocs.add(td.topDocs);
}
}
return reducedQueryPhase(queryResults, Collections.emptyList(), topDocs, topDocsStats, 0, true, aggReduceContextBuilder, true);
return reducedQueryPhase(
queryResults,
Collections.emptyList(),
topDocs,
topDocsStats,
0,
true,
aggReduceContextBuilder,
true,
() -> false

Check warning on line 444 in server/src/main/java/org/opensearch/action/search/SearchPhaseController.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/action/search/SearchPhaseController.java#L444

Added line #L444 was not covered by tests
);
}

/**
Expand All @@ -451,7 +462,8 @@
int numReducePhases,
boolean isScrollRequest,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce
boolean performFinalReduce,
BooleanSupplier isTaskCancelled
) {
assert numReducePhases >= 0 : "num reduce phases must be >= 0 but was: " + numReducePhases;
numReducePhases++; // increment for this phase
Expand Down Expand Up @@ -526,7 +538,7 @@
reducedSuggest = new Suggest(Suggest.reduce(groupedSuggestions));
reducedCompletionSuggestions = reducedSuggest.filter(CompletionSuggestion.class);
}
final InternalAggregations aggregations = reduceAggs(aggReduceContextBuilder, performFinalReduce, bufferedAggs);
final InternalAggregations aggregations = reduceAggs(aggReduceContextBuilder, performFinalReduce, bufferedAggs, isTaskCancelled);
final SearchProfileShardResults shardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, bufferedTopDocs, from, size, reducedCompletionSuggestions);
final TotalHits totalHits = topDocsStats.getTotalHits();
Expand All @@ -551,14 +563,17 @@
private static InternalAggregations reduceAggs(
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce,
List<InternalAggregations> toReduce
List<InternalAggregations> toReduce,
BooleanSupplier isTaskCancelled
) {
return toReduce.isEmpty()
? null
: InternalAggregations.topLevelReduce(
toReduce,
performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction()
);
if (toReduce.isEmpty()) {
return null;
}
final ReduceContext reduceContext = performFinalReduce
? aggReduceContextBuilder.forFinalReduction()
: aggReduceContextBuilder.forPartialReduction();
reduceContext.setIsTaskCancelled(isTaskCancelled);
return InternalAggregations.topLevelReduce(toReduce, reduceContext);
}

/**
Expand Down Expand Up @@ -757,7 +772,8 @@
SearchProgressListener listener,
SearchRequest request,
int numShards,
Consumer<Exception> onPartialMergeFailure
Consumer<Exception> onPartialMergeFailure,
BooleanSupplier isTaskCancelled
) {
return new QueryPhaseResultConsumer(
request,
Expand All @@ -767,7 +783,8 @@
listener,
namedWriteableRegistry,
numShards,
onPartialMergeFailure
onPartialMergeFailure,
isTaskCancelled
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import java.util.Objects;
import java.util.TreeMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.BooleanSupplier;

/**
* Merges multiple search responses into one. Used in cross-cluster search when reduction is performed locally on each cluster.
Expand Down Expand Up @@ -110,7 +111,7 @@ final class SearchResponseMerger {
/**
* Add a search response to the list of responses to be merged together into one.
* Merges currently happen at once when all responses are available and
* {@link #getMergedResponse(SearchResponse.Clusters, SearchRequestContext)} )} is called.
* {@link #getMergedResponse(SearchResponse.Clusters, SearchRequestContext searchCotext, BooleanSupplier isTaskCancelled)} )} is called.
* That may change in the future as it's possible to introduce incremental merges as responses come in if necessary.
*/
void add(SearchResponse searchResponse) {
Expand All @@ -126,7 +127,11 @@ int numResponses() {
* Returns the merged response. To be called once all responses have been added through {@link #add(SearchResponse)}
* so that all responses are merged into a single one.
*/
SearchResponse getMergedResponse(SearchResponse.Clusters clusters, SearchRequestContext searchRequestContext) {
SearchResponse getMergedResponse(
SearchResponse.Clusters clusters,
SearchRequestContext searchRequestContext,
BooleanSupplier isTaskCancelled
) {
// if the search is only across remote clusters, none of them are available, and all of them have skip_unavailable set to true,
// we end up calling merge without anything to merge, we just return an empty search response
if (searchResponses.size() == 0) {
Expand Down Expand Up @@ -214,7 +219,9 @@ SearchResponse getMergedResponse(SearchResponse.Clusters clusters, SearchRequest
SearchHits mergedSearchHits = topDocsToSearchHits(topDocs, topDocsStats);
setSuggestShardIndex(shards, groupedSuggestions);
Suggest suggest = groupedSuggestions.isEmpty() ? null : new Suggest(Suggest.reduce(groupedSuggestions));
InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(aggs, aggReduceContextBuilder.forFinalReduction());
InternalAggregation.ReduceContext reduceContext = aggReduceContextBuilder.forFinalReduction();
reduceContext.setIsTaskCancelled(isTaskCancelled);
InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(aggs, reduceContext);
ShardSearchFailure[] shardFailures = failures.toArray(ShardSearchFailure.EMPTY_ARRAY);
SearchProfileShardResults profileShardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
// make failures ordering consistent between ordinary search and CCS by looking at the shard they come from
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.BooleanSupplier;
import java.util.function.Function;
import java.util.function.LongSupplier;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -548,7 +549,8 @@
searchAsyncActionProvider,
searchRequestContext
),
searchRequestContext
searchRequestContext,
() -> task instanceof CancellableTask ? ((CancellableTask) task).isCancelled() : false
);
} else {
AtomicInteger skippedClusters = new AtomicInteger(0);
Expand Down Expand Up @@ -639,7 +641,8 @@
ThreadPool threadPool,
ActionListener<SearchResponse> listener,
BiConsumer<SearchRequest, ActionListener<SearchResponse>> localSearchConsumer,
SearchRequestContext searchRequestContext
SearchRequestContext searchRequestContext,
BooleanSupplier isTaskCancelled
) {

if (localIndices == null && remoteIndices.size() == 1) {
Expand Down Expand Up @@ -728,7 +731,8 @@
searchResponseMerger,
totalClusters,
listener,
searchRequestContext
searchRequestContext,
isTaskCancelled
);
Client remoteClusterClient = remoteClusterService.getRemoteClusterClient(threadPool, clusterAlias);
remoteClusterClient.search(ccsSearchRequest, ccsListener);
Expand All @@ -743,7 +747,8 @@
searchResponseMerger,
totalClusters,
listener,
searchRequestContext
searchRequestContext,
isTaskCancelled
);
SearchRequest ccsLocalSearchRequest = SearchRequest.subSearchRequest(
searchRequest,
Expand Down Expand Up @@ -841,7 +846,8 @@
SearchResponseMerger searchResponseMerger,
int totalClusters,
ActionListener<SearchResponse> originalListener,
SearchRequestContext searchRequestContext
SearchRequestContext searchRequestContext,
BooleanSupplier isTaskCancelled
) {
return new CCSActionListener<SearchResponse, SearchResponse>(
clusterAlias,
Expand All @@ -863,7 +869,7 @@
searchResponseMerger.numResponses(),
skippedClusters.get()
);
return searchResponseMerger.getMergedResponse(clusters, searchRequestContext);
return searchResponseMerger.getMergedResponse(clusters, searchRequestContext, isTaskCancelled);
}
};
}
Expand Down Expand Up @@ -1265,7 +1271,8 @@
task.getProgressListener(),
searchRequest,
shardIterators.size(),
exc -> cancelTask(task, exc)
exc -> cancelTask(task, exc),

Check warning on line 1274 in server/src/main/java/org/opensearch/action/search/TransportSearchAction.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/action/search/TransportSearchAction.java#L1274

Added line #L1274 was not covered by tests
task::isCancelled
);
AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction;
switch (searchRequest.searchType()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,7 @@ public ReaderContext readerContext() {
@Override
public InternalAggregation.ReduceContext partialOnShard() {
InternalAggregation.ReduceContext rc = requestToAggReduceContextBuilder.apply(request.source()).forPartialReduction();
rc.setIsTaskCancelled(this::isCancelled);
rc.setSliceLevel(shouldUseConcurrentSearch());
return rc;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.opensearch.core.common.io.stream.NamedWriteable;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.tasks.TaskCancelledException;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.rest.action.search.RestSearchAction;
Expand All @@ -52,6 +53,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntConsumer;
Expand Down Expand Up @@ -95,6 +97,7 @@ public static class ReduceContext {
private final ScriptService scriptService;
private final IntConsumer multiBucketConsumer;
private final PipelineTree pipelineTreeRoot;
private BooleanSupplier isTaskCancelled = () -> false;

private boolean isSliceLevel;
/**
Expand Down Expand Up @@ -210,6 +213,23 @@ public void consumeBucketsAndMaybeBreak(int size) {
multiBucketConsumer.accept(size);
}

/**
* Setter for task cancellation supplier.
* @param isTaskCancelled
*/
public void setIsTaskCancelled(BooleanSupplier isTaskCancelled) {
this.isTaskCancelled = isTaskCancelled;
}

/**
* Will check and throw the exception to terminate the request
*/
public void checkCancelled() {
if (isTaskCancelled.getAsBoolean()) {
throw new TaskCancelledException("The query has been cancelled");
}
}

}

protected final String name;
Expand Down Expand Up @@ -288,6 +308,7 @@ public InternalAggregation reducePipelines(
) {
assert reduceContext.isFinalReduce();
for (PipelineAggregator pipelineAggregator : pipelinesForThisAgg.aggregators()) {
reduceContext.checkCancelled();
reducedAggs = pipelineAggregator.reduce(reducedAggs, reduceContext);
}
return reducedAggs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ public static InternalAggregations reduce(List<InternalAggregations> aggregation
Map<String, List<InternalAggregation>> aggByName = new HashMap<>();
for (InternalAggregations aggregations : aggregationsList) {
for (Aggregation aggregation : aggregations.aggregations) {
context.checkCancelled();
List<InternalAggregation> aggs = aggByName.computeIfAbsent(
aggregation.getName(),
k -> new ArrayList<>(aggregationsList.size())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ private List<B> reducePipelineBuckets(ReduceContext reduceContext, PipelineTree
for (B bucket : getBuckets()) {
List<InternalAggregation> aggs = new ArrayList<>();
for (Aggregation agg : bucket.getAggregations()) {
reduceContext.checkCancelled();
PipelineTree subTree = pipelineTree.subTree(agg.getName());
aggs.add(((InternalAggregation) agg).reducePipelines((InternalAggregation) agg, reduceContext, subTree));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ public InternalAggregation reduce(List<InternalAggregation> aggregations, Reduce
long docCount = 0L;
List<InternalAggregations> subAggregationsList = new ArrayList<>(aggregations.size());
for (InternalAggregation aggregation : aggregations) {
reduceContext.checkCancelled();
assert aggregation.getName().equals(getName());
docCount += ((InternalSingleBucketAggregation) aggregation).docCount;
subAggregationsList.add(((InternalSingleBucketAggregation) aggregation).aggregations);
Expand Down
Loading
Loading