diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java index 3f36ffa0f8e..cdd092addc1 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java @@ -51,7 +51,20 @@ public List> parse(Aggregations aggregations) { private Map parse(CompositeAggregation.Bucket bucket) { Map resultMap = new HashMap<>(); resultMap.putAll(bucket.getKey()); + + // Parse regular metric aggregations resultMap.putAll(metricsParser.parse(bucket.getAggregations())); + + // Handle DocCountParser for optimized count aggregations + for (MetricParser parser : metricsParser.getMetricParserList()) { + if (parser instanceof DocCountParser) { + DocCountParser docCountParser = (DocCountParser) parser; + Map bucketMap = new HashMap<>(); + bucketMap.put("doc_count", bucket.getDocCount()); + resultMap.putAll(docCountParser.parseBucket(bucketMap)); + } + } + return resultMap; } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/DocCountParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/DocCountParser.java new file mode 100644 index 00000000000..9a3ba24577e --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/DocCountParser.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.response.agg; + +import java.util.HashMap; +import java.util.Map; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.sql.data.model.ExprIntegerValue; + +/** + * Parser for extracting doc_count from bucket aggregations to optimize count() functions. + */ +public class DocCountParser implements MetricParser { + private final String name; + + public DocCountParser(String name) { + this.name = name; + } + + @Override + public Map parse(Aggregation aggregation) { + throw new UnsupportedOperationException( + "DocCountParser should be used with bucket context, not aggregations"); + } + + /** + * Parse doc_count from bucket map. + * + * @param bucket bucket map containing doc_count + * @return Map with the count value + */ + public Map parseBucket(Map bucket) { + Object docCount = bucket.get("doc_count"); + int count = (docCount instanceof Number) ? ((Number) docCount).intValue() : 0; + Map result = new HashMap<>(); + result.put(name, new ExprIntegerValue(count).value()); + return result; + } + + @Override + public String getName() { + return name; + } +} \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java index 527748c8077..97a2eb38be7 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java @@ -32,7 +32,10 @@ public class MetricParserHelper { private final Map metricParserMap; + private final List metricParserList; + public MetricParserHelper(List metricParserList) { + this.metricParserList = metricParserList; metricParserMap = metricParserList.stream().collect(Collectors.toMap(MetricParser::getName, m -> m)); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java index 562e38ae79e..7a718bed302 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java @@ -73,8 +73,11 @@ public AggregationQueryBuilder(ExpressionSerializer serializer) { List> sortList, boolean bucketNullable) { + // Check if we can optimize count aggregations with bucket doc_count + boolean optimizeCount = canOptimizeCount(namedAggregatorList, groupByList); + final Pair> metrics = - metricBuilder.build(namedAggregatorList); + metricBuilder.build(namedAggregatorList, optimizeCount); if (groupByList.isEmpty()) { // no bucket @@ -111,6 +114,24 @@ public AggregationQueryBuilder(ExpressionSerializer serializer) { } } + /** + * Check if count aggregations can be optimized by using bucket doc_count. + * This is possible when we have both count aggregations and span expressions (date histogram). + */ + private boolean canOptimizeCount( + List namedAggregatorList, List groupByList) { + // Check if there are any count aggregations + boolean hasCountAgg = namedAggregatorList.stream() + .anyMatch(agg -> "count".equalsIgnoreCase(agg.getFunctionName().getFunctionName()) + && !agg.getDelegated().distinct()); + + // Check if there are any span expressions (which create date histogram buckets) + boolean hasSpanExpr = groupByList.stream() + .anyMatch(expr -> expr.getDelegated() instanceof org.opensearch.sql.expression.span.SpanExpression); + + return hasCountAgg && hasSpanExpr; + } + /** Build mapping for OpenSearchExprValueFactory. */ public Map buildTypeMapping( List namedAggregatorList, List groupByList) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 9cb3f9824b8..2cfbbbacf33 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -50,9 +50,31 @@ public MetricAggregationBuilder(ExpressionSerializer serializer) { */ public Pair> build( List aggregatorList) { + return build(aggregatorList, false); + } + + /** + * Build AggregatorFactories.Builder from {@link NamedAggregator} with optimization context. + * + * @param aggregatorList aggregator list + * @param optimizeCount whether to optimize count aggregations by using bucket doc_count + * @return AggregatorFactories.Builder + */ + public Pair> build( + List aggregatorList, boolean optimizeCount) { AggregatorFactories.Builder builder = new AggregatorFactories.Builder(); List metricParserList = new ArrayList<>(); for (NamedAggregator aggregator : aggregatorList) { + String functionName = aggregator.getFunctionName().getFunctionName().toLowerCase(Locale.ROOT); + + // Skip count aggregations when optimization is enabled and it's count(*) or count(literal) + if (optimizeCount && "count".equals(functionName) && !aggregator.getDelegated().distinct() + && isCountStarOrLiteral(aggregator)) { + // Add a parser that will extract doc_count from bucket response + metricParserList.add(new DocCountParser(aggregator.getName())); + continue; + } + Pair pair = aggregator.accept(this, null); builder.addAggregator(pair.getLeft()); metricParserList.add(pair.getRight()); @@ -264,6 +286,15 @@ private Expression replaceStarOrLiteral(Expression countArg) { return countArg; } + /** + * Check if count aggregation is count(*) or count(literal) which can be optimized. + * Only these cases can use bucket doc_count instead of value_count. + */ + private boolean isCountStarOrLiteral(NamedAggregator aggregator) { + Expression arg = aggregator.getArguments().get(0); + return arg instanceof LiteralExpression; + } + /** * Make builder to build FilterAggregation for aggregations with filter in the bucket. * diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/DocCountParserTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/DocCountParserTest.java new file mode 100644 index 00000000000..422acd9bb9b --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/DocCountParserTest.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.response; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.opensearch.response.agg.DocCountParser; + +class DocCountParserTest { + + @Test + void should_parse_doc_count_from_bucket() { + DocCountParser parser = new DocCountParser("count_field"); + Map bucket = new HashMap<>(); + bucket.put("doc_count", 42); + + Map result = parser.parseBucket(bucket); + assertEquals(42, result.get("count_field")); + } + + @Test + void should_return_zero_when_doc_count_missing() { + DocCountParser parser = new DocCountParser("count_field"); + Map bucket = new HashMap<>(); + + Map result = parser.parseBucket(bucket); + assertEquals(0, result.get("count_field")); + } + + @Test + void should_return_name() { + DocCountParser parser = new DocCountParser("test_name"); + assertEquals("test_name", parser.getName()); + } + + @Test + void should_throw_exception_for_aggregations_parse() { + DocCountParser parser = new DocCountParser("count_field"); + + assertThrows(UnsupportedOperationException.class, () -> parser.parse((org.opensearch.search.aggregations.Aggregation) null)); + } +} \ No newline at end of file diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java index 124028c7fc8..8e7460409b6 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java @@ -587,6 +587,134 @@ void should_build_histogram_two_metrics() { false)); } + @Test + void should_optimize_count_with_date_histogram() { + // When count aggregation is used with date histogram (span), + // count should be optimized to use bucket doc_count instead of value_count + assertEquals( + format( + "{%n" + + " \"composite_buckets\" : {%n" + + " \"composite\" : {%n" + + " \"size\" : 1000,%n" + + " \"sources\" : [ {%n" + + " \"SpanExpression(field=timestamp, value=1, unit=H)\" : {%n" + + " \"date_histogram\" : {%n" + + " \"field\" : \"timestamp\",%n" + + " \"missing_bucket\" : true,%n" + + " \"missing_order\" : \"first\",%n" + + " \"order\" : \"asc\",%n" + + " \"fixed_interval\" : \"1h\"%n" + + " }%n" + + " }%n" + + " } ]%n" + + " }%n" + + " }%n" + + "}"), + buildQuery( + Arrays.asList( + named("count(*)", new CountAggregator(Arrays.asList(literal("*")), INTEGER))), + Arrays.asList(named(span(ref("timestamp", TIMESTAMP), literal(1), "h"))))); + } + + @Test + void should_not_optimize_count_field_with_date_histogram() { + // When count(field_name) is used with date histogram, + // it should NOT be optimized because it needs to count non-null values of that field + assertEquals( + format( + "{%n" + + " \"composite_buckets\" : {%n" + + " \"composite\" : {%n" + + " \"size\" : 1000,%n" + + " \"sources\" : [ {%n" + + " \"SpanExpression(field=timestamp, value=1, unit=H)\" : {%n" + + " \"date_histogram\" : {%n" + + " \"field\" : \"timestamp\",%n" + + " \"missing_bucket\" : true,%n" + + " \"missing_order\" : \"first\",%n" + + " \"order\" : \"asc\",%n" + + " \"fixed_interval\" : \"1h\"%n" + + " }%n" + + " }%n" + + " } ]%n" + + " },%n" + + " \"aggregations\" : {%n" + + " \"count(age)\" : {%n" + + " \"value_count\" : {%n" + + " \"field\" : \"age\"%n" + + " }%n" + + " }%n" + + " }%n" + + " }%n" + + "}"), + buildQuery( + Arrays.asList( + named("count(age)", new CountAggregator(Arrays.asList(ref("age", INTEGER)), INTEGER))), + Arrays.asList(named(span(ref("timestamp", TIMESTAMP), literal(1), "h"))))); + } + + @Test + void should_optimize_count_with_date_histogram_bucket_nullable_false() { + // When count aggregation is used with date histogram (span) and bucket_nullable is false, + // count should be optimized to use bucket doc_count instead of value_count + assertEquals( + format( + "{%n" + + " \"SpanExpression(field=timestamp, value=1, unit=H)\" : {%n" + + " \"date_histogram\" : {%n" + + " \"field\" : \"timestamp\",%n" + + " \"fixed_interval\" : \"1h\",%n" + + " \"offset\" : 0,%n" + + " \"order\" : {%n" + + " \"_key\" : \"asc\"%n" + + " },%n" + + " \"keyed\" : false,%n" + + " \"min_doc_count\" : 0%n" + + " }%n" + + " }%n" + + "}"), + buildQuery( + Arrays.asList( + named("count(*)", new CountAggregator(Arrays.asList(literal("*")), INTEGER))), + Arrays.asList(named(span(ref("timestamp", TIMESTAMP), literal(1), "h"))), + false)); + } + + @Test + void should_not_optimize_count_field_with_date_histogram_missing_bucket_false() { + // When count(field_name) is used with date histogram and missing_bucket is false, + // it should NOT be optimized because it needs to count non-null values of that field + assertEquals( + format( + "{%n" + + " \"SpanExpression(field=timestamp, value=1, unit=H)\" : {%n" + + " \"date_histogram\" : {%n" + + " \"field\" : \"timestamp\",%n" + + " \"fixed_interval\" : \"1h\",%n" + + " \"offset\" : 0,%n" + + " \"order\" : {%n" + + " \"_key\" : \"asc\"%n" + + " },%n" + + " \"keyed\" : false,%n" + + " \"min_doc_count\" : 0%n" + + " },%n" + + " \"aggregations\" : {%n" + + " \"count(age)\" : {%n" + + " \"value_count\" : {%n" + + " \"field\" : \"age\"%n" + + " }%n" + + " }%n" + + " }%n" + + " }%n" + + "}"), + buildQuery( + Arrays.asList( + named("count(age)", new CountAggregator(Arrays.asList(ref("age", INTEGER)), INTEGER))), + Arrays.asList(named(span(ref("timestamp", TIMESTAMP), literal(1), "h"))), + false)); + } + @Test void fixed_interval_time_span() { assertEquals( diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index 64ae7b187c2..99b3059f859 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -145,6 +145,37 @@ void should_build_count_other_literal_aggregation() { named("count(1)", new CountAggregator(Arrays.asList(literal(1)), INTEGER))))); } + @Test + void should_skip_count_aggregation_when_optimized() { + // When optimization is enabled, count aggregations should be skipped + // and DocCountParser should be used instead + assertEquals( + "{ }", + buildQueryWithOptimization( + Arrays.asList( + named("count(*)", new CountAggregator(Arrays.asList(literal("*")), INTEGER))), + true)); + } + + @Test + void should_not_skip_count_field_aggregation_when_optimized() { + // When optimization is enabled, count(field_name) should NOT be skipped + // because it needs to count non-null values of that specific field + assertEquals( + format( + "{%n" + + " \"count(age)\" : {%n" + + " \"value_count\" : {%n" + + " \"field\" : \"age\"%n" + + " }%n" + + " }%n" + + "}"), + buildQueryWithOptimization( + Arrays.asList( + named("count(age)", new CountAggregator(Arrays.asList(ref("age", INTEGER)), INTEGER))), + true)); + } + @Test void should_build_min_aggregation() { assertEquals( @@ -514,4 +545,12 @@ private String buildQuery(List namedAggregatorList) { .readTree(aggregationBuilder.build(namedAggregatorList).getLeft().toString()) .toPrettyString(); } + + @SneakyThrows + private String buildQueryWithOptimization(List namedAggregatorList, boolean optimizeCount) { + ObjectMapper objectMapper = new ObjectMapper(); + return objectMapper + .readTree(aggregationBuilder.build(namedAggregatorList, optimizeCount).getLeft().toString()) + .toPrettyString(); + } }