Skip to content

Commit bd11e6c

Browse files
authored
Fix NPE on composite aggregation with sub-aggregations that need scores (#28129)
The composite aggregation defers the collection of sub-aggregations to a second pass that visits documents only if they appear in the top buckets. Though the scorer for sub-aggregations is not set on this second pass and generates an NPE if any sub-aggregation tries to access the score. This change creates a scorer for the second pass and makes sure that sub-aggs can use it safely to check the score of the collected documents.
1 parent ee7eac8 commit bd11e6c

File tree

3 files changed

+123
-12
lines changed

3 files changed

+123
-12
lines changed

server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregator.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
import org.apache.lucene.search.CollectionTerminatedException;
2424
import org.apache.lucene.search.DocIdSet;
2525
import org.apache.lucene.search.DocIdSetIterator;
26+
import org.apache.lucene.search.Query;
27+
import org.apache.lucene.search.Scorer;
28+
import org.apache.lucene.search.Weight;
2629
import org.apache.lucene.util.RoaringDocIdSet;
2730
import org.elasticsearch.search.aggregations.Aggregator;
2831
import org.elasticsearch.search.aggregations.AggregatorFactories;
@@ -87,6 +90,12 @@ public InternalAggregation buildAggregation(long zeroBucket) throws IOException
8790

8891
// Replay all documents that contain at least one top bucket (collected during the first pass).
8992
grow(keys.size()+1);
93+
final boolean needsScores = needsScores();
94+
Weight weight = null;
95+
if (needsScores) {
96+
Query query = context.query();
97+
weight = context.searcher().createNormalizedWeight(query, true);
98+
}
9099
for (LeafContext context : contexts) {
91100
DocIdSetIterator docIdSetIterator = context.docIdSet.iterator();
92101
if (docIdSetIterator == null) {
@@ -95,7 +104,21 @@ public InternalAggregation buildAggregation(long zeroBucket) throws IOException
95104
final CompositeValuesSource.Collector collector =
96105
array.getLeafCollector(context.ctx, getSecondPassCollector(context.subCollector));
97106
int docID;
107+
DocIdSetIterator scorerIt = null;
108+
if (needsScores) {
109+
Scorer scorer = weight.scorer(context.ctx);
110+
// We don't need to check if the scorer is null
111+
// since we are sure that there are documents to replay (docIdSetIterator it not empty).
112+
scorerIt = scorer.iterator();
113+
context.subCollector.setScorer(scorer);
114+
}
98115
while ((docID = docIdSetIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
116+
if (needsScores) {
117+
assert scorerIt.docID() < docID;
118+
scorerIt.advance(docID);
119+
// aggregations should only be replayed on matching documents
120+
assert scorerIt.docID() == docID;
121+
}
99122
collector.collect(docID);
100123
}
101124
}

server/src/test/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregatorTests.java

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
import org.elasticsearch.index.mapper.NumberFieldMapper;
5151
import org.elasticsearch.search.aggregations.AggregatorTestCase;
5252
import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval;
53+
import org.elasticsearch.search.aggregations.metrics.tophits.TopHits;
54+
import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder;
5355
import org.elasticsearch.search.sort.SortOrder;
5456
import org.elasticsearch.test.IndexSettingsModule;
5557
import org.joda.time.DateTimeZone;
@@ -1065,8 +1067,73 @@ public void testWithKeywordAndDateHistogram() throws IOException {
10651067
);
10661068
}
10671069

1068-
private void testSearchCase(Query query,
1069-
Sort sort,
1070+
public void testWithKeywordAndTopHits() throws Exception {
1071+
final List<Map<String, List<Object>>> dataset = new ArrayList<>();
1072+
dataset.addAll(
1073+
Arrays.asList(
1074+
createDocument("keyword", "a"),
1075+
createDocument("keyword", "c"),
1076+
createDocument("keyword", "a"),
1077+
createDocument("keyword", "d"),
1078+
createDocument("keyword", "c")
1079+
)
1080+
);
1081+
final Sort sort = new Sort(new SortedSetSortField("keyword", false));
1082+
testSearchCase(new MatchAllDocsQuery(), sort, dataset,
1083+
() -> {
1084+
TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword")
1085+
.field("keyword");
1086+
return new CompositeAggregationBuilder("name", Collections.singletonList(terms))
1087+
.subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
1088+
}, (result) -> {
1089+
assertEquals(3, result.getBuckets().size());
1090+
assertEquals("{keyword=a}", result.getBuckets().get(0).getKeyAsString());
1091+
assertEquals(2L, result.getBuckets().get(0).getDocCount());
1092+
TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits");
1093+
assertNotNull(topHits);
1094+
assertEquals(topHits.getHits().getHits().length, 2);
1095+
assertEquals(topHits.getHits().getTotalHits(), 2L);
1096+
assertEquals("{keyword=c}", result.getBuckets().get(1).getKeyAsString());
1097+
assertEquals(2L, result.getBuckets().get(1).getDocCount());
1098+
topHits = result.getBuckets().get(1).getAggregations().get("top_hits");
1099+
assertNotNull(topHits);
1100+
assertEquals(topHits.getHits().getHits().length, 2);
1101+
assertEquals(topHits.getHits().getTotalHits(), 2L);
1102+
assertEquals("{keyword=d}", result.getBuckets().get(2).getKeyAsString());
1103+
assertEquals(1L, result.getBuckets().get(2).getDocCount());
1104+
topHits = result.getBuckets().get(2).getAggregations().get("top_hits");
1105+
assertNotNull(topHits);
1106+
assertEquals(topHits.getHits().getHits().length, 1);
1107+
assertEquals(topHits.getHits().getTotalHits(), 1L);;
1108+
}
1109+
);
1110+
1111+
testSearchCase(new MatchAllDocsQuery(), sort, dataset,
1112+
() -> {
1113+
TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword")
1114+
.field("keyword");
1115+
return new CompositeAggregationBuilder("name", Collections.singletonList(terms))
1116+
.aggregateAfter(Collections.singletonMap("keyword", "a"))
1117+
.subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
1118+
}, (result) -> {
1119+
assertEquals(2, result.getBuckets().size());
1120+
assertEquals("{keyword=c}", result.getBuckets().get(0).getKeyAsString());
1121+
assertEquals(2L, result.getBuckets().get(0).getDocCount());
1122+
TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits");
1123+
assertNotNull(topHits);
1124+
assertEquals(topHits.getHits().getHits().length, 2);
1125+
assertEquals(topHits.getHits().getTotalHits(), 2L);
1126+
assertEquals("{keyword=d}", result.getBuckets().get(1).getKeyAsString());
1127+
assertEquals(1L, result.getBuckets().get(1).getDocCount());
1128+
topHits = result.getBuckets().get(1).getAggregations().get("top_hits");
1129+
assertNotNull(topHits);
1130+
assertEquals(topHits.getHits().getHits().length, 1);
1131+
assertEquals(topHits.getHits().getTotalHits(), 1L);
1132+
}
1133+
);
1134+
}
1135+
1136+
private void testSearchCase(Query query, Sort sort,
10701137
List<Map<String, List<Object>>> dataset,
10711138
Supplier<CompositeAggregationBuilder> create,
10721139
Consumer<InternalComposite> verify) throws IOException {
@@ -1107,7 +1174,7 @@ private void executeTestCase(boolean reduced,
11071174
IndexSearcher indexSearcher = newSearcher(indexReader, sort == null, sort == null);
11081175
CompositeAggregationBuilder aggregationBuilder = create.get();
11091176
if (sort != null) {
1110-
CompositeAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES);
1177+
CompositeAggregator aggregator = createAggregator(query, aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES);
11111178
assertTrue(aggregator.canEarlyTerminate());
11121179
}
11131180
final InternalComposite composite;

test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,27 @@ protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggreg
103103
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
104104
}
105105

106-
/** Create a factory for the given aggregation builder. */
106+
107107
protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggregationBuilder,
108108
IndexSearcher indexSearcher,
109109
IndexSettings indexSettings,
110110
MultiBucketConsumer bucketConsumer,
111111
MappedFieldType... fieldTypes) throws IOException {
112+
return createAggregatorFactory(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes);
113+
}
114+
115+
/** Create a factory for the given aggregation builder. */
116+
protected AggregatorFactory<?> createAggregatorFactory(Query query,
117+
AggregationBuilder aggregationBuilder,
118+
IndexSearcher indexSearcher,
119+
IndexSettings indexSettings,
120+
MultiBucketConsumer bucketConsumer,
121+
MappedFieldType... fieldTypes) throws IOException {
112122
SearchContext searchContext = createSearchContext(indexSearcher, indexSettings);
113123
CircuitBreakerService circuitBreakerService = new NoneCircuitBreakerService();
114124
when(searchContext.aggregations())
115125
.thenReturn(new SearchContextAggregations(AggregatorFactories.EMPTY, bucketConsumer));
126+
when(searchContext.query()).thenReturn(query);
116127
when(searchContext.bigArrays()).thenReturn(new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), circuitBreakerService));
117128
// TODO: now just needed for top_hits, this will need to be revised for other agg unit tests:
118129
MapperService mapperService = mapperServiceMock();
@@ -146,28 +157,38 @@ protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregati
146157
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
147158
}
148159

149-
protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
160+
protected <A extends Aggregator> A createAggregator(Query query,
161+
AggregationBuilder aggregationBuilder,
150162
IndexSearcher indexSearcher,
151163
IndexSettings indexSettings,
152164
MappedFieldType... fieldTypes) throws IOException {
153-
return createAggregator(aggregationBuilder, indexSearcher, indexSettings,
165+
return createAggregator(query, aggregationBuilder, indexSearcher, indexSettings,
154166
new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
155167
}
156168

157-
protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
169+
protected <A extends Aggregator> A createAggregator(Query query, AggregationBuilder aggregationBuilder,
158170
IndexSearcher indexSearcher,
159171
MultiBucketConsumer bucketConsumer,
160172
MappedFieldType... fieldTypes) throws IOException {
161-
return createAggregator(aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes);
173+
return createAggregator(query, aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes);
162174
}
163175

164176
protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
165177
IndexSearcher indexSearcher,
166178
IndexSettings indexSettings,
167179
MultiBucketConsumer bucketConsumer,
168180
MappedFieldType... fieldTypes) throws IOException {
181+
return createAggregator(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes);
182+
}
183+
184+
protected <A extends Aggregator> A createAggregator(Query query,
185+
AggregationBuilder aggregationBuilder,
186+
IndexSearcher indexSearcher,
187+
IndexSettings indexSettings,
188+
MultiBucketConsumer bucketConsumer,
189+
MappedFieldType... fieldTypes) throws IOException {
169190
@SuppressWarnings("unchecked")
170-
A aggregator = (A) createAggregatorFactory(aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes)
191+
A aggregator = (A) createAggregatorFactory(query, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes)
171192
.create(null, true);
172193
return aggregator;
173194
}
@@ -262,7 +283,7 @@ protected <A extends InternalAggregation, C extends Aggregator> A search(IndexSe
262283
int maxBucket,
263284
MappedFieldType... fieldTypes) throws IOException {
264285
MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
265-
C a = createAggregator(builder, searcher, bucketConsumer, fieldTypes);
286+
C a = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes);
266287
a.preCollection();
267288
searcher.search(query, a);
268289
a.postCollection();
@@ -310,11 +331,11 @@ protected <A extends InternalAggregation, C extends Aggregator> A searchAndReduc
310331
Query rewritten = searcher.rewrite(query);
311332
Weight weight = searcher.createWeight(rewritten, true, 1f);
312333
MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
313-
C root = createAggregator(builder, searcher, bucketConsumer, fieldTypes);
334+
C root = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes);
314335

315336
for (ShardSearcher subSearcher : subSearchers) {
316337
MultiBucketConsumer shardBucketConsumer = new MultiBucketConsumer(maxBucket);
317-
C a = createAggregator(builder, subSearcher, shardBucketConsumer, fieldTypes);
338+
C a = createAggregator(query, builder, subSearcher, shardBucketConsumer, fieldTypes);
318339
a.preCollection();
319340
subSearcher.search(weight, a);
320341
a.postCollection();

0 commit comments

Comments
 (0)