diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index 44ee4d90386..1c133cb99cd 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -560,6 +560,11 @@ private Pair, List> aggregateWithTrimming( // \- Project([c, b]) // \- Filter(a > 1) // \- Scan t + // Example 3: source=t | stats count(): no project added for count() + // Before: Aggregate(count) + // \- Scan t + // After: Aggregate(count) + // \- Scan t Pair, List> resolved = resolveAttributesForAggregation(groupExprList, aggExprList, context); List trimmedRefs = new ArrayList<>(); diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java index 1015a12a0a9..509ab22132f 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java @@ -97,6 +97,14 @@ public void testFilterAndAggPushDownExplain() throws IOException { + "| stats avg(age) AS avg_age by state, city")); } + @Test + public void testCountAggPushDownExplain() throws IOException { + String expected = loadExpectedPlan("explain_count_agg_push.json"); + assertJsonEqualsIgnoreId( + expected, + explainQueryToString("source=opensearch-sql_test_index_account | stats count() as cnt")); + } + @Test public void testSortPushDownExplain() throws IOException { String expected = loadExpectedPlan("explain_sort_push.json"); diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_count_agg_push.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_count_agg_push.json new file mode 100644 index 00000000000..be28c90d0cf --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_count_agg_push.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalAggregate(group=[{}], cnt=[COUNT()])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "physical": "CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={},cnt=COUNT())], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"aggregations\":{\"cnt\":{\"value_count\":{\"field\":\"_index\"}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n" + } +} diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_count_agg_push.json b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_count_agg_push.json new file mode 100644 index 00000000000..18311e50af4 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_count_agg_push.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalAggregate(group=[{}], cnt=[COUNT()])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "physical": "EnumerableAggregate(group=[{}], cnt=[COUNT()])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n" + } +} diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_count_agg_push.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_count_agg_push.json new file mode 100644 index 00000000000..0d302725e1b --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_count_agg_push.json @@ -0,0 +1,15 @@ +{ + "root": { + "name": "ProjectOperator", + "description": { + "fields": "[cnt]" + }, + "children": [{ + "name": "OpenSearchIndexScan", + "description": { + "request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_account, sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"aggregations\":{\"cnt\":{\"value_count\":{\"field\":\"_index\"}}}}, needClean=true, searchDone=false, pitId=*, cursorKeepAlive=null, searchAfter=null, searchResponse=null)" + }, + "children": [] + }] + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java index d54c924db47..35c171ab0f7 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java @@ -9,6 +9,7 @@ import org.apache.calcite.plan.RelRule; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.sql.SqlKind; import org.immutables.value.Value; import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan; @@ -29,6 +30,11 @@ public void onMatch(RelOptRuleCall call) { final LogicalProject project = call.rel(1); final CalciteLogicalIndexScan scan = call.rel(2); apply(call, aggregate, project, scan); + } else if (call.rels.length == 2) { + // case of count() without group-by + final LogicalAggregate aggregate = call.rel(0); + final CalciteLogicalIndexScan scan = call.rel(1); + apply(call, aggregate, null, scan); } else { throw new AssertionError( String.format( @@ -54,6 +60,7 @@ public interface Config extends RelRule.Config { Config DEFAULT = ImmutableOpenSearchAggregateIndexScanRule.Config.builder() .build() + .withDescription("Agg-Project-TableScan") .withOperandSupplier( b0 -> b0.operand(LogicalAggregate.class) @@ -71,6 +78,28 @@ public interface Config extends RelRule.Config { OpenSearchIndexScanRule ::noAggregatePushed)) .noInputs()))); + Config COUNT_STAR = + ImmutableOpenSearchAggregateIndexScanRule.Config.builder() + .build() + .withDescription("Agg[count()]-TableScan") + .withOperandSupplier( + b0 -> + b0.operand(LogicalAggregate.class) + .predicate( + agg -> + agg.getGroupSet().isEmpty() + && agg.getAggCallList().stream() + .allMatch( + call -> + call.getAggregation().kind == SqlKind.COUNT + && call.getArgList().isEmpty())) + .oneInput( + b1 -> + b1.operand(CalciteLogicalIndexScan.class) + .predicate( + Predicate.not(OpenSearchIndexScanRule::isLimitPushed) + .and(OpenSearchIndexScanRule::noAggregatePushed)) + .noInputs())); @Override default OpenSearchAggregateIndexScanRule toRule() { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java index e67c041383d..05f3f6cc4f9 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchIndexRules.java @@ -16,6 +16,8 @@ public class OpenSearchIndexRules { OpenSearchFilterIndexScanRule.Config.DEFAULT.toRule(); private static final OpenSearchAggregateIndexScanRule AGGREGATE_INDEX_SCAN = OpenSearchAggregateIndexScanRule.Config.DEFAULT.toRule(); + private static final OpenSearchAggregateIndexScanRule COUNT_STAR_INDEX_SCAN = + OpenSearchAggregateIndexScanRule.Config.COUNT_STAR.toRule(); private static final OpenSearchLimitIndexScanRule LIMIT_INDEX_SCAN = OpenSearchLimitIndexScanRule.Config.DEFAULT.toRule(); private static final OpenSearchSortIndexScanRule SORT_INDEX_SCAN = @@ -26,6 +28,7 @@ public class OpenSearchIndexRules { PROJECT_INDEX_SCAN, FILTER_INDEX_SCAN, AGGREGATE_INDEX_SCAN, + COUNT_STAR_INDEX_SCAN, LIMIT_INDEX_SCAN, SORT_INDEX_SCAN);