diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Window.java b/core/src/main/java/org/opensearch/sql/ast/tree/Window.java index 818e78120ec..fbdf8e163a7 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Window.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Window.java @@ -21,6 +21,8 @@ public class Window extends UnresolvedPlan { private final List windowFunctionList; + private final List groupList; + private final boolean bucketNullable; @ToString.Exclude private UnresolvedPlan child; @Override diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index 4848415c360..158c25688f7 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -1614,9 +1614,32 @@ private static void buildDedupNotNull( @Override public RelNode visitWindow(Window node, CalcitePlanContext context) { visitChildren(node, context); + + List groupList = node.getGroupList(); + boolean hasGroup = groupList != null && !groupList.isEmpty(); + boolean bucketNullable = node.isBucketNullable(); + List overExpressions = node.getWindowFunctionList().stream().map(w -> rexVisitor.analyze(w, context)).toList(); - context.relBuilder.projectPlus(overExpressions); + + if (hasGroup && !bucketNullable) { + // construct groupNotNull predicate + List groupByList = + groupList.stream().map(expr -> rexVisitor.analyze(expr, context)).toList(); + List notNullList = + PlanUtils.getSelectColumns(groupByList).stream() + .map(context.relBuilder::field) + .map(context.relBuilder::isNotNull) + .toList(); + RexNode groupNotNull = context.relBuilder.and(notNullList); + + // wrap each expr: CASE WHEN groupNotNull THEN rawExpr ELSE CAST(NULL AS rawType) END + List wrappedOverExprs = + wrapWindowFunctionsWithGroupNotNull(overExpressions, groupNotNull, context); + context.relBuilder.projectPlus(wrappedOverExprs); + } else { + context.relBuilder.projectPlus(overExpressions); + } return context.relBuilder.peek(); } diff --git a/docs/user/ppl/cmd/eventstats.rst b/docs/user/ppl/cmd/eventstats.rst index 755af0486e4..cf4ac0d9b02 100644 --- a/docs/user/ppl/cmd/eventstats.rst +++ b/docs/user/ppl/cmd/eventstats.rst @@ -40,9 +40,14 @@ The ``stats`` and ``eventstats`` commands are both used for calculating statisti Syntax ====== -eventstats ... [by-clause] +eventstats [bucket_nullable=bool] ... [by-clause] * function: mandatory. An aggregation function or window function. +* bucket_nullable: optional. Controls whether the eventstats command consider null buckets as a valid group in group-by aggregations. When set to ``false``, it will not treat null group-by values as a distinct group during aggregation. **Default:** Determined by ``plugins.ppl.syntax.legacy.preferred``. + + * When ``plugins.ppl.syntax.legacy.preferred=true``, ``bucket_nullable`` defaults to ``true`` + * When ``plugins.ppl.syntax.legacy.preferred=false``, ``bucket_nullable`` defaults to ``false`` + * by-clause: optional. Groups results by specified fields or expressions. Syntax: by [span-expression,] [field,]... **Default:** aggregation over the entire result set. * span-expression: optional, at most one. Splits field into buckets by intervals. Syntax: span(field_expr, interval_expr). For example, ``span(age, 10)`` creates 10-year age buckets, ``span(timestamp, 1h)`` creates hourly buckets. @@ -126,3 +131,32 @@ PPL query:: | 13 | F | 28 | 1 | | 18 | M | 33 | 2 | +----------------+--------+-----+-----+ + +Example 3: Null buckets handling +================================ + +PPL query:: + + os> source=accounts | eventstats bucket_nullable=false count() as cnt by employer | fields account_number, firstname, employer, cnt | sort account_number; + fetched rows / total rows = 4/4 + +----------------+-----------+----------+------+ + | account_number | firstname | employer | cnt | + |----------------+-----------+----------+------| + | 1 | Amber | Pyrami | 1 | + | 6 | Hattie | Netagy | 1 | + | 13 | Nanette | Quility | 1 | + | 18 | Dale | null | null | + +----------------+-----------+----------+------+ + +PPL query:: + + os> source=accounts | eventstats bucket_nullable=true count() as cnt by employer | fields account_number, firstname, employer, cnt | sort account_number; + fetched rows / total rows = 4/4 + +----------------+-----------+----------+-----+ + | account_number | firstname | employer | cnt | + |----------------+-----------+----------+-----| + | 1 | Amber | Pyrami | 1 | + | 6 | Hattie | Netagy | 1 | + | 13 | Nanette | Quility | 1 | + | 18 | Dale | null | 1 | + +----------------+-----------+----------+-----+ diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index c981dfee8cb..c2dce34fc38 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -618,6 +618,16 @@ public void testEventstatsDistinctCountFunctionExplain() throws IOException { assertJsonEqualsIgnoreId(expected, result); } + @Test + public void testEventstatsNullBucketExplain() throws IOException { + String query = + "source=opensearch-sql_test_index_account | eventstats bucket_nullable=false count() by" + + " state"; + var result = explainQueryYaml(query); + String expected = loadExpectedPlan("explain_eventstats_null_bucket.yaml"); + assertYamlEqualsIgnoreId(expected, result); + } + @Test public void testStreamstatsDistinctCountExplain() throws IOException { String query = diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLEventstatsIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLEventstatsIT.java index 9839fff00c4..f1ee8df35ea 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLEventstatsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLEventstatsIT.java @@ -165,6 +165,40 @@ public void testEventstatsByWithNull() throws IOException { rows("Hello", "USA", "New York", 4, 2023, 30, 1, 30, 30, 30)); } + @Test + public void testEventstatsByWithNullBucket() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | eventstats bucket_nullable=false count() as cnt, avg(age) as avg," + + " min(age) as min, max(age) as max by country", + TEST_INDEX_STATE_COUNTRY_WITH_NULL)); + + verifyDataRows( + actual, + rows("Kevin", null, null, 4, 2023, null, null, null, null, null), + rows(null, "Canada", null, 4, 2023, 10, 3, 18.333333333333332, 10, 25), + rows("John", "Canada", "Ontario", 4, 2023, 25, 3, 18.333333333333332, 10, 25), + rows("Jane", "Canada", "Quebec", 4, 2023, 20, 3, 18.333333333333332, 10, 25), + rows("Jake", "USA", "California", 4, 2023, 70, 2, 50, 30, 70), + rows("Hello", "USA", "New York", 4, 2023, 30, 2, 50, 30, 70)); + + actual = + executeQuery( + String.format( + "source=%s | eventstats bucket_nullable=false count() as cnt, avg(age) as avg," + + " min(age) as min, max(age) as max by state", + TEST_INDEX_STATE_COUNTRY_WITH_NULL)); + verifyDataRows( + actual, + rows(null, "Canada", null, 4, 2023, 10, null, null, null, null), + rows("Kevin", null, null, 4, 2023, null, null, null, null, null), + rows("John", "Canada", "Ontario", 4, 2023, 25, 1, 25, 25, 25), + rows("Jane", "Canada", "Quebec", 4, 2023, 20, 1, 20, 20, 20), + rows("Jake", "USA", "California", 4, 2023, 70, 1, 70, 70, 70), + rows("Hello", "USA", "New York", 4, 2023, 30, 1, 30, 30, 30)); + } + @Test public void testEventstatsBySpan() throws IOException { JSONObject actual = @@ -324,6 +358,26 @@ public void testMultipleEventstatsWithNull() throws IOException { rows("Hello", "USA", "New York", 4, 2023, 30, 30.0, 50.0)); } + @Test + public void testMultipleEventstatsWithNullBucket() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | eventstats bucket_nullable=false avg(age) as avg_age by state, country" + + " | eventstats bucket_nullable=false avg(avg_age) as avg_state_age by" + + " country", + TEST_INDEX_STATE_COUNTRY_WITH_NULL)); + + verifyDataRows( + actual, + rows("Kevin", null, null, 4, 2023, null, null, null), + rows(null, "Canada", null, 4, 2023, 10, null, 22.5), + rows("Jane", "Canada", "Quebec", 4, 2023, 20, 20.0, 22.5), + rows("John", "Canada", "Ontario", 4, 2023, 25, 25.0, 22.5), + rows("Jake", "USA", "California", 4, 2023, 70, 70.0, 50.0), + rows("Hello", "USA", "New York", 4, 2023, 30, 30.0, 50.0)); + } + @Test public void testMultipleEventstatsWithEval() throws IOException { JSONObject actual = diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_eventstats_null_bucket.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_eventstats_null_bucket.yaml new file mode 100644 index 00000000000..ae969892eeb --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_eventstats_null_bucket.yaml @@ -0,0 +1,11 @@ +calcite: + logical: | + LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], count()=[CASE(IS NOT NULL($7), COUNT() OVER (PARTITION BY $7), null:BIGINT)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableCalc(expr#0..12=[{inputs}], expr#13=[null:BIGINT], expr#14=[CASE($t11, $t12, $t13)], proj#0..10=[{exprs}], count()=[$t14]) + EnumerableWindow(window#0=[window(partition {7} aggs [COUNT()])]) + EnumerableCalc(expr#0..10=[{inputs}], expr#11=[IS NOT NULL($t7)], proj#0..11=[{exprs}]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[account_number, firstname, address, balance, gender, city, employer, state, age, email, lastname]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"timeout":"1m","_source":{"includes":["account_number","firstname","address","balance","gender","city","employer","state","age","email","lastname"],"excludes":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_eventstats_null_bucket.yaml b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_eventstats_null_bucket.yaml new file mode 100644 index 00000000000..ad8f22e9421 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_eventstats_null_bucket.yaml @@ -0,0 +1,11 @@ +calcite: + logical: | + LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], count()=[CASE(IS NOT NULL($7), COUNT() OVER (PARTITION BY $7), null:BIGINT)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableCalc(expr#0..12=[{inputs}], expr#13=[null:BIGINT], expr#14=[CASE($t11, $t12, $t13)], proj#0..10=[{exprs}], count()=[$t14]) + EnumerableWindow(window#0=[window(partition {7} aggs [COUNT()])]) + EnumerableCalc(expr#0..16=[{inputs}], expr#17=[IS NOT NULL($t7)], proj#0..10=[{exprs}], $11=[$t17]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) \ No newline at end of file diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 6a542659047..5a4af885b90 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -255,7 +255,7 @@ dedupSplitArg ; eventstatsCommand - : EVENTSTATS eventstatsAggTerm (COMMA eventstatsAggTerm)* (statsByClause)? + : EVENTSTATS (bucketNullableArg)? eventstatsAggTerm (COMMA eventstatsAggTerm)* (statsByClause)? ; streamstatsCommand diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index ed66682a981..3ffff5f9442 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -481,14 +481,24 @@ public UnresolvedPlan visitStatsCommand(StatsCommandContext ctx) { /** Eventstats command. */ public UnresolvedPlan visitEventstatsCommand(OpenSearchPPLParser.EventstatsCommandContext ctx) { + // 1. Parse arguments from the eventstats command + List argExprList = ArgumentFactory.getArgumentList(ctx, settings); + ArgumentMap arguments = ArgumentMap.of(argExprList); + + // bucket_nullable + boolean bucketNullable = + (Boolean) arguments.getOrDefault(Argument.BUCKET_NULLABLE, Literal.TRUE).getValue(); + + // 2. Build groupList + List groupList = getPartitionExprList(ctx.statsByClause()); + ImmutableList.Builder windownFunctionListBuilder = new ImmutableList.Builder<>(); for (OpenSearchPPLParser.EventstatsAggTermContext aggCtx : ctx.eventstatsAggTerm()) { UnresolvedExpression windowFunction = internalVisitExpression(aggCtx.windowFunction()); // set partition by list for window function if (windowFunction instanceof WindowFunction) { - ((WindowFunction) windowFunction) - .setPartitionByList(getPartitionExprList(ctx.statsByClause())); + ((WindowFunction) windowFunction).setPartitionByList(groupList); } String name = aggCtx.alias == null @@ -498,7 +508,7 @@ public UnresolvedPlan visitEventstatsCommand(OpenSearchPPLParser.EventstatsComma windownFunctionListBuilder.add(alias); } - return new Window(windownFunctionListBuilder.build()); + return new Window(windownFunctionListBuilder.build(), groupList, bucketNullable); } /** Streamstats command. */ diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index 41e9e91535b..9fab9ba9a0f 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -26,6 +26,7 @@ import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DedupCommandContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DefaultSortFieldContext; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EventstatsCommandContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldsCommandContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntegerLiteralContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.PrefixSortFieldContext; @@ -111,6 +112,22 @@ public static List getArgumentList(StreamstatsCommandContext ctx) { : new Argument("global", new Literal(true, DataType.BOOLEAN))); } + /** + * Get list of {@link Argument}. + * + * @param ctx EventstatsCommandContext instance + * @return the list of arguments fetched from the eventstats command + */ + public static List getArgumentList(EventstatsCommandContext ctx, Settings settings) { + return Collections.singletonList( + ctx.bucketNullableArg() != null && !ctx.bucketNullableArg().isEmpty() + ? new Argument( + Argument.BUCKET_NULLABLE, getArgumentValue(ctx.bucketNullableArg().bucket_nullable)) + : new Argument( + Argument.BUCKET_NULLABLE, + legacyPreferred(settings) ? Literal.TRUE : Literal.FALSE)); + } + /** * Get list of {@link Argument}. * diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEventstatsTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEventstatsTest.java index cd808621407..24f489a739f 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEventstatsTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLEventstatsTest.java @@ -70,4 +70,25 @@ public void testEventstatsAvg() { + "FROM `scott`.`EMP`"; verifyPPLToSparkSQL(root, expectedSparkSql); } + + @Test + public void testEventstatsNullBucket() { + String ppl = "source=EMP | eventstats bucket_nullable=false avg(SAL) by DEPTNO"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$7], avg(SAL)=[CASE(IS NOT NULL($7), /(SUM($5) OVER (PARTITION" + + " BY $7), CAST(COUNT($5) OVER (PARTITION BY $7)):DOUBLE NOT NULL), null:DOUBLE)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`, CASE WHEN" + + " `DEPTNO` IS NOT NULL THEN (SUM(`SAL`) OVER (PARTITION BY `DEPTNO` RANGE BETWEEN" + + " UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)) / CAST(COUNT(`SAL`) OVER (PARTITION" + + " BY `DEPTNO` RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS DOUBLE)" + + " ELSE NULL END `avg(SAL)`\n" + + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } }