Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,20 @@ public List<Map<String, Object>> parse(Aggregations aggregations) {
private Map<String, Object> parse(CompositeAggregation.Bucket bucket) {
Map<String, Object> resultMap = new HashMap<>();
resultMap.putAll(bucket.getKey());

// Parse regular metric aggregations
resultMap.putAll(metricsParser.parse(bucket.getAggregations()));

// Handle DocCountParser for optimized count aggregations
for (MetricParser parser : metricsParser.getMetricParserList()) {
if (parser instanceof DocCountParser) {
DocCountParser docCountParser = (DocCountParser) parser;
Map<String, Object> bucketMap = new HashMap<>();
bucketMap.put("doc_count", bucket.getDocCount());
resultMap.putAll(docCountParser.parseBucket(bucketMap));
}
}

return resultMap;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.opensearch.response.agg;

import java.util.HashMap;
import java.util.Map;
import org.opensearch.search.aggregations.Aggregation;
import org.opensearch.sql.data.model.ExprIntegerValue;

/**
* Parser for extracting doc_count from bucket aggregations to optimize count() functions.
*/
public class DocCountParser implements MetricParser {
private final String name;

public DocCountParser(String name) {
this.name = name;
}

@Override
public Map<String, Object> parse(Aggregation aggregation) {
throw new UnsupportedOperationException(
"DocCountParser should be used with bucket context, not aggregations");
}

/**
* Parse doc_count from bucket map.
*
* @param bucket bucket map containing doc_count
* @return Map with the count value
*/
public Map<String, Object> parseBucket(Map<String, Object> bucket) {
Object docCount = bucket.get("doc_count");
int count = (docCount instanceof Number) ? ((Number) docCount).intValue() : 0;
Map<String, Object> result = new HashMap<>();
result.put(name, new ExprIntegerValue(count).value());
return result;
}

@Override
public String getName() {
return name;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ public class MetricParserHelper {

private final Map<String, MetricParser> metricParserMap;

private final List<MetricParser> metricParserList;

public MetricParserHelper(List<MetricParser> metricParserList) {
this.metricParserList = metricParserList;
metricParserMap =
metricParserList.stream().collect(Collectors.toMap(MetricParser::getName, m -> m));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ public AggregationQueryBuilder(ExpressionSerializer serializer) {
List<Pair<Sort.SortOption, Expression>> sortList,
boolean bucketNullable) {

// Check if we can optimize count aggregations with bucket doc_count
boolean optimizeCount = canOptimizeCount(namedAggregatorList, groupByList);

final Pair<AggregatorFactories.Builder, List<MetricParser>> metrics =
metricBuilder.build(namedAggregatorList);
metricBuilder.build(namedAggregatorList, optimizeCount);

if (groupByList.isEmpty()) {
// no bucket
Expand Down Expand Up @@ -111,6 +114,24 @@ public AggregationQueryBuilder(ExpressionSerializer serializer) {
}
}

/**
* Check if count aggregations can be optimized by using bucket doc_count.
* This is possible when we have both count aggregations and span expressions (date histogram).
*/
private boolean canOptimizeCount(
List<NamedAggregator> namedAggregatorList, List<NamedExpression> groupByList) {
// Check if there are any count aggregations
boolean hasCountAgg = namedAggregatorList.stream()
.anyMatch(agg -> "count".equalsIgnoreCase(agg.getFunctionName().getFunctionName())
&& !agg.getDelegated().distinct());

// Check if there are any span expressions (which create date histogram buckets)
boolean hasSpanExpr = groupByList.stream()
.anyMatch(expr -> expr.getDelegated() instanceof org.opensearch.sql.expression.span.SpanExpression);

return hasCountAgg && hasSpanExpr;
}

/** Build mapping for OpenSearchExprValueFactory. */
public Map<String, OpenSearchDataType> buildTypeMapping(
List<NamedAggregator> namedAggregatorList, List<NamedExpression> groupByList) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,31 @@ public MetricAggregationBuilder(ExpressionSerializer serializer) {
*/
public Pair<AggregatorFactories.Builder, List<MetricParser>> build(
List<NamedAggregator> aggregatorList) {
return build(aggregatorList, false);
}

/**
* Build AggregatorFactories.Builder from {@link NamedAggregator} with optimization context.
*
* @param aggregatorList aggregator list
* @param optimizeCount whether to optimize count aggregations by using bucket doc_count
* @return AggregatorFactories.Builder
*/
public Pair<AggregatorFactories.Builder, List<MetricParser>> build(
List<NamedAggregator> aggregatorList, boolean optimizeCount) {
AggregatorFactories.Builder builder = new AggregatorFactories.Builder();
List<MetricParser> metricParserList = new ArrayList<>();
for (NamedAggregator aggregator : aggregatorList) {
String functionName = aggregator.getFunctionName().getFunctionName().toLowerCase(Locale.ROOT);

// Skip count aggregations when optimization is enabled and it's count(*) or count(literal)
if (optimizeCount && "count".equals(functionName) && !aggregator.getDelegated().distinct()
&& isCountStarOrLiteral(aggregator)) {
// Add a parser that will extract doc_count from bucket response
metricParserList.add(new DocCountParser(aggregator.getName()));
continue;
}

Pair<AggregationBuilder, MetricParser> pair = aggregator.accept(this, null);
builder.addAggregator(pair.getLeft());
metricParserList.add(pair.getRight());
Expand Down Expand Up @@ -264,6 +286,15 @@ private Expression replaceStarOrLiteral(Expression countArg) {
return countArg;
}

/**
* Check if count aggregation is count(*) or count(literal) which can be optimized.
* Only these cases can use bucket doc_count instead of value_count.
*/
private boolean isCountStarOrLiteral(NamedAggregator aggregator) {
Expression arg = aggregator.getArguments().get(0);
return arg instanceof LiteralExpression;
}

/**
* Make builder to build FilterAggregation for aggregations with filter in the bucket.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.opensearch.response;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.opensearch.sql.data.model.ExprIntegerValue;
import org.opensearch.sql.opensearch.response.agg.DocCountParser;

class DocCountParserTest {

@Test
void should_parse_doc_count_from_bucket() {
DocCountParser parser = new DocCountParser("count_field");
Map<String, Object> bucket = new HashMap<>();
bucket.put("doc_count", 42);

Map<String, Object> result = parser.parseBucket(bucket);
assertEquals(42, result.get("count_field"));
}

@Test
void should_return_zero_when_doc_count_missing() {
DocCountParser parser = new DocCountParser("count_field");
Map<String, Object> bucket = new HashMap<>();

Map<String, Object> result = parser.parseBucket(bucket);
assertEquals(0, result.get("count_field"));
}

@Test
void should_return_name() {
DocCountParser parser = new DocCountParser("test_name");
assertEquals("test_name", parser.getName());
}

@Test
void should_throw_exception_for_aggregations_parse() {
DocCountParser parser = new DocCountParser("count_field");

assertThrows(UnsupportedOperationException.class, () -> parser.parse((org.opensearch.search.aggregations.Aggregation) null));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,134 @@ void should_build_histogram_two_metrics() {
false));
}

@Test
void should_optimize_count_with_date_histogram() {
// When count aggregation is used with date histogram (span),
// count should be optimized to use bucket doc_count instead of value_count
assertEquals(
format(
"{%n"
+ " \"composite_buckets\" : {%n"
+ " \"composite\" : {%n"
+ " \"size\" : 1000,%n"
+ " \"sources\" : [ {%n"
+ " \"SpanExpression(field=timestamp, value=1, unit=H)\" : {%n"
+ " \"date_histogram\" : {%n"
+ " \"field\" : \"timestamp\",%n"
+ " \"missing_bucket\" : true,%n"
+ " \"missing_order\" : \"first\",%n"
+ " \"order\" : \"asc\",%n"
+ " \"fixed_interval\" : \"1h\"%n"
+ " }%n"
+ " }%n"
+ " } ]%n"
+ " }%n"
+ " }%n"
+ "}"),
buildQuery(
Arrays.asList(
named("count(*)", new CountAggregator(Arrays.asList(literal("*")), INTEGER))),
Arrays.asList(named(span(ref("timestamp", TIMESTAMP), literal(1), "h")))));
}

@Test
void should_not_optimize_count_field_with_date_histogram() {
// When count(field_name) is used with date histogram,
// it should NOT be optimized because it needs to count non-null values of that field
assertEquals(
format(
"{%n"
+ " \"composite_buckets\" : {%n"
+ " \"composite\" : {%n"
+ " \"size\" : 1000,%n"
+ " \"sources\" : [ {%n"
+ " \"SpanExpression(field=timestamp, value=1, unit=H)\" : {%n"
+ " \"date_histogram\" : {%n"
+ " \"field\" : \"timestamp\",%n"
+ " \"missing_bucket\" : true,%n"
+ " \"missing_order\" : \"first\",%n"
+ " \"order\" : \"asc\",%n"
+ " \"fixed_interval\" : \"1h\"%n"
+ " }%n"
+ " }%n"
+ " } ]%n"
+ " },%n"
+ " \"aggregations\" : {%n"
+ " \"count(age)\" : {%n"
+ " \"value_count\" : {%n"
+ " \"field\" : \"age\"%n"
+ " }%n"
+ " }%n"
+ " }%n"
+ " }%n"
+ "}"),
buildQuery(
Arrays.asList(
named("count(age)", new CountAggregator(Arrays.asList(ref("age", INTEGER)), INTEGER))),
Arrays.asList(named(span(ref("timestamp", TIMESTAMP), literal(1), "h")))));
}

@Test
void should_optimize_count_with_date_histogram_bucket_nullable_false() {
// When count aggregation is used with date histogram (span) and bucket_nullable is false,
// count should be optimized to use bucket doc_count instead of value_count
assertEquals(
format(
"{%n"
+ " \"SpanExpression(field=timestamp, value=1, unit=H)\" : {%n"
+ " \"date_histogram\" : {%n"
+ " \"field\" : \"timestamp\",%n"
+ " \"fixed_interval\" : \"1h\",%n"
+ " \"offset\" : 0,%n"
+ " \"order\" : {%n"
+ " \"_key\" : \"asc\"%n"
+ " },%n"
+ " \"keyed\" : false,%n"
+ " \"min_doc_count\" : 0%n"
+ " }%n"
+ " }%n"
+ "}"),
buildQuery(
Arrays.asList(
named("count(*)", new CountAggregator(Arrays.asList(literal("*")), INTEGER))),
Arrays.asList(named(span(ref("timestamp", TIMESTAMP), literal(1), "h"))),
false));
}

@Test
void should_not_optimize_count_field_with_date_histogram_missing_bucket_false() {
// When count(field_name) is used with date histogram and missing_bucket is false,
// it should NOT be optimized because it needs to count non-null values of that field
assertEquals(
format(
"{%n"
+ " \"SpanExpression(field=timestamp, value=1, unit=H)\" : {%n"
+ " \"date_histogram\" : {%n"
+ " \"field\" : \"timestamp\",%n"
+ " \"fixed_interval\" : \"1h\",%n"
+ " \"offset\" : 0,%n"
+ " \"order\" : {%n"
+ " \"_key\" : \"asc\"%n"
+ " },%n"
+ " \"keyed\" : false,%n"
+ " \"min_doc_count\" : 0%n"
+ " },%n"
+ " \"aggregations\" : {%n"
+ " \"count(age)\" : {%n"
+ " \"value_count\" : {%n"
+ " \"field\" : \"age\"%n"
+ " }%n"
+ " }%n"
+ " }%n"
+ " }%n"
+ "}"),
buildQuery(
Arrays.asList(
named("count(age)", new CountAggregator(Arrays.asList(ref("age", INTEGER)), INTEGER))),
Arrays.asList(named(span(ref("timestamp", TIMESTAMP), literal(1), "h"))),
false));
}

@Test
void fixed_interval_time_span() {
assertEquals(
Expand Down
Loading
Loading