3232package org .opensearch .search .aggregations .bucket .range ;
3333
3434import org .apache .lucene .index .LeafReaderContext ;
35+ import org .apache .lucene .search .DocIdSetIterator ;
3536import org .apache .lucene .search .ScoreMode ;
37+ import org .apache .lucene .util .FixedBitSet ;
3638import org .opensearch .core .ParseField ;
3739import org .opensearch .core .common .io .stream .StreamInput ;
3840import org .opensearch .core .common .io .stream .StreamOutput ;
4345import org .opensearch .core .xcontent .ToXContentObject ;
4446import org .opensearch .core .xcontent .XContentBuilder ;
4547import org .opensearch .core .xcontent .XContentParser ;
48+ import org .opensearch .index .codec .composite .CompositeIndexFieldInfo ;
49+ import org .opensearch .index .compositeindex .datacube .MetricStat ;
50+ import org .opensearch .index .compositeindex .datacube .startree .index .StarTreeValues ;
51+ import org .opensearch .index .compositeindex .datacube .startree .utils .StarTreeUtils ;
52+ import org .opensearch .index .compositeindex .datacube .startree .utils .iterator .SortedNumericStarTreeValuesIterator ;
4653import org .opensearch .index .fielddata .SortedNumericDoubleValues ;
54+ import org .opensearch .index .mapper .NumberFieldMapper ;
4755import org .opensearch .search .DocValueFormat ;
4856import org .opensearch .search .aggregations .Aggregator ;
4957import org .opensearch .search .aggregations .AggregatorFactories ;
5361import org .opensearch .search .aggregations .LeafBucketCollector ;
5462import org .opensearch .search .aggregations .LeafBucketCollectorBase ;
5563import org .opensearch .search .aggregations .NonCollectingAggregator ;
64+ import org .opensearch .search .aggregations .StarTreeBucketCollector ;
65+ import org .opensearch .search .aggregations .StarTreePreComputeCollector ;
5666import org .opensearch .search .aggregations .bucket .BucketsAggregator ;
5767import org .opensearch .search .aggregations .bucket .filterrewrite .FilterRewriteOptimizationContext ;
5868import org .opensearch .search .aggregations .bucket .filterrewrite .RangeAggregatorBridge ;
5969import org .opensearch .search .aggregations .support .ValuesSource ;
6070import org .opensearch .search .aggregations .support .ValuesSourceConfig ;
6171import org .opensearch .search .internal .SearchContext ;
72+ import org .opensearch .search .startree .StarTreeQueryHelper ;
73+ import org .opensearch .search .startree .StarTreeTraversalUtil ;
74+ import org .opensearch .search .startree .filter .DimensionFilter ;
6275
6376import java .io .IOException ;
6477import java .util .ArrayList ;
7083
7184import static org .opensearch .core .xcontent .ConstructingObjectParser .optionalConstructorArg ;
7285import static org .opensearch .search .aggregations .bucket .filterrewrite .AggregatorBridge .segmentMatchAll ;
86+ import static org .opensearch .search .startree .StarTreeQueryHelper .getSupportedStarTree ;
7387
7488/**
7589 * Aggregate all docs that match given ranges.
7690 *
7791 * @opensearch.internal
7892 */
79- public class RangeAggregator extends BucketsAggregator {
93+ public class RangeAggregator extends BucketsAggregator implements StarTreePreComputeCollector {
8094
8195 public static final ParseField RANGES_FIELD = new ParseField ("ranges" );
8296 public static final ParseField KEYED_FIELD = new ParseField ("keyed" );
97+ public final String fieldName ;
8398
8499 /**
85100 * Range for the range aggregator
@@ -298,6 +313,9 @@ protected Function<Object, Long> bucketOrdProducer() {
298313 }
299314 };
300315 filterRewriteOptimizationContext = new FilterRewriteOptimizationContext (bridge , parent , subAggregators .length , context );
316+ this .fieldName = (valuesSource instanceof ValuesSource .Numeric .FieldData )
317+ ? ((ValuesSource .Numeric .FieldData ) valuesSource ).getIndexFieldName ()
318+ : null ;
301319 }
302320
303321 @ Override
@@ -310,8 +328,13 @@ public ScoreMode scoreMode() {
310328
311329 @ Override
312330 protected boolean tryPrecomputeAggregationForLeaf (LeafReaderContext ctx ) throws IOException {
313- if (segmentMatchAll (context , ctx )) {
314- return filterRewriteOptimizationContext .tryOptimize (ctx , this ::incrementBucketDocCount , false );
331+ if (segmentMatchAll (context , ctx ) && filterRewriteOptimizationContext .tryOptimize (ctx , this ::incrementBucketDocCount , false )) {
332+ return true ;
333+ }
334+ CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree (this .context .getQueryShardContext ());
335+ if (supportedStarTree != null ) {
336+ preComputeWithStarTree (ctx , supportedStarTree );
337+ return true ;
315338 }
316339 return false ;
317340 }
@@ -333,52 +356,107 @@ public void collect(int doc, long bucket) throws IOException {
333356 }
334357
335358 private int collect (int doc , double value , long owningBucketOrdinal , int lowBound ) throws IOException {
336- int lo = lowBound , hi = ranges .length - 1 ; // all candidates are between these indexes
337- int mid = (lo + hi ) >>> 1 ;
338- while (lo <= hi ) {
339- if (value < ranges [mid ].from ) {
340- hi = mid - 1 ;
341- } else if (value >= maxTo [mid ]) {
342- lo = mid + 1 ;
343- } else {
344- break ;
359+ MatchedRange range = new MatchedRange (ranges , lowBound , value , maxTo );
360+ for (int i = range .startLo ; i <= range .endHi ; ++i ) {
361+ if (ranges [i ].matches (value )) {
362+ collectBucket (sub , doc , subBucketOrdinal (owningBucketOrdinal , i ));
345363 }
346- mid = (lo + hi ) >>> 1 ;
347364 }
348- if (lo > hi ) return lo ; // no potential candidate
349-
350- // binary search the lower bound
351- int startLo = lo , startHi = mid ;
352- while (startLo <= startHi ) {
353- final int startMid = (startLo + startHi ) >>> 1 ;
354- if (value >= maxTo [startMid ]) {
355- startLo = startMid + 1 ;
356- } else {
357- startHi = startMid - 1 ;
358- }
365+ return range .endHi + 1 ;
366+ }
367+ };
368+ }
369+
370+ private void preComputeWithStarTree (LeafReaderContext ctx , CompositeIndexFieldInfo starTree ) throws IOException {
371+ StarTreeBucketCollector starTreeBucketCollector = getStarTreeBucketCollector (ctx , starTree , null );
372+ FixedBitSet matchingDocsBitSet = starTreeBucketCollector .getMatchingDocsBitSet ();
373+
374+ int numBits = matchingDocsBitSet .length ();
375+
376+ if (numBits > 0 ) {
377+ for (int bit = matchingDocsBitSet .nextSetBit (0 ); bit != DocIdSetIterator .NO_MORE_DOCS ; bit = (bit + 1 < numBits )
378+ ? matchingDocsBitSet .nextSetBit (bit + 1 )
379+ : DocIdSetIterator .NO_MORE_DOCS ) {
380+ starTreeBucketCollector .collectStarTreeEntry (bit , 0 );
381+ }
382+ }
383+ }
384+
385+ @ Override
386+ public StarTreeBucketCollector getStarTreeBucketCollector (
387+ LeafReaderContext ctx ,
388+ CompositeIndexFieldInfo starTree ,
389+ StarTreeBucketCollector parentCollector
390+ ) throws IOException {
391+ assert parentCollector == null ;
392+ StarTreeValues starTreeValues = StarTreeQueryHelper .getStarTreeValues (ctx , starTree );
393+ // TODO: Evaluate optimizing StarTree traversal filter with specific ranges instead of MATCH_ALL_DEFAULT
394+ return new StarTreeBucketCollector (
395+ starTreeValues ,
396+ StarTreeTraversalUtil .getStarTreeResult (
397+ starTreeValues ,
398+ StarTreeQueryHelper .mergeDimensionFilterIfNotExists (
399+ context .getQueryShardContext ().getStarTreeQueryContext ().getBaseQueryStarTreeFilter (),
400+ fieldName ,
401+ List .of (DimensionFilter .MATCH_ALL_DEFAULT )
402+ ),
403+ context
404+ )
405+ ) {
406+ @ Override
407+ public void setSubCollectors () throws IOException {
408+ for (Aggregator aggregator : subAggregators ) {
409+ this .subCollectors .add (((StarTreePreComputeCollector ) aggregator ).getStarTreeBucketCollector (ctx , starTree , this ));
410+ }
411+ }
412+
413+ SortedNumericStarTreeValuesIterator valuesIterator = (SortedNumericStarTreeValuesIterator ) starTreeValues
414+ .getDimensionValuesIterator (fieldName );
415+
416+ String metricName = StarTreeUtils .fullyQualifiedFieldNameForStarTreeMetricsDocValues (
417+ starTree .getField (),
418+ "_doc_count" ,
419+ MetricStat .DOC_COUNT .getTypeName ()
420+ );
421+
422+ SortedNumericStarTreeValuesIterator docCountsIterator = (SortedNumericStarTreeValuesIterator ) starTreeValues
423+ .getMetricValuesIterator (metricName );
424+
425+ @ Override
426+ public void collectStarTreeEntry (int starTreeEntry , long owningBucketOrd ) throws IOException {
427+ if (!valuesIterator .advanceExact (starTreeEntry )) {
428+ return ;
359429 }
360430
361- // binary search the upper bound
362- int endLo = mid , endHi = hi ;
363- while (endLo <= endHi ) {
364- final int endMid = (endLo + endHi ) >>> 1 ;
365- if (value < ranges [endMid ].from ) {
366- endHi = endMid - 1 ;
431+ for (int i = 0 , count = valuesIterator .entryValueCount (); i < count ; i ++) {
432+ long dimensionLongValue = valuesIterator .nextValue ();
433+ double dimensionValue ;
434+
435+ // Only numeric & floating points are supported as of now in star-tree
436+ // TODO: Add support for isBigInteger() when it gets supported in star-tree
437+ if (valuesSource .isFloatingPoint ()) {
438+ dimensionValue = ((NumberFieldMapper .NumberFieldType ) context .mapperService ().fieldType (fieldName )).toDoubleValue (
439+ dimensionLongValue
440+ );
367441 } else {
368- endLo = endMid + 1 ;
442+ dimensionValue = dimensionLongValue ;
369443 }
370- }
371444
372- assert startLo == lowBound || value >= maxTo [startLo - 1 ];
373- assert endHi == ranges .length - 1 || value < ranges [endHi + 1 ].from ;
445+ MatchedRange matchedRange = new MatchedRange (ranges , 0 , dimensionValue , maxTo );
446+ if (matchedRange .startLo > matchedRange .endHi ) {
447+ continue ; // No matching range
448+ }
374449
375- for (int i = startLo ; i <= endHi ; ++i ) {
376- if (ranges [i ].matches (value )) {
377- collectBucket (sub , doc , subBucketOrdinal (owningBucketOrdinal , i ));
450+ if (docCountsIterator .advanceExact (starTreeEntry )) {
451+ long metricValue = docCountsIterator .nextValue ();
452+ for (int j = matchedRange .startLo ; j <= matchedRange .endHi ; ++j ) {
453+ if (ranges [j ].matches (dimensionValue )) {
454+ long bucketOrd = subBucketOrdinal (owningBucketOrd , j );
455+ collectStarTreeBucket (this , metricValue , bucketOrd , starTreeEntry );
456+ }
457+ }
378458 }
379459 }
380-
381- return endHi + 1 ;
382460 }
383461 };
384462 }
@@ -421,6 +499,63 @@ public InternalAggregation buildEmptyAggregation() {
421499 return rangeFactory .create (name , buckets , format , keyed , metadata ());
422500 }
423501
502+ static class MatchedRange {
503+ int startLo , endHi ;
504+
505+ MatchedRange (RangeAggregator .Range [] ranges , int lowBound , double value , double [] maxTo ) {
506+ computeMatchingRange (ranges , lowBound , value , maxTo );
507+ }
508+
509+ private void computeMatchingRange (RangeAggregator .Range [] ranges , int lowBound , double value , double [] maxTo ) {
510+ int lo = lowBound , hi = ranges .length - 1 ;
511+ int mid = (lo + hi ) >>> 1 ;
512+
513+ while (lo <= hi ) {
514+ if (value < ranges [mid ].from ) {
515+ hi = mid - 1 ;
516+ } else if (value >= maxTo [mid ]) {
517+ lo = mid + 1 ;
518+ } else {
519+ break ;
520+ }
521+ mid = (lo + hi ) >>> 1 ;
522+ }
523+ if (lo > hi ) {
524+ this .startLo = lo ;
525+ this .endHi = lo - 1 ;
526+ return ;
527+ }
528+
529+ // binary search the lower bound
530+ int startLo = lo , startHi = mid ;
531+ while (startLo <= startHi ) {
532+ int startMid = (startLo + startHi ) >>> 1 ;
533+ if (value >= maxTo [startMid ]) {
534+ startLo = startMid + 1 ;
535+ } else {
536+ startHi = startMid - 1 ;
537+ }
538+ }
539+
540+ // binary search the upper bound
541+ int endLo = mid , endHi = hi ;
542+ while (endLo <= endHi ) {
543+ int endMid = (endLo + endHi ) >>> 1 ;
544+ if (value < ranges [endMid ].from ) {
545+ endHi = endMid - 1 ;
546+ } else {
547+ endLo = endMid + 1 ;
548+ }
549+ }
550+
551+ assert startLo == lowBound || value >= maxTo [startLo - 1 ];
552+ assert endHi == ranges .length - 1 || value < ranges [endHi + 1 ].from ;
553+
554+ this .startLo = startLo ;
555+ this .endHi = endHi ;
556+ }
557+ }
558+
424559 /**
425560 * Unmapped range
426561 *
@@ -456,7 +591,7 @@ public Unmapped(
456591 public InternalAggregation buildEmptyAggregation () {
457592 InternalAggregations subAggs = buildEmptySubAggregations ();
458593 List <org .opensearch .search .aggregations .bucket .range .Range .Bucket > buckets = new ArrayList <>(ranges .length );
459- for (RangeAggregator . Range range : ranges ) {
594+ for (Range range : ranges ) {
460595 buckets .add (factory .createBucket (range .key , range .from , range .to , 0 , subAggs , keyed , format ));
461596 }
462597 return factory .create (name , buckets , format , keyed , metadata ());
0 commit comments