diff --git a/CHANGELOG.md b/CHANGELOG.md index 648cd3d0963e6..576bdc75c25c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -78,7 +78,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Fix 'system call filter not installed' caused when network.host: 0.0.0.0 ([#18309](https://github.com/opensearch-project/OpenSearch/pull/18309)) - Fix MatrixStatsAggregator reuse when mode parameter changes ([#18242](https://github.com/opensearch-project/OpenSearch/issues/18242)) - Replace the deprecated construction method of TopScoreDocCollectorManager with the new method ([#18395](https://github.com/opensearch-project/OpenSearch/pull/18395)) -- Fixed Approximate Framework regression with Lucene 10.2.1 by updating `intersectRight` BKD walk and `IntRef` visit method ([#18358](https://github.com/opensearch-project/OpenSearch/issues/18358 +- Fixed Approximate Framework regression with Lucene 10.2.1 by updating `intersectRight` BKD walk and `IntRef` visit method ([#18358](https://github.com/opensearch-project/OpenSearch/issues/18358)) +- Add task cancellation checks in aggregators ([#18426](https://github.com/opensearch-project/OpenSearch/pull/18426)) ### Security diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java b/server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java index f91bf972a3d28..07f2586ac756a 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java @@ -38,6 +38,7 @@ import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.core.common.breaker.CircuitBreakingException; import org.opensearch.core.indices.breaker.CircuitBreakerService; +import org.opensearch.core.tasks.TaskCancelledException; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; @@ -328,4 +329,10 @@ protected final InternalAggregations buildEmptySubAggregations() { public String toString() { return name; } + + protected void checkCancelled() { + if (context.isCancelled()) { + throw new TaskCancelledException("The query has been cancelled"); + } + } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/BucketsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/BucketsAggregator.java index a65728b2d658a..4b252de116e5d 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/BucketsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/BucketsAggregator.java @@ -235,9 +235,11 @@ protected void beforeBuildingBuckets(long[] ordsToCollect) throws IOException {} * array of ordinals */ protected final InternalAggregations[] buildSubAggsForBuckets(long[] bucketOrdsToCollect) throws IOException { + checkCancelled(); beforeBuildingBuckets(bucketOrdsToCollect); InternalAggregation[][] aggregations = new InternalAggregation[subAggregators.length][]; for (int i = 0; i < subAggregators.length; i++) { + checkCancelled(); aggregations[i] = subAggregators[i].buildAggregations(bucketOrdsToCollect); } InternalAggregations[] result = new InternalAggregations[bucketOrdsToCollect.length]; @@ -323,6 +325,7 @@ protected final InternalAggregation[] buildAggregationsForFixedBucketCount( BucketBuilderForFixedCount bucketBuilder, Function, InternalAggregation> resultBuilder ) throws IOException { + checkCancelled(); int totalBuckets = owningBucketOrds.length * bucketsPerOwningBucketOrd; long[] bucketOrdsToCollect = new long[totalBuckets]; int bucketOrdIdx = 0; @@ -373,6 +376,7 @@ protected final InternalAggregation[] buildAggregationsForSingleBucket(long[] ow * `consumeBucketsAndMaybeBreak(owningBucketOrds.length)` * here but we don't because single bucket aggs never have. */ + checkCancelled(); InternalAggregations[] subAggregationResults = buildSubAggsForBuckets(owningBucketOrds); InternalAggregation[] results = new InternalAggregation[owningBucketOrds.length]; for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) { @@ -403,6 +407,7 @@ protected final InternalAggregation[] buildAggregationsForVariableBuckets( BucketBuilderForVariable bucketBuilder, ResultBuilderForVariable resultBuilder ) throws IOException { + checkCancelled(); long totalOrdsToCollect = 0; for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) { totalOrdsToCollect += bucketOrds.bucketsInOrd(owningBucketOrds[ordIdx]); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregator.java index f82ee9dc242fb..b2bb3b1ffe786 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregator.java @@ -208,6 +208,7 @@ public void collect(int doc, long bucket) throws IOException { @Override public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException { + checkCancelled(); // Buckets are ordered into groups - [keyed filters] [key1&key2 intersects] int maxOrd = owningBucketOrds.length * totalNumKeys; int totalBucketsToBuild = 0; diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java index 471baf52b9303..d08c8121b0535 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java @@ -260,6 +260,7 @@ protected void doPostCollection() throws IOException { @Override public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException { + checkCancelled(); // Composite aggregator must be at the top of the aggregation tree assert owningBucketOrds.length == 1 && owningBucketOrds[0] == 0L; if (deferredCollectors != NO_OP_COLLECTOR) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filter/FiltersAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filter/FiltersAggregator.java index 7b86d0ed15cf8..9ea4c66a9a1c4 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/filter/FiltersAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filter/FiltersAggregator.java @@ -200,6 +200,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I owningBucketOrds, keys.length + (showOtherBucket ? 1 : 0), (offsetInOwningOrd, docCount, subAggregationResults) -> { + checkCancelled(); if (offsetInOwningOrd < keys.length) { return new InternalFilters.InternalBucket(keys[offsetInOwningOrd], docCount, subAggregationResults, keyed); } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AbstractHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AbstractHistogramAggregator.java index d3a4a51e5b6f2..f41ef17330212 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AbstractHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AbstractHistogramAggregator.java @@ -104,10 +104,12 @@ public AbstractHistogramAggregator( @Override public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException { return buildAggregationsForVariableBuckets(owningBucketOrds, bucketOrds, (bucketValue, docCount, subAggregationResults) -> { + checkCancelled(); double roundKey = Double.longBitsToDouble(bucketValue); double key = roundKey * interval + offset; return new InternalHistogram.Bucket(key, docCount, keyed, formatter, subAggregationResults); }, (owningBucketOrd, buckets) -> { + checkCancelled(); // the contract of the histogram aggregation is that shards must return buckets ordered by key in ascending order CollectionUtil.introSort(buckets, BucketOrder.key(true).comparator()); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java index 3c9fc2dcb0e43..8fa9c61821fd8 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java @@ -285,6 +285,7 @@ protected final InternalAggregation[] buildAggregations( subAggregationResults ), (owningBucketOrd, buckets) -> { + checkCancelled(); // the contract of the histogram aggregation is that shards must return // buckets ordered by key in ascending order CollectionUtil.introSort(buckets, BucketOrder.key(true).comparator()); @@ -733,6 +734,7 @@ private int increaseRoundingIfNeeded(long owningBucketOrd, int oldEstimatedBucke private void rebucket() { rebucketCount++; try (LongKeyedBucketOrds oldOrds = bucketOrds) { + checkCancelled(); long[] mergeMap = new long[Math.toIntExact(oldOrds.size())]; bucketOrds = new LongKeyedBucketOrds.FromMany(context.bigArrays()); for (long owningBucketOrd = 0; owningBucketOrd <= oldOrds.maxOwningBucketOrd(); owningBucketOrd++) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateRangeHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateRangeHistogramAggregator.java index f70a5bb537300..ce9892081d998 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateRangeHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateRangeHistogramAggregator.java @@ -207,6 +207,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I subAggregationResults ), (owningBucketOrd, buckets) -> { + checkCancelled(); // the contract of the histogram aggregation is that shards must return buckets ordered by key in ascending order CollectionUtil.introSort(buckets, BucketOrder.key(true).comparator()); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/VariableWidthHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/VariableWidthHistogramAggregator.java index 526945243c786..fdb0c025e880c 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/VariableWidthHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/VariableWidthHistogramAggregator.java @@ -585,6 +585,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I List buckets = new ArrayList<>(numClusters); for (int bucketOrd = 0; bucketOrd < numClusters; bucketOrd++) { + checkCancelled(); buckets.add(collector.buildBucket(bucketOrd, subAggregationResults[bucketOrd])); } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java index 02186a6a99079..769fd55329c90 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java @@ -469,6 +469,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I owningBucketOrds, ranges.length, (offsetInOwningOrd, docCount, subAggregationResults) -> { + checkCancelled(); Range range = ranges[offsetInOwningOrd]; return rangeFactory.createBucket(range.key, range.from, range.to, docCount, subAggregationResults, keyed, format); }, diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java index 130444a3d673f..b75639e993ec3 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java @@ -805,6 +805,7 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws B[][] topBucketsPreOrd = buildTopBucketsPerOrd(owningBucketOrds.length); long[] otherDocCount = new long[owningBucketOrds.length]; for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) { + checkCancelled(); final int size; if (localBucketCountThresholds.getMinDocCount() == 0) { // if minDocCount == 0 then we can end up with more buckets then maxBucketOrd() returns diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/LongRareTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/LongRareTermsAggregator.java index 6e4cd895e7496..483b678430e41 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/LongRareTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/LongRareTermsAggregator.java @@ -133,6 +133,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I long offset = 0; for (int owningOrdIdx = 0; owningOrdIdx < owningBucketOrds.length; owningOrdIdx++) { try (LongHash bucketsInThisOwningBucketToCollect = new LongHash(1, context.bigArrays())) { + checkCancelled(); filters[owningOrdIdx] = newFilter(); List builtBuckets = new ArrayList<>(); LongKeyedBucketOrds.BucketOrdsEnum collectedBuckets = bucketOrds.ordsEnum(owningBucketOrds[owningOrdIdx]); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java index ade23f7290f89..7fd4e12ad39c4 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java @@ -249,6 +249,7 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws B[][] topBucketsPerOrd = buildTopBucketsPerOrd(owningBucketOrds.length); long[] otherDocCounts = new long[owningBucketOrds.length]; for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) { + checkCancelled(); collectZeroDocEntriesIfNeeded(owningBucketOrds[ordIdx]); int size = (int) Math.min(bucketOrds.size(), localBucketCountThresholds.getRequiredSize()); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java index 07edf487af670..847170bc3ba12 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java @@ -125,6 +125,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I InternalMultiTerms.Bucket[][] topBucketsPerOrd = new InternalMultiTerms.Bucket[owningBucketOrds.length][]; long[] otherDocCounts = new long[owningBucketOrds.length]; for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) { + checkCancelled(); collectZeroDocEntriesIfNeeded(owningBucketOrds[ordIdx]); long bucketsInOrd = bucketOrds.bucketsInOrd(owningBucketOrds[ordIdx]); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/NumericTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/NumericTermsAggregator.java index bcdea9fb4af3c..c34ac7aa20ce5 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/NumericTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/NumericTermsAggregator.java @@ -259,6 +259,7 @@ private InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws B[][] topBucketsPerOrd = buildTopBucketsPerOrd(owningBucketOrds.length); long[] otherDocCounts = new long[owningBucketOrds.length]; for (int ordIdx = 0; ordIdx < owningBucketOrds.length; ordIdx++) { + checkCancelled(); collectZeroDocEntriesIfNeeded(owningBucketOrds[ordIdx]); long bucketsInOrd = bucketOrds.bucketsInOrd(owningBucketOrds[ordIdx]); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/StringRareTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/StringRareTermsAggregator.java index cc35fe75e5e92..6a4443adbb42d 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/StringRareTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/StringRareTermsAggregator.java @@ -136,6 +136,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I long offset = 0; for (int owningOrdIdx = 0; owningOrdIdx < owningBucketOrds.length; owningOrdIdx++) { try (BytesRefHash bucketsInThisOwningBucketToCollect = new BytesRefHash(context.bigArrays())) { + checkCancelled(); filters[owningOrdIdx] = newFilter(); List builtBuckets = new ArrayList<>(); BytesKeyedBucketOrds.BucketOrdsEnum collectedBuckets = bucketOrds.ordsEnum(owningBucketOrds[owningOrdIdx]); diff --git a/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSource.java b/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSource.java index 5732d545cb2d2..c0254e5de5048 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSource.java +++ b/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSource.java @@ -80,6 +80,7 @@ */ @PublicApi(since = "1.0.0") public abstract class ValuesSource { + private Runnable cancellationCheck; /** * Get the current {@link BytesValues}. @@ -101,6 +102,10 @@ public boolean needsScores() { */ public abstract Function roundingPreparer(IndexReader reader) throws IOException; + protected void setCancellationCheck(Runnable cancellationCheck) { + this.cancellationCheck = cancellationCheck; + } + /** * Check if this values source supports using global ordinals */ diff --git a/server/src/test/java/org/opensearch/search/aggregations/AggregationSetupTests.java b/server/src/test/java/org/opensearch/search/aggregations/AggregationSetupTests.java index 73ab8d7dc814c..e0cd8205ef086 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/AggregationSetupTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/AggregationSetupTests.java @@ -8,6 +8,7 @@ package org.opensearch.search.aggregations; +import org.opensearch.action.search.SearchShardTask; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexService; @@ -38,6 +39,8 @@ public void setUp() throws Exception { client().admin().indices().prepareRefresh("idx").get(); context = createSearchContext(index); ((TestSearchContext) context).setConcurrentSegmentSearchEnabled(true); + SearchShardTask task = new SearchShardTask(0, "n/a", "n/a", "test-kind", null, null); + context.setTask(task); } protected AggregatorFactories getAggregationFactories(String agg) throws IOException { diff --git a/server/src/test/java/org/opensearch/search/aggregations/AggregatorCancellationTests.java b/server/src/test/java/org/opensearch/search/aggregations/AggregatorCancellationTests.java new file mode 100644 index 0000000000000..5cdb089c01bef --- /dev/null +++ b/server/src/test/java/org/opensearch/search/aggregations/AggregatorCancellationTests.java @@ -0,0 +1,200 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.document.SortedSetDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.util.BytesRef; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; +import org.opensearch.core.xcontent.ObjectParser; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.plugins.SearchPlugin; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.aggregations.metrics.InternalSum; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class AggregatorCancellationTests extends AggregatorTestCase { + + @Override + protected List getSearchPlugins() { + return List.of(new SearchPlugin() { + @Override + public List getAggregations() { + return List.of( + new AggregationSpec( + "custom_cancellable", + CustomCancellableAggregationBuilder::new, + CustomCancellableAggregationBuilder.PARSER + ) + ); + } + }); + } + + public void testNestedAggregationCancellation() throws IOException { + try (Directory directory = newDirectory()) { + RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory); + + // Create documents + for (int i = 0; i < 100; i++) { + Document doc = new Document(); + doc.add(new SortedSetDocValuesField("category", new BytesRef("cat" + (i % 10)))); + doc.add(new SortedSetDocValuesField("subcategory", new BytesRef("subcat" + (i % 50)))); + doc.add(new SortedSetDocValuesField("brand", new BytesRef("brand" + (i % 20)))); + doc.add(new SortedNumericDocValuesField("value", i)); + indexWriter.addDocument(doc); + } + indexWriter.close(); + + try (IndexReader reader = DirectoryReader.open(directory)) { + IndexSearcher searcher = newIndexSearcher(reader); + + // Create nested aggregations with our custom cancellable agg + CustomCancellableAggregationBuilder aggBuilder = new CustomCancellableAggregationBuilder("test_agg").subAggregation( + new TermsAggregationBuilder("categories").field("category") + .size(10) + .subAggregation( + new TermsAggregationBuilder("subcategories").field("subcategory") + .size(50000) + .subAggregation(new TermsAggregationBuilder("brands").field("brand").size(20000)) + ) + ); + + expectThrows( + OpenSearchRejectedExecutionException.class, + () -> searchAndReduce( + searcher, + new MatchAllDocsQuery(), + aggBuilder, + keywordField("category"), + keywordField("subcategory"), + keywordField("brand") + ) + ); + } + } + } + + private static class CustomCancellableAggregationBuilder extends AbstractAggregationBuilder { + + public static final String NAME = "custom_cancellable"; + + public static final ObjectParser PARSER = ObjectParser.fromBuilder( + NAME, + CustomCancellableAggregationBuilder::new + ); + + CustomCancellableAggregationBuilder(String name) { + super(name); + } + + public CustomCancellableAggregationBuilder(StreamInput in) throws IOException { + super(in); + } + + @Override + protected AggregatorFactory doBuild( + QueryShardContext queryShardContext, + AggregatorFactory parent, + AggregatorFactories.Builder subfactoriesBuilder + ) throws IOException { + return new AggregatorFactory(name, queryShardContext, parent, subfactoriesBuilder, metadata) { + @Override + protected Aggregator createInternal( + SearchContext searchContext, + Aggregator parent, + CardinalityUpperBound cardinality, + Map metadata + ) throws IOException { + return new CustomCancellableAggregator( + name, + searchContext, + parent, + subfactoriesBuilder.build(queryShardContext, this), + null + ); + } + }; + } + + @Override + protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { + return builder; + } + + @Override + public CustomCancellableAggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map metadata) { + return new CustomCancellableAggregationBuilder(getName()).setMetadata(metadata).subAggregations(factoriesBuilder); + } + + public String getType() { + return NAME; + } + + @Override + public BucketCardinality bucketCardinality() { + return BucketCardinality.NONE; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + // Nothing to write + } + } + + private static class CustomCancellableAggregator extends AggregatorBase { + + CustomCancellableAggregator( + String name, + SearchContext context, + Aggregator parent, + AggregatorFactories factories, + Map metadata + ) throws IOException { + super(name, factories, context, parent, CardinalityUpperBound.NONE, metadata); + } + + @Override + protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + checkCancelled(); + return LeafBucketCollector.NO_OP_COLLECTOR; + } + + protected void checkCancelled() { + throw new OpenSearchRejectedExecutionException("The request has been cancelled"); + } + + @Override + public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException { + InternalAggregation internalAggregation = new InternalSum(name(), 0.0, DocValueFormat.RAW, metadata()); + return new InternalAggregation[] { internalAggregation }; + } + + @Override + public InternalAggregation buildEmptyAggregation() { + return new InternalSum(name(), 0.0, DocValueFormat.RAW, metadata()); + } + } +}