Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions docs/user/ppl/cmd/stats.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,37 @@ COUNT
Description
>>>>>>>>>>>

Usage: Returns a count of the number of expr in the rows retrieved by a SELECT statement. The ``C()`` function can be used as an abbreviation for ``COUNT()``.
Usage: Returns a count of the number of expr in the rows retrieved. The ``C()`` function can be used as an abbreviation for ``COUNT()``. To perform a filtered counting, wrap the condition to satisfy in an `eval` expression.

Example::

os> source=accounts | stats count();
os> source=accounts | stats count(), c();
fetched rows / total rows = 1/1
+---------+
| count() |
|---------|
| 4 |
+---------+
+---------+-----+
| count() | c() |
|---------+-----|
| 4 | 4 |
+---------+-----+

Example of filtered counting::

os> source=accounts | stats count(eval(age > 30)) as mature_users;
fetched rows / total rows = 1/1
+--------------+
| mature_users |
|--------------|
| 3 |
+--------------+

Example of filtered counting with complex conditions::

os> source=accounts | stats c();
os> source=accounts | stats count(eval(age > 30 and balance > 25000)) as high_value_users;
fetched rows / total rows = 1/1
+-----+
| c() |
|-----|
| 4 |
+-----+
+------------------+
| high_value_users |
|------------------|
| 1 |
+------------------+

SUM
---
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,72 @@ public void testAggByByteNumberWithScript() throws IOException {
verifyDataRows(response, rows(1, 4));
}

@Test
public void testCountEvalSimpleCondition() throws IOException {
JSONObject actual =
executeQuery(
String.format("source=%s | stats count(eval(age > 30)) as c", TEST_INDEX_BANK));
verifySchema(actual, schema("c", "bigint"));
verifyDataRows(actual, rows(6));
}

@Test
public void testCountEvalComplexCondition() throws IOException {
JSONObject actual =
executeQuery(
String.format(
"source=%s | stats count(eval(balance > 20000 and age < 35)) as c",
TEST_INDEX_BANK));
verifySchema(actual, schema("c", "bigint"));
verifyDataRows(actual, rows(3));
}

@Test
public void testCountEvalGroupBy() throws IOException {
JSONObject actual =
executeQuery(
String.format(
"source=%s | stats count(eval(balance > 25000)) as high_balance by gender",
TEST_INDEX_BANK));
verifySchema(actual, schema("gender", "string"), schema("high_balance", "bigint"));
verifyDataRows(actual, rows(3, "F"), rows(1, "M"));
}

@Test
public void testCountEvalWithMultipleAggregations() throws IOException {
JSONObject actual =
executeQuery(
String.format(
"source=%s | stats count(eval(age > 30)) as mature_count, "
+ "count(eval(balance > 25000)) as high_balance_count, "
+ "count() as total_count",
TEST_INDEX_BANK));
verifySchema(
actual,
schema("mature_count", "bigint"),
schema("high_balance_count", "bigint"),
schema("total_count", "bigint"));
verifyDataRows(actual, rows(6, 4, 7));
}

@Test
public void testShortcutCEvalSimpleCondition() throws IOException {
JSONObject actual =
executeQuery(String.format("source=%s | stats c(eval(age > 30)) as c", TEST_INDEX_BANK));
verifySchema(actual, schema("c", "bigint"));
verifyDataRows(actual, rows(6));
}

@Test
public void testShortcutCEvalComplexCondition() throws IOException {
JSONObject actual =
executeQuery(
String.format(
"source=%s | stats c(eval(balance > 20000 and age < 35)) as c", TEST_INDEX_BANK));
verifySchema(actual, schema("c", "bigint"));
verifyDataRows(actual, rows(3));
}

@Test
public void testPercentileShortcuts() throws IOException {
JSONObject actual =
Expand Down
5 changes: 5 additions & 0 deletions ppl/src/main/antlr/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ statsAggTerm
// aggregation functions
statsFunction
: statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall
| (COUNT | C) LT_PRTHS evalExpression RT_PRTHS # countEvalFunctionCall
| (COUNT | C) LT_PRTHS RT_PRTHS # countAllFunctionCall
| PERCENTILE_SHORTCUT LT_PRTHS valueExpression RT_PRTHS # percentileShortcutFunctionCall
| (DISTINCT_COUNT | DC | DISTINCT_COUNT_APPROX) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall
Expand Down Expand Up @@ -501,6 +502,10 @@ valueExpression
| LT_PRTHS logicalExpression RT_PRTHS # nestedValueExpr
;

evalExpression
: EVAL LT_PRTHS logicalExpression RT_PRTHS
;

functionCall
: evalFunctionCall
| dataTypeFunctionCall
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,39 @@
package org.opensearch.sql.ppl.parser;

import static org.opensearch.sql.expression.function.BuiltinFunctionName.*;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BinaryArithmeticContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BySpanClauseContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ConvertedDataTypeContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountAllFunctionCallContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountEvalFunctionCallContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DataTypeFunctionCallContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DistinctCountFunctionCallContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DoubleLiteralContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalClauseContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalFunctionCallContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldExpressionContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FloatLiteralContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IdentsAsQualifiedNameContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IdentsAsTableQualifiedNameContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IdentsAsWildcardQualifiedNameContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.InExprContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntegerLiteralContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntervalLiteralContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalAndContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalNotContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalOrContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalXorContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.MultiFieldRelevanceFunctionContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SingleFieldRelevanceFunctionContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SpanClauseContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StatsFunctionCallContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StringLiteralContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TableSourceContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.WcFieldExpressionContext;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -41,6 +74,7 @@
import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DistinctCountFunctionCallContext;
import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DoubleLiteralContext;
import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalClauseContext;
import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalExpressionContext;
import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalFunctionCallContext;
import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldExpressionContext;
import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FloatLiteralContext;
Expand Down Expand Up @@ -229,12 +263,28 @@ public UnresolvedExpression visitCountAllFunctionCall(CountAllFunctionCallContex
return new AggregateFunction("count", AllFields.of());
}

@Override
public UnresolvedExpression visitCountEvalFunctionCall(CountEvalFunctionCallContext ctx) {
return new AggregateFunction("count", visit(ctx.evalExpression()));
}

@Override
public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunctionCallContext ctx) {
String funcName = ctx.DISTINCT_COUNT_APPROX() != null ? "distinct_count_approx" : "count";
return new AggregateFunction(funcName, visit(ctx.valueExpression()), true);
}

@Override
public UnresolvedExpression visitEvalExpression(EvalExpressionContext ctx) {
/*
* Rewrite "eval(p)" as "CASE WHEN p THEN 1 ELSE NULL END" so that COUNT or DISTINCT_COUNT
* can correctly perform filtered counting.
* Note: at present only eval(<predicate>) inside counting functions is supported.
*/
UnresolvedExpression predicate = visit(ctx.logicalExpression());
return AstDSL.caseWhen(null, AstDSL.when(predicate, AstDSL.intLiteral(1)));
}

@Override
public UnresolvedExpression visitPercentileApproxFunctionCall(
OpenSearchPPLParser.PercentileApproxFunctionCallContext ctx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,37 @@ private Node plan(PPLSyntaxParser parser, String query) {
return builder.visit(parser.parse(query));
}

/**
* Fluent API for building count(eval) test cases. Provides a clean and readable way to define PPL
* queries and their expected outcomes.
*/
protected PPLQueryTestBuilder withPPLQuery(String ppl) {
return new PPLQueryTestBuilder(ppl);
}

protected class PPLQueryTestBuilder {
private final RelNode relNode;

public PPLQueryTestBuilder(String ppl) {
this.relNode = getRelNode(ppl);
}

public PPLQueryTestBuilder expectLogical(String expectedLogical) {
verifyLogical(relNode, expectedLogical);
return this;
}

public PPLQueryTestBuilder expectResult(String expectedResult) {
verifyResult(relNode, expectedResult);
return this;
}

public PPLQueryTestBuilder expectSparkSQL(String expectedSparkSql) {
verifyPPLToSparkSQL(relNode, expectedSparkSql);
return this;
}
}

/** Verify the logical plan of the given RelNode */
public void verifyLogical(RelNode rel, String expectedLogical) {
assertThat(rel, hasTree(expectedLogical));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ppl.calcite;

import org.apache.calcite.test.CalciteAssert;
import org.junit.Test;

/**
* Unit tests for count(eval) functionality in CalcitePPL engine. Tests various scenarios of
* filtered count aggregations.
*/
public class CalcitePPLCountEvalTest extends CalcitePPLAbstractTest {

public CalcitePPLCountEvalTest() {
super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL);
}

@Test
public void testCountEvalSimpleCondition() {
withPPLQuery("source=EMP | stats count(eval(SAL > 2000)) as c")
.expectLogical(
"LogicalAggregate(group=[{}], c=[COUNT($0)])\n"
+ " LogicalProject($f1=[CASE(>($5, 2000), 1, null:NULL)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n")
.expectResult("c=6\n")
.expectSparkSQL(
"SELECT COUNT(CASE WHEN `SAL` > 2000 THEN 1 ELSE NULL END) `c`\nFROM `scott`.`EMP`");
}

@Test
public void testCountEvalComplexCondition() {
withPPLQuery("source=EMP | stats count(eval(SAL > 2000 and DEPTNO < 30)) as c")
.expectLogical(
"LogicalAggregate(group=[{}], c=[COUNT($0)])\n"
+ " LogicalProject($f2=[CASE(AND(>($5, 2000), <($7, 30)), 1, null:NULL)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n")
.expectResult("c=5\n")
.expectSparkSQL(
"SELECT COUNT(CASE WHEN `SAL` > 2000 AND `DEPTNO` < 30 THEN 1 ELSE NULL END) `c`\n"
+ "FROM `scott`.`EMP`");
}

@Test
public void testCountEvalStringComparison() {
withPPLQuery("source=EMP | stats count(eval(JOB = 'MANAGER')) as manager_count")
.expectLogical(
"LogicalAggregate(group=[{}], manager_count=[COUNT($0)])\n"
+ " LogicalProject($f1=[CASE(=($2, 'MANAGER':VARCHAR), 1, null:NULL)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n")
.expectResult("manager_count=3\n")
.expectSparkSQL(
"SELECT COUNT(CASE WHEN `JOB` = 'MANAGER' THEN 1 ELSE NULL END) `manager_count`\n"
+ "FROM `scott`.`EMP`");
}

@Test
public void testCountEvalArithmeticExpression() {
withPPLQuery("source=EMP | stats count(eval(SAL / COMM > 10)) as high_ratio")
.expectLogical(
"LogicalAggregate(group=[{}], high_ratio=[COUNT($0)])\n"
+ " LogicalProject($f2=[CASE(>(DIVIDE($5, $6), 10), 1, null:NULL)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n")
.expectResult("high_ratio=0\n")
.expectSparkSQL(
"SELECT COUNT(CASE WHEN `DIVIDE`(`SAL`, `COMM`) > 10 THEN 1 ELSE NULL END)"
+ " `high_ratio`\n"
+ "FROM `scott`.`EMP`");
}

@Test
public void testCountEvalWithNullHandling() {
withPPLQuery("source=EMP | stats count(eval(isnotnull(MGR))) as non_null_mgr")
.expectLogical(
"LogicalAggregate(group=[{}], non_null_mgr=[COUNT($0)])\n"
+ " LogicalProject($f1=[CASE(IS NOT NULL($3), 1, null:NULL)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n")
.expectResult("non_null_mgr=13\n")
.expectSparkSQL(
"SELECT COUNT(CASE WHEN `MGR` IS NOT NULL THEN 1 ELSE NULL END) `non_null_mgr`\n"
+ "FROM `scott`.`EMP`");
}

@Test
public void testShortcutCEvalSimpleCondition() {
withPPLQuery("source=EMP | stats c(eval(SAL > 2000)) as c")
.expectLogical(
"LogicalAggregate(group=[{}], c=[COUNT($0)])\n"
+ " LogicalProject($f1=[CASE(>($5, 2000), 1, null:NULL)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n")
.expectResult("c=6\n")
.expectSparkSQL(
"SELECT COUNT(CASE WHEN `SAL` > 2000 THEN 1 ELSE NULL END) `c`\nFROM `scott`.`EMP`");
}

@Test
public void testShortcutCEvalComplexCondition() {
withPPLQuery("source=EMP | stats c(eval(JOB = 'MANAGER')) as manager_count")
.expectLogical(
"LogicalAggregate(group=[{}], manager_count=[COUNT($0)])\n"
+ " LogicalProject($f1=[CASE(=($2, 'MANAGER':VARCHAR), 1, null:NULL)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n")
.expectResult("manager_count=3\n")
.expectSparkSQL(
"SELECT COUNT(CASE WHEN `JOB` = 'MANAGER' THEN 1 ELSE NULL END) `manager_count`\n"
+ "FROM `scott`.`EMP`");
}
}
Loading
Loading