1717import org .opensearch .common .lucene .Lucene ;
1818import org .opensearch .index .codec .composite .CompositeIndexFieldInfo ;
1919import org .opensearch .index .codec .composite .CompositeIndexReader ;
20+ import org .opensearch .index .compositeindex .datacube .DateDimension ;
2021import org .opensearch .index .compositeindex .datacube .Dimension ;
2122import org .opensearch .index .compositeindex .datacube .Metric ;
2223import org .opensearch .index .compositeindex .datacube .MetricStat ;
2324import org .opensearch .index .compositeindex .datacube .startree .index .StarTreeValues ;
25+ import org .opensearch .index .compositeindex .datacube .startree .utils .date .DateTimeUnitAdapter ;
26+ import org .opensearch .index .compositeindex .datacube .startree .utils .date .DateTimeUnitRounding ;
2427import org .opensearch .index .compositeindex .datacube .startree .utils .iterator .SortedNumericStarTreeValuesIterator ;
2528import org .opensearch .index .mapper .CompositeDataCubeFieldType ;
2629import org .opensearch .index .query .MatchAllQueryBuilder ;
2730import org .opensearch .index .query .QueryBuilder ;
2831import org .opensearch .index .query .TermQueryBuilder ;
2932import org .opensearch .search .aggregations .AggregatorFactory ;
3033import org .opensearch .search .aggregations .LeafBucketCollector ;
31- import org .opensearch .search .aggregations .LeafBucketCollectorBase ;
34+ import org .opensearch .search .aggregations .StarTreeBucketCollector ;
35+ import org .opensearch .search .aggregations .bucket .histogram .DateHistogramAggregatorFactory ;
3236import org .opensearch .search .aggregations .metrics .MetricAggregatorFactory ;
3337import org .opensearch .search .aggregations .support .ValuesSource ;
3438import org .opensearch .search .builder .SearchSourceBuilder ;
3741import org .opensearch .search .startree .StarTreeQueryContext ;
3842
3943import java .io .IOException ;
40- import java .util .HashMap ;
4144import java .util .List ;
4245import java .util .Map ;
46+ import java .util .Set ;
47+ import java .util .function .BiConsumer ;
4348import java .util .function .Consumer ;
4449import java .util .stream .Collectors ;
4550
@@ -74,10 +79,16 @@ public static StarTreeQueryContext getStarTreeQueryContext(SearchContext context
7479 );
7580
7681 for (AggregatorFactory aggregatorFactory : context .aggregations ().factories ().getFactories ()) {
77- MetricStat metricStat = validateStarTreeMetricSupport (compositeMappedFieldType , aggregatorFactory );
78- if (metricStat == null ) {
79- return null ;
82+ // first check for aggregation is a metric aggregation
83+ if (validateStarTreeMetricSupport (compositeMappedFieldType , aggregatorFactory )) {
84+ continue ;
85+ }
86+
87+ // if not a metric aggregation, check for applicable date histogram shape
88+ if (validateDateHistogramSupport (compositeMappedFieldType , aggregatorFactory )) {
89+ continue ;
8090 }
91+ return null ;
8192 }
8293
8394 // need to cache star tree values only for multiple aggregations
@@ -100,63 +111,86 @@ private static StarTreeQueryContext tryCreateStarTreeQueryContext(
100111 if (queryBuilder == null || queryBuilder instanceof MatchAllQueryBuilder ) {
101112 queryMap = null ;
102113 } else if (queryBuilder instanceof TermQueryBuilder ) {
114+ TermQueryBuilder termQueryBuilder = (TermQueryBuilder ) queryBuilder ;
103115 // TODO: Add support for keyword fields
104- if (compositeFieldType .getDimensions ().stream ().anyMatch (d -> d .getDocValuesType () != DocValuesType .SORTED_NUMERIC )) {
105- // return null for non-numeric fields
106- return null ;
107- }
108-
109- List <String > supportedDimensions = compositeFieldType .getDimensions ()
116+ Dimension matchedDimension = compositeFieldType .getDimensions ()
110117 .stream ()
111- .map ( Dimension :: getField )
112- .collect ( Collectors . toList ());
113- queryMap = getStarTreePredicates ( queryBuilder , supportedDimensions );
114- if (queryMap == null ) {
118+ .filter ( d -> ( d . getField (). equals ( termQueryBuilder . fieldName ()) && d . getDocValuesType () == DocValuesType . SORTED_NUMERIC ) )
119+ .findFirst ()
120+ . orElse ( null );
121+ if (matchedDimension == null ) {
115122 return null ;
116123 }
124+ queryMap = Map .of (termQueryBuilder .fieldName (), Long .parseLong (termQueryBuilder .value ().toString ()));
117125 } else {
118126 return null ;
119127 }
120128 return new StarTreeQueryContext (compositeIndexFieldInfo , queryMap , cacheStarTreeValuesSize );
121129 }
122130
123- /**
124- * Parse query body to star-tree predicates
125- * @param queryBuilder to match star-tree supported query shape
126- * @return predicates to match
127- */
128- private static Map <String , Long > getStarTreePredicates (QueryBuilder queryBuilder , List <String > supportedDimensions ) {
129- TermQueryBuilder tq = (TermQueryBuilder ) queryBuilder ;
130- String field = tq .fieldName ();
131- if (!supportedDimensions .contains (field )) {
132- return null ;
133- }
134- long inputQueryVal = Long .parseLong (tq .value ().toString ());
135-
136- // Create a map with the field and the value
137- Map <String , Long > predicateMap = new HashMap <>();
138- predicateMap .put (field , inputQueryVal );
139- return predicateMap ;
140- }
141-
142- private static MetricStat validateStarTreeMetricSupport (
131+ private static boolean validateStarTreeMetricSupport (
143132 CompositeDataCubeFieldType compositeIndexFieldInfo ,
144133 AggregatorFactory aggregatorFactory
145134 ) {
146135 if (aggregatorFactory instanceof MetricAggregatorFactory && aggregatorFactory .getSubFactories ().getFactories ().length == 0 ) {
136+ MetricAggregatorFactory metricAggregatorFactory = (MetricAggregatorFactory ) aggregatorFactory ;
147137 String field ;
148138 Map <String , List <MetricStat >> supportedMetrics = compositeIndexFieldInfo .getMetrics ()
149139 .stream ()
150140 .collect (Collectors .toMap (Metric ::getField , Metric ::getMetrics ));
151141
152- MetricStat metricStat = ((MetricAggregatorFactory ) aggregatorFactory ).getMetricStat ();
153- field = ((MetricAggregatorFactory ) aggregatorFactory ).getField ();
142+ MetricStat metricStat = metricAggregatorFactory .getMetricStat ();
143+ field = metricAggregatorFactory .getField ();
144+
145+ return supportedMetrics .containsKey (field ) && supportedMetrics .get (field ).contains (metricStat );
146+ }
147+ return false ;
148+ }
149+
150+ private static boolean validateDateHistogramSupport (
151+ CompositeDataCubeFieldType compositeIndexFieldInfo ,
152+ AggregatorFactory aggregatorFactory
153+ ) {
154+ if (!(aggregatorFactory instanceof DateHistogramAggregatorFactory )
155+ || aggregatorFactory .getSubFactories ().getFactories ().length < 1 ) {
156+ return false ;
157+ }
158+ DateHistogramAggregatorFactory dateHistogramAggregatorFactory = (DateHistogramAggregatorFactory ) aggregatorFactory ;
159+
160+ // Find the DateDimension in the dimensions list
161+ DateDimension starTreeDateDimension = null ;
162+ for (Dimension dimension : compositeIndexFieldInfo .getDimensions ()) {
163+ if (dimension instanceof DateDimension ) {
164+ starTreeDateDimension = (DateDimension ) dimension ;
165+ break ;
166+ }
167+ }
168+
169+ // If no DateDimension is found, validation fails
170+ if (starTreeDateDimension == null ) {
171+ return false ;
172+ }
173+
174+ // Ensure the rounding is not null
175+ if (dateHistogramAggregatorFactory .getRounding () == null ) {
176+ return false ;
177+ }
178+
179+ // Find the closest valid interval in the DateTimeUnitRounding class associated with star tree
180+ DateTimeUnitRounding rounding = starTreeDateDimension .findClosestValidInterval (
181+ new DateTimeUnitAdapter (dateHistogramAggregatorFactory .getRounding ())
182+ );
183+ if (rounding == null ) {
184+ return false ;
185+ }
154186
155- if (field != null && supportedMetrics .containsKey (field ) && supportedMetrics .get (field ).contains (metricStat )) {
156- return metricStat ;
187+ // Validate all sub-factories
188+ for (AggregatorFactory subFactory : aggregatorFactory .getSubFactories ().getFactories ()) {
189+ if (!validateStarTreeMetricSupport (compositeIndexFieldInfo , subFactory )) {
190+ return false ;
157191 }
158192 }
159- return null ;
193+ return true ;
160194 }
161195
162196 public static CompositeIndexFieldInfo getSupportedStarTree (SearchContext context ) {
@@ -222,11 +256,37 @@ public static LeafBucketCollector getStarTreeLeafCollector(
222256 // Call the final consumer after processing all entries
223257 finalConsumer .run ();
224258
225- // Return a LeafBucketCollector that terminates collection
226- return new LeafBucketCollectorBase (sub , valuesSource .doubleValues (ctx )) {
259+ // Terminate after pre-computing aggregation
260+ throw new CollectionTerminatedException ();
261+ }
262+
263+ public static StarTreeBucketCollector getStarTreeBucketMetricCollector (
264+ CompositeIndexFieldInfo starTree ,
265+ String metric ,
266+ ValuesSource .Numeric valuesSource ,
267+ StarTreeBucketCollector parentCollector ,
268+ Consumer <Long > growArrays ,
269+ BiConsumer <Long , Long > updateBucket
270+ ) throws IOException {
271+ assert parentCollector != null ;
272+ return new StarTreeBucketCollector (parentCollector ) {
273+ String metricName = StarTreeUtils .fullyQualifiedFieldNameForStarTreeMetricsDocValues (
274+ starTree .getField (),
275+ ((ValuesSource .Numeric .FieldData ) valuesSource ).getIndexFieldName (),
276+ metric
277+ );
278+ SortedNumericStarTreeValuesIterator metricValuesIterator = (SortedNumericStarTreeValuesIterator ) starTreeValues
279+ .getMetricValuesIterator (metricName );
280+
227281 @ Override
228- public void collect (int doc , long bucket ) {
229- throw new CollectionTerminatedException ();
282+ public void collectStarTreeEntry (int starTreeEntryBit , long bucket ) throws IOException {
283+ growArrays .accept (bucket );
284+ // Advance the valuesIterator to the current bit
285+ if (!metricValuesIterator .advanceExact (starTreeEntryBit )) {
286+ return ; // Skip if no entries for this document
287+ }
288+ long metricValue = metricValuesIterator .nextValue ();
289+ updateBucket .accept (bucket , metricValue );
230290 }
231291 };
232292 }
@@ -240,7 +300,7 @@ public static FixedBitSet getStarTreeFilteredValues(SearchContext context, LeafR
240300 throws IOException {
241301 FixedBitSet result = context .getStarTreeQueryContext ().getStarTreeValues (ctx );
242302 if (result == null ) {
243- result = StarTreeFilter .getStarTreeResult (starTreeValues , context .getStarTreeQueryContext ().getQueryMap ());
303+ result = StarTreeFilter .getStarTreeResult (starTreeValues , context .getStarTreeQueryContext ().getQueryMap (), Set . of () );
244304 context .getStarTreeQueryContext ().setStarTreeValues (ctx , result );
245305 }
246306 return result ;
0 commit comments