diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index cfc270fb6b8..238757626d4 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -80,25 +80,37 @@ COUNT Description >>>>>>>>>>> -Usage: Returns a count of the number of expr in the rows retrieved by a SELECT statement. The ``C()`` function can be used as an abbreviation for ``COUNT()``. +Usage: Returns a count of the number of expr in the rows retrieved. The ``C()`` function can be used as an abbreviation for ``COUNT()``. To perform a filtered counting, wrap the condition to satisfy in an `eval` expression. Example:: - os> source=accounts | stats count(); + os> source=accounts | stats count(), c(); fetched rows / total rows = 1/1 - +---------+ - | count() | - |---------| - | 4 | - +---------+ + +---------+-----+ + | count() | c() | + |---------+-----| + | 4 | 4 | + +---------+-----+ + +Example of filtered counting:: + + os> source=accounts | stats count(eval(age > 30)) as mature_users; + fetched rows / total rows = 1/1 + +--------------+ + | mature_users | + |--------------| + | 3 | + +--------------+ + +Example of filtered counting with complex conditions:: - os> source=accounts | stats c(); + os> source=accounts | stats count(eval(age > 30 and balance > 25000)) as high_value_users; fetched rows / total rows = 1/1 - +-----+ - | c() | - |-----| - | 4 | - +-----+ + +------------------+ + | high_value_users | + |------------------| + | 1 | + +------------------+ SUM --- diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAggregationIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAggregationIT.java index 9d23ff4b542..b18fd5a7e6d 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAggregationIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAggregationIT.java @@ -816,6 +816,72 @@ public void testAggByByteNumberWithScript() throws IOException { verifyDataRows(response, rows(1, 4)); } + @Test + public void testCountEvalSimpleCondition() throws IOException { + JSONObject actual = + executeQuery( + String.format("source=%s | stats count(eval(age > 30)) as c", TEST_INDEX_BANK)); + verifySchema(actual, schema("c", "bigint")); + verifyDataRows(actual, rows(6)); + } + + @Test + public void testCountEvalComplexCondition() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | stats count(eval(balance > 20000 and age < 35)) as c", + TEST_INDEX_BANK)); + verifySchema(actual, schema("c", "bigint")); + verifyDataRows(actual, rows(3)); + } + + @Test + public void testCountEvalGroupBy() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | stats count(eval(balance > 25000)) as high_balance by gender", + TEST_INDEX_BANK)); + verifySchema(actual, schema("gender", "string"), schema("high_balance", "bigint")); + verifyDataRows(actual, rows(3, "F"), rows(1, "M")); + } + + @Test + public void testCountEvalWithMultipleAggregations() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | stats count(eval(age > 30)) as mature_count, " + + "count(eval(balance > 25000)) as high_balance_count, " + + "count() as total_count", + TEST_INDEX_BANK)); + verifySchema( + actual, + schema("mature_count", "bigint"), + schema("high_balance_count", "bigint"), + schema("total_count", "bigint")); + verifyDataRows(actual, rows(6, 4, 7)); + } + + @Test + public void testShortcutCEvalSimpleCondition() throws IOException { + JSONObject actual = + executeQuery(String.format("source=%s | stats c(eval(age > 30)) as c", TEST_INDEX_BANK)); + verifySchema(actual, schema("c", "bigint")); + verifyDataRows(actual, rows(6)); + } + + @Test + public void testShortcutCEvalComplexCondition() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | stats c(eval(balance > 20000 and age < 35)) as c", TEST_INDEX_BANK)); + verifySchema(actual, schema("c", "bigint")); + verifyDataRows(actual, rows(3)); + } + @Test public void testPercentileShortcuts() throws IOException { JSONObject actual = diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index c52b7fd1e98..06741c62a85 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -429,6 +429,7 @@ statsAggTerm // aggregation functions statsFunction : statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall + | (COUNT | C) LT_PRTHS evalExpression RT_PRTHS # countEvalFunctionCall | (COUNT | C) LT_PRTHS RT_PRTHS # countAllFunctionCall | PERCENTILE_SHORTCUT LT_PRTHS valueExpression RT_PRTHS # percentileShortcutFunctionCall | (DISTINCT_COUNT | DC | DISTINCT_COUNT_APPROX) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall @@ -501,6 +502,10 @@ valueExpression | LT_PRTHS logicalExpression RT_PRTHS # nestedValueExpr ; +evalExpression + : EVAL LT_PRTHS logicalExpression RT_PRTHS + ; + functionCall : evalFunctionCall | dataTypeFunctionCall diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index f909b57164f..851bba30615 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -6,6 +6,39 @@ package org.opensearch.sql.ppl.parser; import static org.opensearch.sql.expression.function.BuiltinFunctionName.*; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BinaryArithmeticContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BySpanClauseContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ConvertedDataTypeContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountAllFunctionCallContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountEvalFunctionCallContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DataTypeFunctionCallContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DistinctCountFunctionCallContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DoubleLiteralContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalClauseContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalFunctionCallContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldExpressionContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FloatLiteralContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IdentsAsQualifiedNameContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IdentsAsTableQualifiedNameContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IdentsAsWildcardQualifiedNameContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.InExprContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntegerLiteralContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntervalLiteralContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalAndContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalNotContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalOrContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalXorContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.MultiFieldRelevanceFunctionContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SingleFieldRelevanceFunctionContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SpanClauseContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StatsFunctionCallContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StringLiteralContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TableSourceContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.WcFieldExpressionContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -41,6 +74,7 @@ import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DistinctCountFunctionCallContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DoubleLiteralContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalClauseContext; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalExpressionContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalFunctionCallContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldExpressionContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FloatLiteralContext; @@ -229,12 +263,28 @@ public UnresolvedExpression visitCountAllFunctionCall(CountAllFunctionCallContex return new AggregateFunction("count", AllFields.of()); } + @Override + public UnresolvedExpression visitCountEvalFunctionCall(CountEvalFunctionCallContext ctx) { + return new AggregateFunction("count", visit(ctx.evalExpression())); + } + @Override public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunctionCallContext ctx) { String funcName = ctx.DISTINCT_COUNT_APPROX() != null ? "distinct_count_approx" : "count"; return new AggregateFunction(funcName, visit(ctx.valueExpression()), true); } + @Override + public UnresolvedExpression visitEvalExpression(EvalExpressionContext ctx) { + /* + * Rewrite "eval(p)" as "CASE WHEN p THEN 1 ELSE NULL END" so that COUNT or DISTINCT_COUNT + * can correctly perform filtered counting. + * Note: at present only eval() inside counting functions is supported. + */ + UnresolvedExpression predicate = visit(ctx.logicalExpression()); + return AstDSL.caseWhen(null, AstDSL.when(predicate, AstDSL.intLiteral(1))); + } + @Override public UnresolvedExpression visitPercentileApproxFunctionCall( OpenSearchPPLParser.PercentileApproxFunctionCallContext ctx) { diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java index 2e70a210d6e..f40e90ff0e6 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java @@ -100,6 +100,37 @@ private Node plan(PPLSyntaxParser parser, String query) { return builder.visit(parser.parse(query)); } + /** + * Fluent API for building count(eval) test cases. Provides a clean and readable way to define PPL + * queries and their expected outcomes. + */ + protected PPLQueryTestBuilder withPPLQuery(String ppl) { + return new PPLQueryTestBuilder(ppl); + } + + protected class PPLQueryTestBuilder { + private final RelNode relNode; + + public PPLQueryTestBuilder(String ppl) { + this.relNode = getRelNode(ppl); + } + + public PPLQueryTestBuilder expectLogical(String expectedLogical) { + verifyLogical(relNode, expectedLogical); + return this; + } + + public PPLQueryTestBuilder expectResult(String expectedResult) { + verifyResult(relNode, expectedResult); + return this; + } + + public PPLQueryTestBuilder expectSparkSQL(String expectedSparkSql) { + verifyPPLToSparkSQL(relNode, expectedSparkSql); + return this; + } + } + /** Verify the logical plan of the given RelNode */ public void verifyLogical(RelNode rel, String expectedLogical) { assertThat(rel, hasTree(expectedLogical)); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCountEvalTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCountEvalTest.java new file mode 100644 index 00000000000..91d9a8de09d --- /dev/null +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCountEvalTest.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.calcite; + +import org.apache.calcite.test.CalciteAssert; +import org.junit.Test; + +/** + * Unit tests for count(eval) functionality in CalcitePPL engine. Tests various scenarios of + * filtered count aggregations. + */ +public class CalcitePPLCountEvalTest extends CalcitePPLAbstractTest { + + public CalcitePPLCountEvalTest() { + super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL); + } + + @Test + public void testCountEvalSimpleCondition() { + withPPLQuery("source=EMP | stats count(eval(SAL > 2000)) as c") + .expectLogical( + "LogicalAggregate(group=[{}], c=[COUNT($0)])\n" + + " LogicalProject($f1=[CASE(>($5, 2000), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n") + .expectResult("c=6\n") + .expectSparkSQL( + "SELECT COUNT(CASE WHEN `SAL` > 2000 THEN 1 ELSE NULL END) `c`\nFROM `scott`.`EMP`"); + } + + @Test + public void testCountEvalComplexCondition() { + withPPLQuery("source=EMP | stats count(eval(SAL > 2000 and DEPTNO < 30)) as c") + .expectLogical( + "LogicalAggregate(group=[{}], c=[COUNT($0)])\n" + + " LogicalProject($f2=[CASE(AND(>($5, 2000), <($7, 30)), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n") + .expectResult("c=5\n") + .expectSparkSQL( + "SELECT COUNT(CASE WHEN `SAL` > 2000 AND `DEPTNO` < 30 THEN 1 ELSE NULL END) `c`\n" + + "FROM `scott`.`EMP`"); + } + + @Test + public void testCountEvalStringComparison() { + withPPLQuery("source=EMP | stats count(eval(JOB = 'MANAGER')) as manager_count") + .expectLogical( + "LogicalAggregate(group=[{}], manager_count=[COUNT($0)])\n" + + " LogicalProject($f1=[CASE(=($2, 'MANAGER':VARCHAR), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n") + .expectResult("manager_count=3\n") + .expectSparkSQL( + "SELECT COUNT(CASE WHEN `JOB` = 'MANAGER' THEN 1 ELSE NULL END) `manager_count`\n" + + "FROM `scott`.`EMP`"); + } + + @Test + public void testCountEvalArithmeticExpression() { + withPPLQuery("source=EMP | stats count(eval(SAL / COMM > 10)) as high_ratio") + .expectLogical( + "LogicalAggregate(group=[{}], high_ratio=[COUNT($0)])\n" + + " LogicalProject($f2=[CASE(>(DIVIDE($5, $6), 10), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n") + .expectResult("high_ratio=0\n") + .expectSparkSQL( + "SELECT COUNT(CASE WHEN `DIVIDE`(`SAL`, `COMM`) > 10 THEN 1 ELSE NULL END)" + + " `high_ratio`\n" + + "FROM `scott`.`EMP`"); + } + + @Test + public void testCountEvalWithNullHandling() { + withPPLQuery("source=EMP | stats count(eval(isnotnull(MGR))) as non_null_mgr") + .expectLogical( + "LogicalAggregate(group=[{}], non_null_mgr=[COUNT($0)])\n" + + " LogicalProject($f1=[CASE(IS NOT NULL($3), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n") + .expectResult("non_null_mgr=13\n") + .expectSparkSQL( + "SELECT COUNT(CASE WHEN `MGR` IS NOT NULL THEN 1 ELSE NULL END) `non_null_mgr`\n" + + "FROM `scott`.`EMP`"); + } + + @Test + public void testShortcutCEvalSimpleCondition() { + withPPLQuery("source=EMP | stats c(eval(SAL > 2000)) as c") + .expectLogical( + "LogicalAggregate(group=[{}], c=[COUNT($0)])\n" + + " LogicalProject($f1=[CASE(>($5, 2000), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n") + .expectResult("c=6\n") + .expectSparkSQL( + "SELECT COUNT(CASE WHEN `SAL` > 2000 THEN 1 ELSE NULL END) `c`\nFROM `scott`.`EMP`"); + } + + @Test + public void testShortcutCEvalComplexCondition() { + withPPLQuery("source=EMP | stats c(eval(JOB = 'MANAGER')) as manager_count") + .expectLogical( + "LogicalAggregate(group=[{}], manager_count=[COUNT($0)])\n" + + " LogicalProject($f1=[CASE(=($2, 'MANAGER':VARCHAR), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n") + .expectResult("manager_count=3\n") + .expectSparkSQL( + "SELECT COUNT(CASE WHEN `JOB` = 'MANAGER' THEN 1 ELSE NULL END) `manager_count`\n" + + "FROM `scott`.`EMP`"); + } +} diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index ed5da84600d..f752e502de7 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -13,6 +13,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.and; import static org.opensearch.sql.ast.dsl.AstDSL.argument; import static org.opensearch.sql.ast.dsl.AstDSL.booleanLiteral; +import static org.opensearch.sql.ast.dsl.AstDSL.caseWhen; import static org.opensearch.sql.ast.dsl.AstDSL.cast; import static org.opensearch.sql.ast.dsl.AstDSL.compare; import static org.opensearch.sql.ast.dsl.AstDSL.decimalLiteral; @@ -42,6 +43,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.sort; import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.unresolvedArg; +import static org.opensearch.sql.ast.dsl.AstDSL.when; import static org.opensearch.sql.ast.dsl.AstDSL.xor; import com.google.common.collect.ImmutableMap; @@ -549,6 +551,24 @@ public void testCountFuncCallExpr() { defaultStatsArgs())); } + @Test + public void testCountEvalFuncCallExpr() { + assertEqual( + "source=t | stats count(eval(a > 0)) by b", + agg( + relation("t"), + exprList( + alias( + "count(eval(a > 0))", + aggregate( + "count", + caseWhen( + null, when(compare(">", field("a"), intLiteral(0)), intLiteral(1)))))), + emptyList(), + exprList(alias("b", field("b"))), + defaultStatsArgs())); + } + @Test public void testDistinctCount() { assertEqual(