Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/Window.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
public class Window extends UnresolvedPlan {

private final List<UnresolvedExpression> windowFunctionList;
private final List<UnresolvedExpression> groupList;
private final boolean bucketNullable;
@ToString.Exclude private UnresolvedPlan child;

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1614,9 +1614,32 @@ private static void buildDedupNotNull(
@Override
public RelNode visitWindow(Window node, CalcitePlanContext context) {
visitChildren(node, context);

List<UnresolvedExpression> groupList = node.getGroupList();
boolean hasGroup = groupList != null && !groupList.isEmpty();
boolean bucketNullable = node.isBucketNullable();

List<RexNode> overExpressions =
node.getWindowFunctionList().stream().map(w -> rexVisitor.analyze(w, context)).toList();
context.relBuilder.projectPlus(overExpressions);

if (hasGroup && !bucketNullable) {
// construct groupNotNull predicate
List<RexNode> groupByList =
groupList.stream().map(expr -> rexVisitor.analyze(expr, context)).toList();
List<RexNode> notNullList =
PlanUtils.getSelectColumns(groupByList).stream()
.map(context.relBuilder::field)
.map(context.relBuilder::isNotNull)
.toList();
RexNode groupNotNull = context.relBuilder.and(notNullList);

// wrap each expr: CASE WHEN groupNotNull THEN rawExpr ELSE CAST(NULL AS rawType) END
List<RexNode> wrappedOverExprs =
wrapWindowFunctionsWithGroupNotNull(overExpressions, groupNotNull, context);
context.relBuilder.projectPlus(wrappedOverExprs);
} else {
context.relBuilder.projectPlus(overExpressions);
}
return context.relBuilder.peek();
}

Expand Down
36 changes: 35 additions & 1 deletion docs/user/ppl/cmd/eventstats.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ The ``stats`` and ``eventstats`` commands are both used for calculating statisti

Syntax
======
eventstats <function>... [by-clause]
eventstats [bucket_nullable=bool] <function>... [by-clause]

* function: mandatory. An aggregation function or window function.
* bucket_nullable: optional. Controls whether the eventstats command consider null buckets as a valid group in group-by aggregations. When set to ``false``, it will not treat null group-by values as a distinct group during aggregation. **Default:** Determined by ``plugins.ppl.syntax.legacy.preferred``.

* When ``plugins.ppl.syntax.legacy.preferred=true``, ``bucket_nullable`` defaults to ``true``
* When ``plugins.ppl.syntax.legacy.preferred=false``, ``bucket_nullable`` defaults to ``false``

* by-clause: optional. Groups results by specified fields or expressions. Syntax: by [span-expression,] [field,]... **Default:** aggregation over the entire result set.
* span-expression: optional, at most one. Splits field into buckets by intervals. Syntax: span(field_expr, interval_expr). For example, ``span(age, 10)`` creates 10-year age buckets, ``span(timestamp, 1h)`` creates hourly buckets.

Expand Down Expand Up @@ -126,3 +131,32 @@ PPL query::
| 13 | F | 28 | 1 |
| 18 | M | 33 | 2 |
+----------------+--------+-----+-----+

Example 3: Null buckets handling
================================

PPL query::

os> source=accounts | eventstats bucket_nullable=false count() as cnt by employer | fields account_number, firstname, employer, cnt | sort account_number;
fetched rows / total rows = 4/4
+----------------+-----------+----------+------+
| account_number | firstname | employer | cnt |
|----------------+-----------+----------+------|
| 1 | Amber | Pyrami | 1 |
| 6 | Hattie | Netagy | 1 |
| 13 | Nanette | Quility | 1 |
| 18 | Dale | null | null |
+----------------+-----------+----------+------+

PPL query::

os> source=accounts | eventstats bucket_nullable=true count() as cnt by employer | fields account_number, firstname, employer, cnt | sort account_number;
fetched rows / total rows = 4/4
+----------------+-----------+----------+-----+
| account_number | firstname | employer | cnt |
|----------------+-----------+----------+-----|
| 1 | Amber | Pyrami | 1 |
| 6 | Hattie | Netagy | 1 |
| 13 | Nanette | Quility | 1 |
| 18 | Dale | null | 1 |
+----------------+-----------+----------+-----+
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,16 @@ public void testEventstatsDistinctCountFunctionExplain() throws IOException {
assertJsonEqualsIgnoreId(expected, result);
}

@Test
public void testEventstatsNullBucketExplain() throws IOException {
String query =
"source=opensearch-sql_test_index_account | eventstats bucket_nullable=false count() by"
+ " state";
var result = explainQueryYaml(query);
String expected = loadExpectedPlan("explain_eventstats_null_bucket.yaml");
assertYamlEqualsIgnoreId(expected, result);
}

@Test
public void testStreamstatsDistinctCountExplain() throws IOException {
String query =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,40 @@ public void testEventstatsByWithNull() throws IOException {
rows("Hello", "USA", "New York", 4, 2023, 30, 1, 30, 30, 30));
}

@Test
public void testEventstatsByWithNullBucket() throws IOException {
JSONObject actual =
executeQuery(
String.format(
"source=%s | eventstats bucket_nullable=false count() as cnt, avg(age) as avg,"
+ " min(age) as min, max(age) as max by country",
TEST_INDEX_STATE_COUNTRY_WITH_NULL));

verifyDataRows(
actual,
rows("Kevin", null, null, 4, 2023, null, null, null, null, null),
rows(null, "Canada", null, 4, 2023, 10, 3, 18.333333333333332, 10, 25),
rows("John", "Canada", "Ontario", 4, 2023, 25, 3, 18.333333333333332, 10, 25),
rows("Jane", "Canada", "Quebec", 4, 2023, 20, 3, 18.333333333333332, 10, 25),
rows("Jake", "USA", "California", 4, 2023, 70, 2, 50, 30, 70),
rows("Hello", "USA", "New York", 4, 2023, 30, 2, 50, 30, 70));

actual =
executeQuery(
String.format(
"source=%s | eventstats bucket_nullable=false count() as cnt, avg(age) as avg,"
+ " min(age) as min, max(age) as max by state",
TEST_INDEX_STATE_COUNTRY_WITH_NULL));
verifyDataRows(
actual,
rows(null, "Canada", null, 4, 2023, 10, null, null, null, null),
rows("Kevin", null, null, 4, 2023, null, null, null, null, null),
rows("John", "Canada", "Ontario", 4, 2023, 25, 1, 25, 25, 25),
rows("Jane", "Canada", "Quebec", 4, 2023, 20, 1, 20, 20, 20),
rows("Jake", "USA", "California", 4, 2023, 70, 1, 70, 70, 70),
rows("Hello", "USA", "New York", 4, 2023, 30, 1, 30, 30, 30));
}

@Test
public void testEventstatsBySpan() throws IOException {
JSONObject actual =
Expand Down Expand Up @@ -324,6 +358,26 @@ public void testMultipleEventstatsWithNull() throws IOException {
rows("Hello", "USA", "New York", 4, 2023, 30, 30.0, 50.0));
}

@Test
public void testMultipleEventstatsWithNullBucket() throws IOException {
JSONObject actual =
executeQuery(
String.format(
"source=%s | eventstats bucket_nullable=false avg(age) as avg_age by state, country"
+ " | eventstats bucket_nullable=false avg(avg_age) as avg_state_age by"
+ " country",
TEST_INDEX_STATE_COUNTRY_WITH_NULL));

verifyDataRows(
actual,
rows("Kevin", null, null, 4, 2023, null, null, null),
rows(null, "Canada", null, 4, 2023, 10, null, 22.5),
rows("Jane", "Canada", "Quebec", 4, 2023, 20, 20.0, 22.5),
rows("John", "Canada", "Ontario", 4, 2023, 25, 25.0, 22.5),
rows("Jake", "USA", "California", 4, 2023, 70, 70.0, 50.0),
rows("Hello", "USA", "New York", 4, 2023, 30, 30.0, 50.0));
}

@Test
public void testMultipleEventstatsWithEval() throws IOException {
JSONObject actual =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
calcite:
logical: |
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], count()=[CASE(IS NOT NULL($7), COUNT() OVER (PARTITION BY $7), null:BIGINT)])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])
physical: |
EnumerableLimit(fetch=[10000])
EnumerableCalc(expr#0..12=[{inputs}], expr#13=[null:BIGINT], expr#14=[CASE($t11, $t12, $t13)], proj#0..10=[{exprs}], count()=[$t14])
EnumerableWindow(window#0=[window(partition {7} aggs [COUNT()])])
EnumerableCalc(expr#0..10=[{inputs}], expr#11=[IS NOT NULL($t7)], proj#0..11=[{exprs}])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[account_number, firstname, address, balance, gender, city, employer, state, age, email, lastname]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"timeout":"1m","_source":{"includes":["account_number","firstname","address","balance","gender","city","employer","state","age","email","lastname"],"excludes":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
calcite:
logical: |
LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])
LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], count()=[CASE(IS NOT NULL($7), COUNT() OVER (PARTITION BY $7), null:BIGINT)])
CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])
physical: |
EnumerableLimit(fetch=[10000])
EnumerableCalc(expr#0..12=[{inputs}], expr#13=[null:BIGINT], expr#14=[CASE($t11, $t12, $t13)], proj#0..10=[{exprs}], count()=[$t14])
EnumerableWindow(window#0=[window(partition {7} aggs [COUNT()])])
EnumerableCalc(expr#0..16=[{inputs}], expr#17=[IS NOT NULL($t7)], proj#0..10=[{exprs}], $11=[$t17])
CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])
2 changes: 1 addition & 1 deletion ppl/src/main/antlr/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ dedupSplitArg
;

eventstatsCommand
: EVENTSTATS eventstatsAggTerm (COMMA eventstatsAggTerm)* (statsByClause)?
: EVENTSTATS (bucketNullableArg)? eventstatsAggTerm (COMMA eventstatsAggTerm)* (statsByClause)?
;

streamstatsCommand
Expand Down
16 changes: 13 additions & 3 deletions ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -481,14 +481,24 @@ public UnresolvedPlan visitStatsCommand(StatsCommandContext ctx) {

/** Eventstats command. */
public UnresolvedPlan visitEventstatsCommand(OpenSearchPPLParser.EventstatsCommandContext ctx) {
// 1. Parse arguments from the eventstats command
List<Argument> argExprList = ArgumentFactory.getArgumentList(ctx, settings);
ArgumentMap arguments = ArgumentMap.of(argExprList);

// bucket_nullable
boolean bucketNullable =
(Boolean) arguments.getOrDefault(Argument.BUCKET_NULLABLE, Literal.TRUE).getValue();

// 2. Build groupList
List<UnresolvedExpression> groupList = getPartitionExprList(ctx.statsByClause());

ImmutableList.Builder<UnresolvedExpression> windownFunctionListBuilder =
new ImmutableList.Builder<>();
for (OpenSearchPPLParser.EventstatsAggTermContext aggCtx : ctx.eventstatsAggTerm()) {
UnresolvedExpression windowFunction = internalVisitExpression(aggCtx.windowFunction());
// set partition by list for window function
if (windowFunction instanceof WindowFunction) {
((WindowFunction) windowFunction)
.setPartitionByList(getPartitionExprList(ctx.statsByClause()));
((WindowFunction) windowFunction).setPartitionByList(groupList);
}
String name =
aggCtx.alias == null
Expand All @@ -498,7 +508,7 @@ public UnresolvedPlan visitEventstatsCommand(OpenSearchPPLParser.EventstatsComma
windownFunctionListBuilder.add(alias);
}

return new Window(windownFunctionListBuilder.build());
return new Window(windownFunctionListBuilder.build(), groupList, bucketNullable);
}

/** Streamstats command. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext;
import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DedupCommandContext;
import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DefaultSortFieldContext;
import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EventstatsCommandContext;
import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldsCommandContext;
import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IntegerLiteralContext;
import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.PrefixSortFieldContext;
Expand Down Expand Up @@ -111,6 +112,22 @@ public static List<Argument> getArgumentList(StreamstatsCommandContext ctx) {
: new Argument("global", new Literal(true, DataType.BOOLEAN)));
}

/**
* Get list of {@link Argument}.
*
* @param ctx EventstatsCommandContext instance
* @return the list of arguments fetched from the eventstats command
*/
public static List<Argument> getArgumentList(EventstatsCommandContext ctx, Settings settings) {
return Collections.singletonList(
ctx.bucketNullableArg() != null && !ctx.bucketNullableArg().isEmpty()
? new Argument(
Argument.BUCKET_NULLABLE, getArgumentValue(ctx.bucketNullableArg().bucket_nullable))
: new Argument(
Argument.BUCKET_NULLABLE,
legacyPreferred(settings) ? Literal.TRUE : Literal.FALSE));
}

/**
* Get list of {@link Argument}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,25 @@ public void testEventstatsAvg() {
+ "FROM `scott`.`EMP`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

@Test
public void testEventstatsNullBucket() {
String ppl = "source=EMP | eventstats bucket_nullable=false avg(SAL) by DEPTNO";
RelNode root = getRelNode(ppl);
String expectedLogical =
"LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5],"
+ " COMM=[$6], DEPTNO=[$7], avg(SAL)=[CASE(IS NOT NULL($7), /(SUM($5) OVER (PARTITION"
+ " BY $7), CAST(COUNT($5) OVER (PARTITION BY $7)):DOUBLE NOT NULL), null:DOUBLE)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);

String expectedSparkSql =
"SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`, CASE WHEN"
+ " `DEPTNO` IS NOT NULL THEN (SUM(`SAL`) OVER (PARTITION BY `DEPTNO` RANGE BETWEEN"
+ " UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)) / CAST(COUNT(`SAL`) OVER (PARTITION"
+ " BY `DEPTNO` RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS DOUBLE)"
+ " ELSE NULL END `avg(SAL)`\n"
+ "FROM `scott`.`EMP`";
verifyPPLToSparkSQL(root, expectedSparkSql);
}
}
Loading