Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_NAME_MAIN;
import static org.opensearch.sql.calcite.utils.PlanUtils.ROW_NUMBER_COLUMN_NAME_SUBSEARCH;
import static org.opensearch.sql.calcite.utils.PlanUtils.getRelation;
import static org.opensearch.sql.calcite.utils.PlanUtils.getRexCall;
import static org.opensearch.sql.calcite.utils.PlanUtils.transformPlanToAttachChild;

import com.google.common.base.Strings;
Expand Down Expand Up @@ -53,6 +54,7 @@
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
Expand Down Expand Up @@ -813,6 +815,23 @@ private void projectPlusOverriding(
context.relBuilder.rename(expectedRenameFields);
}

private List<List<RexInputRef>> extractInputRefList(List<RelBuilder.AggCall> aggCalls) {
return aggCalls.stream()
.map(RelBuilder.AggCall::over)
.map(RelBuilder.OverCall::toRex)
.map(node -> getRexCall(node, this::isCountField))
.map(list -> list.isEmpty() ? null : list.getFirst())
.map(PlanUtils::getInputRefs)
.toList();
}

/** Is count(FIELD) */
private boolean isCountField(RexCall call) {
return call.isA(SqlKind.COUNT)
&& call.getOperands().size() == 1 // count(FIELD)
&& call.getOperands().get(0) instanceof RexInputRef;
}

/**
* Resolve the aggregation with trimming unused fields to avoid bugs in {@link
* org.apache.calcite.sql2rel.RelDecorrelator#decorrelateRel(Aggregate, boolean)}
Expand All @@ -826,6 +845,72 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
List<UnresolvedExpression> groupExprList,
List<UnresolvedExpression> aggExprList,
CalcitePlanContext context) {
Pair<List<RexNode>, List<AggCall>> resolved =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[non-blocking] We'd do this optimization in a separaete rule instead of here. Otherwise it will affect the basic logical plan by adding a redundant filter. We can do this change as follow-up

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#4390 opened

resolveAttributesForAggregation(groupExprList, aggExprList, context);
List<RexNode> resolvedGroupByList = resolved.getLeft();
List<AggCall> resolvedAggCallList = resolved.getRight();

// `doc_count` optimization required a filter `isNotNull(RexInputRef)` for the
// `count(FIELD)` aggregation which only can be applied to single FIELD without grouping:
//
// Example 1: source=t | stats count(a)
// Before: Aggregate(count(a))
// \- Scan t
// After: Aggregate(count(a))
// \- Filter(isNotNull(a))
// \- Scan t
//
// Example 2: source=t | stats count(a), count(a)
// Before: Aggregate(count(a), count(a))
// \- Scan t
// After: Aggregate(count(a), count(a))
// \- Filter(isNotNull(a))
// \- Scan t
//
// Example 3: source=t | stats count(a) by b
// Before & After: Aggregate(count(a) by b)
// \- Scan t
//
// Example 4: source=t | stats count()
// Before & After: Aggregate(count())
// \- Scan t
//
// Example 5: source=t | stats count(), count(a)
// Before & After: Aggregate(count(), count(a))
// \- Scan t
//
// Example 6: source=t | stats count(a), count(b)
// Before & After: Aggregate(count(a), count(b))
// \- Scan t
//
// Example 7: source=t | stats count(a+1)
// Before & After: Aggregate(count(a+1))
// \- Scan t
if (resolvedGroupByList.isEmpty()) {
List<List<RexInputRef>> refsPerCount = extractInputRefList(resolvedAggCallList);
List<RexInputRef> distinctRefsOfCounts;
if (context.relBuilder.peek() instanceof org.apache.calcite.rel.core.Project project) {
List<RexNode> mappedInProject =
refsPerCount.stream()
.flatMap(List::stream)
.map(ref -> project.getProjects().get(ref.getIndex()))
.toList();
if (mappedInProject.stream().allMatch(RexInputRef.class::isInstance)) {
distinctRefsOfCounts =
mappedInProject.stream().map(RexInputRef.class::cast).distinct().toList();
} else {
distinctRefsOfCounts = List.of();
}
} else {
distinctRefsOfCounts = refsPerCount.stream().flatMap(List::stream).distinct().toList();
}
if (distinctRefsOfCounts.size() == 1 && refsPerCount.stream().noneMatch(List::isEmpty)) {
context.relBuilder.filter(context.relBuilder.isNotNull(distinctRefsOfCounts.getFirst()));
}
}

// Add project before aggregate:
//
// Example 1: source=t | where a > 1 | stats avg(b + 1) by c
// Before: Aggregate(avg(b + 1))
// \- Filter(a > 1)
Expand All @@ -836,23 +921,22 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
// \- Scan t
//
// Example 2: source=t | where a > 1 | top b by c
// Before: Aggregate(count)
// \-Filter(a > 1)
// Before: Aggregate(count(b) by c)
// \-Filter(a > 1 && isNotNull(b))
// \- Scan t
// After: Aggregate(count)
// After: Aggregate(count(b) by c)
// \- Project([c, b])
// \- Filter(a > 1)
// \- Filter(a > 1 && isNotNull(b))
// \- Scan t
// Example 3: source=t | stats count(): no project added for count()
// Before: Aggregate(count)
//
// Example 3: source=t | stats count(): no change for count()
// Before: Aggregate(count())
// \- Scan t
// After: Aggregate(count)
// After: Aggregate(count())
// \- Scan t
Pair<List<RexNode>, List<AggCall>> resolved =
resolveAttributesForAggregation(groupExprList, aggExprList, context);
List<RexInputRef> trimmedRefs = new ArrayList<>();
trimmedRefs.addAll(PlanUtils.getInputRefs(resolved.getLeft())); // group-by keys first
trimmedRefs.addAll(PlanUtils.getInputRefsFromAggCall(resolved.getRight()));
trimmedRefs.addAll(PlanUtils.getInputRefs(resolvedGroupByList)); // group-by keys first
trimmedRefs.addAll(PlanUtils.getInputRefsFromAggCall(resolvedAggCallList));
context.relBuilder.project(trimmedRefs);

// Re-resolve all attributes based on adding trimmed Project.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.calcite.plan.RelOptTable;
Expand Down Expand Up @@ -255,6 +256,9 @@ static RelBuilder.AggCall makeAggCall(

/** Get all uniq input references from a RexNode. */
static List<RexInputRef> getInputRefs(RexNode node) {
if (node == null) {
return List.of();
}
List<RexInputRef> inputRefs = new ArrayList<>();
node.accept(
new RexVisitorImpl<Void>(true) {
Expand All @@ -274,6 +278,26 @@ static List<RexInputRef> getInputRefs(List<RexNode> nodes) {
return nodes.stream().flatMap(node -> getInputRefs(node).stream()).toList();
}

/** Get all uniq RexCall from RexNode with a predicate */
static List<RexCall> getRexCall(RexNode node, Predicate<RexCall> predicate) {
List<RexCall> list = new ArrayList<>();
node.accept(
new RexVisitorImpl<Void>(true) {
@Override
public Void visitCall(RexCall inputCall) {
if (predicate.test(inputCall)) {
if (!list.contains(inputCall)) {
list.add(inputCall);
}
} else {
inputCall.getOperands().forEach(call -> call.accept(this));
}
return null;
}
});
return list;
}

/** Get all uniq input references from a list of agg calls. */
static List<RexInputRef> getInputRefsFromAggCall(List<RelBuilder.AggCall> aggCalls) {
return aggCalls.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.sql.calcite.remote;

import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ACCOUNT;
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK;
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_LOGS;
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_NESTED_SIMPLE;
Expand Down Expand Up @@ -644,6 +645,143 @@ public void testExplainMinOnStringField() throws IOException {
explainQueryToString("source=opensearch-sql_test_index_account | stats min(firstname)"));
}

@Test
@Override
public void testCountAggPushDownExplain() throws IOException {
enabledOnlyWhenPushdownIsEnabled();
// should be optimized by hits.total.value
String expected = loadExpectedPlan("explain_count_agg_push1.yaml");
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString("source=opensearch-sql_test_index_account | stats count() as cnt"));

// should be optimized
expected = loadExpectedPlan("explain_count_agg_push2.yaml");
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString(
"source=opensearch-sql_test_index_account | stats count(lastname) as cnt"));

// should be optimized
expected = loadExpectedPlan("explain_count_agg_push3.yaml");
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString(
"source=opensearch-sql_test_index_account | eval name = lastname | stats count(name) as"
+ " cnt"));

// should be optimized
expected = loadExpectedPlan("explain_count_agg_push4.yaml");
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString(
"source=opensearch-sql_test_index_account | stats count() as c1, count() as c2"));

// should be optimized
expected = loadExpectedPlan("explain_count_agg_push5.yaml");
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString(
"source=opensearch-sql_test_index_account | stats count(lastname) as c1,"
+ " count(lastname) as c2"));

// should be optimized
expected = loadExpectedPlan("explain_count_agg_push6.yaml");
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString(
"source=opensearch-sql_test_index_account | eval name = lastname | stats"
+ " count(lastname), count(name)"));

// should not be optimized
expected = loadExpectedPlan("explain_count_agg_push7.yaml");
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString(
"source=opensearch-sql_test_index_account | stats count(balance + 1) as cnt"));

// should not be optimized
expected = loadExpectedPlan("explain_count_agg_push8.yaml");
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString(
"source=opensearch-sql_test_index_account | stats count() as c1, count(lastname) as"
+ " c2"));

// should not be optimized
expected = loadExpectedPlan("explain_count_agg_push9.yaml");
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString(
"source=opensearch-sql_test_index_account | stats count(firstname), count(lastname)"));

// should not be optimized
expected = loadExpectedPlan("explain_count_agg_push10.yaml");
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString(
"source=opensearch-sql_test_index_account | eval name = lastname | stats"
+ " count(firstname), count(name)"));
}

@Test
public void testExplainCountsByAgg() throws IOException {
enabledOnlyWhenPushdownIsEnabled();
String expected = loadExpectedPlan("explain_agg_counts_by1.yaml");
// case of only count(): doc_count works
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString(
String.format(
"source=%s | stats count(), count() as c1 by gender", TEST_INDEX_ACCOUNT)));

// count(FIELD) by: doc_count doesn't work
expected = loadExpectedPlan("explain_agg_counts_by2.yaml");
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString(
String.format(
"source=%s | stats count(balance) as c1, count(balance) as c2 by gender",
TEST_INDEX_ACCOUNT)));

// count(FIELD) by: doc_count doesn't work
expected = loadExpectedPlan("explain_agg_counts_by3.yaml");
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString(
String.format(
"source=%s | eval account_number_alias = account_number"
+ " | stats count(account_number), count(account_number_alias) as c2 by gender",
TEST_INDEX_ACCOUNT)));

// count() + count(FIELD)): doc_count doesn't work
expected = loadExpectedPlan("explain_agg_counts_by4.yaml");
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString(
String.format(
"source=%s | stats count(), count(account_number) by gender", TEST_INDEX_ACCOUNT)));

// count(FIELD1) + count(FIELD2)) by: doc_count doesn't work
expected = loadExpectedPlan("explain_agg_counts_by5.yaml");
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString(
String.format(
"source=%s | stats count(balance), count(account_number) by gender",
TEST_INDEX_ACCOUNT)));

// case of count(EXPRESSION) by: doc_count doesn't work
expected = loadExpectedPlan("explain_agg_counts_by6.yaml");
assertYamlEqualsJsonIgnoreId(
expected,
explainQueryToString(
String.format(
"source=%s | eval b_1 = balance + 1"
+ " | stats count(b_1), count(pow(balance, 2)) as c3 by gender",
TEST_INDEX_ACCOUNT)));
}

@Test
public void testExplainSortOnMetricsNoBucketNullable() throws IOException {
// TODO enhancement later: https://github.com/opensearch-project/sql/issues/4282
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -772,4 +772,29 @@ public void testStatsBySpanTimeWithNullBucket() throws IOException {
rows(8213, "2025-07-31 00:00:00"),
rows(8490, "2025-07-31 12:00:00"));
}

@Test
public void testStatsByCounts() throws IOException {
JSONObject response =
executeQuery(
String.format(
"source=%s | eval b_1 = balance + 1 | stats count(), count() as c1,"
+ " count(account_number), count(lastname) as c2, count(balance/10),"
+ " count(pow(balance, 2)) as c3, count(b_1) by gender",
TEST_INDEX_ACCOUNT));
verifySchema(
response,
schema("count()", null, isCalciteEnabled() ? "bigint" : "int"),
schema("c1", null, isCalciteEnabled() ? "bigint" : "int"),
schema("count(account_number)", null, isCalciteEnabled() ? "bigint" : "int"),
schema("c2", null, isCalciteEnabled() ? "bigint" : "int"),
schema("count(balance/10)", null, isCalciteEnabled() ? "bigint" : "int"),
schema("c3", null, isCalciteEnabled() ? "bigint" : "int"),
schema("count(b_1)", null, isCalciteEnabled() ? "bigint" : "int"),
schema("gender", null, "string"));
verifyDataRows(
response,
rows(493, 493, 493, 493, 493, 493, 493, "F"),
rows(507, 507, 507, 507, 507, 507, 507, "M"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
calcite:
logical: |
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalProject(count()=[$1], c1=[$1], gender=[$0])
LogicalAggregate(group=[{0}], count()=[COUNT()])
LogicalProject(gender=[$4])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])
physical: |
EnumerableCalc(expr#0..1=[{inputs}], count()=[$t0], count()0=[$t0], gender=[$t1])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},count()=COUNT()), LIMIT->10000, PROJECT->[count(), gender]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"gender":{"terms":{"field":"gender.keyword","missing_bucket":true,"missing_order":"first","order":"asc"}}}]}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
calcite:
logical: |
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalProject(c1=[$1], c2=[$1], gender=[$0])
LogicalAggregate(group=[{0}], c1=[COUNT($1)])
LogicalProject(gender=[$4], balance=[$3])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])
physical: |
EnumerableCalc(expr#0..1=[{inputs}], c1=[$t0], c10=[$t0], gender=[$t1])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},c1=COUNT($1)), LIMIT->10000, PROJECT->[c1, gender]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"gender":{"terms":{"field":"gender.keyword","missing_bucket":true,"missing_order":"first","order":"asc"}}}]},"aggregations":{"c1":{"value_count":{"field":"balance"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Loading
Loading