Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
}

public enum TrendlineType {
SMA
SMA,
WMA
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.opensearch.sql.ast.expression.ParseMethod;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.WindowFrame;
import org.opensearch.sql.ast.expression.WindowFrame.FrameType;
import org.opensearch.sql.ast.expression.subquery.SubqueryExpression;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.Aggregation;
Expand Down Expand Up @@ -86,11 +87,13 @@
import org.opensearch.sql.ast.tree.SubqueryAlias;
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.Trendline.TrendlineType;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.ast.tree.Window;
import org.opensearch.sql.calcite.plan.OpenSearchConstants;
import org.opensearch.sql.calcite.utils.JoinAndLookupUtils;
import org.opensearch.sql.calcite.utils.PlanUtils;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.exception.CalciteUnsupportedException;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
Expand Down Expand Up @@ -975,6 +978,126 @@ public RelNode visitTableFunction(TableFunction node, CalcitePlanContext context

@Override
public RelNode visitTrendline(Trendline node, CalcitePlanContext context) {
throw new CalciteUnsupportedException("Trendline command is unsupported in Calcite");
visitChildren(node, context);

node.getSortByField()
.ifPresent(
sortField -> {
SortOption sortOption = analyzeSortOption(sortField.getFieldArgs());
RexNode field = rexVisitor.analyze(sortField, context);
if (sortOption == DEFAULT_DESC) {
context.relBuilder.sort(context.relBuilder.desc(field));
} else {
context.relBuilder.sort(field);
}
});

List<RexNode> trendlineNodes = new ArrayList<>();
List<String> aliases = new ArrayList<>();
node.getComputations()
.forEach(
trendlineComputation -> {
RexNode field = rexVisitor.analyze(trendlineComputation.getDataField(), context);
context.relBuilder.filter(context.relBuilder.isNotNull(field));

WindowFrame windowFrame =
WindowFrame.of(
FrameType.ROWS,
StringUtils.format(
"%d PRECEDING", trendlineComputation.getNumberOfDataPoints() - 1),
"CURRENT ROW");
RexNode countExpr =
PlanUtils.makeOver(
context,
BuiltinFunctionName.COUNT,
null,
List.of(),
List.of(),
List.of(),
windowFrame);
// CASE WHEN count() over (ROWS (windowSize-1) PRECEDING) > windowSize - 1
RexNode whenConditionExpr =
PPLFuncImpTable.INSTANCE.resolve(
context.rexBuilder,
">",
countExpr,
context.relBuilder.literal(trendlineComputation.getNumberOfDataPoints() - 1));

RexNode thenExpr;
switch (trendlineComputation.getComputationType()) {
case TrendlineType.SMA:
// THEN avg(field) over (ROWS (windowSize-1) PRECEDING)
thenExpr =
PlanUtils.makeOver(
context,
BuiltinFunctionName.AVG,
field,
List.of(),
List.of(),
List.of(),
windowFrame);
break;
case TrendlineType.WMA:
// THEN wma expression
thenExpr =
buildWmaRexNode(
field,
trendlineComputation.getNumberOfDataPoints(),
windowFrame,
context);
break;
default:
throw new IllegalStateException("Unsupported trendline type");
}

// ELSE NULL
RexNode elseExpr = context.relBuilder.literal(null);

List<RexNode> caseOperands = new ArrayList<>();
caseOperands.add(whenConditionExpr);
caseOperands.add(thenExpr);
caseOperands.add(elseExpr);
RexNode trendlineNode =
context.rexBuilder.makeCall(SqlStdOperatorTable.CASE, caseOperands);
trendlineNodes.add(trendlineNode);
aliases.add(trendlineComputation.getAlias());
});

projectPlusOverriding(trendlineNodes, aliases, context);
return context.relBuilder.peek();
}

private RexNode buildWmaRexNode(
RexNode field,
Integer numberOfDataPoints,
WindowFrame windowFrame,
CalcitePlanContext context) {

// Divisor: 1 + 2 + 3 + ... + windowSize, aka (windowSize * (windowSize + 1) / 2)
RexNode divisor = context.relBuilder.literal(numberOfDataPoints * (numberOfDataPoints + 1) / 2);

// Divider: 1 * NTH_VALUE(field, 1) + 2 * NTH_VALUE(field, 2) + ... + windowSize *
// NTH_VALUE(field, windowSize)
RexNode divider = context.relBuilder.literal(0);
for (int i = 1; i <= numberOfDataPoints; i++) {
RexNode nthValueExpr =
PlanUtils.makeOver(
context,
BuiltinFunctionName.NTH_VALUE,
field,
List.of(context.relBuilder.literal(i)),
List.of(),
List.of(),
windowFrame);
divider =
context.relBuilder.call(
SqlStdOperatorTable.PLUS,
divider,
context.relBuilder.call(
SqlStdOperatorTable.MULTIPLY, nthValueExpr, context.relBuilder.literal(i)));
}
// Divider / CAST(Divisor, DOUBLE)
return context.relBuilder.call(
SqlStdOperatorTable.DIVIDE, divider, context.relBuilder.cast(divisor, SqlTypeName.DOUBLE));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ static RexNode makeOver(
true,
lowerBound,
upperBound);
case NTH_VALUE:
return withOver(
context.relBuilder.aggregateCall(SqlStdOperatorTable.NTH_VALUE, field, argList.get(0)),
partitions,
orderKeys,
true,
lowerBound,
upperBound);
default:
return withOver(
makeAggCall(context, functionName, false, field, argList),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ public enum BuiltinFunctionName {
IS_BLANK(FunctionName.of("isblank")),

ROW_NUMBER(FunctionName.of("row_number")),
NTH_VALUE(FunctionName.of("nth_value")),
RANK(FunctionName.of("rank")),
DENSE_RANK(FunctionName.of("dense_rank")),

Expand Down
3 changes: 3 additions & 0 deletions docs/user/ppl/cmd/trendline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,6 @@ PPL query::
| 15.5 |
+--------------------------+

Limitation
==========
The ``trendline`` command will filter out all NULL values to make sure result correctness because it's meaningless to count NULL values. But this may reduce lines in result for further processing.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to:
Starting with version 3.1.0, the trendline command requires all values in the specified field to be non-null. Any null values present in the calculation field will be automatically excluded from the command's output.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,6 @@ public void init() throws Exception {
disallowCalciteFallback();
}

@Override
public void testTrendlinePushDownExplain() throws Exception {
withFallbackEnabled(
() -> {
try {
super.testTrendlinePushDownExplain();
} catch (Exception e) {
throw new RuntimeException(e);
}
},
"https://github.com/opensearch-project/sql/issues/3466");
}

@Override
public void testTrendlineWithSortPushDownExplain() throws Exception {
withFallbackEnabled(
() -> {
try {
super.testTrendlineWithSortPushDownExplain();
} catch (Exception e) {
throw new RuntimeException(e);
}
},
"https://github.com/opensearch-project/sql/issues/3466");
}

@Override
@Ignore("test only in v2")
public void testExplainModeUnsupportedInV2() throws IOException {}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.standalone;

import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK;
import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK_WITH_NULL_VALUES;
import static org.opensearch.sql.util.MatcherUtils.rows;
import static org.opensearch.sql.util.MatcherUtils.schema;
import static org.opensearch.sql.util.MatcherUtils.verifyDataRows;
import static org.opensearch.sql.util.MatcherUtils.verifySchema;

import java.io.IOException;
import org.json.JSONObject;
import org.junit.jupiter.api.Test;

public class CalcitePPLTrendlineIT extends CalcitePPLIntegTestCase {

@LantaoJin LantaoJin Jun 9, 2025

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing CalcitePPLTrendlinePushdownIT if we add IT in standalone

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's added now.


@Override
public void init() throws IOException {
super.init();
loadIndex(Index.BANK);
loadIndex(Index.BANK_WITH_NULL_VALUES);
}

@Test
public void testTrendlineSma() throws IOException {
JSONObject result =
executeQuery(
String.format(
"source=%s | where balance > 30000 | trendline sma(3, balance) as balance_trend |"
+ " fields balance_trend",
TEST_INDEX_BANK));
verifySchema(result, schema("balance_trend", "double"));
verifyDataRows(
result, rows((Object) null), rows((Object) null), rows(37534.333333333336), rows(40488));
}

@Test
public void testTrendlineWma() throws IOException {
JSONObject result =
executeQuery(
String.format(
"source=%s | where balance > 30000 | trendline wma(3, balance) as balance_trend |"
+ " fields balance_trend",
TEST_INDEX_BANK));
verifySchema(result, schema("balance_trend", "double"));
verifyDataRows(
result, rows((Object) null), rows((Object) null), rows(37753.5), rows(43029.333333333336));
}

@Test
public void testTrendlineMultipleFields() throws Exception {
JSONObject result =
executeQuery(
String.format(
"source=%s | where balance > 30000 | trendline sma(2, balance) as sma wma(3,"
+ " balance) as wma | fields balance, sma, wma",
TEST_INDEX_BANK));
verifySchema(
result, schema("balance", "long"), schema("sma", "double"), schema("wma", "double"));
verifyDataRows(
result,
rows(39225, null, null),
rows(32838, 36031.5, null),
rows(40540, 36689, 37753.5),
rows(48086, 44313, 43029.333333333336));
}

@Test
public void testTrendlineNoAlias() throws Exception {
JSONObject result =
executeQuery(
String.format(
"source=%s | where balance > 30000 | trendline sma(2, balance) | fields"
+ " balance, balance_trendline",
TEST_INDEX_BANK));
verifySchema(result, schema("balance", "long"), schema("balance_trendline", "double"));
verifyDataRows(
result, rows(39225, null), rows(32838, 36031.5), rows(40540, 36689), rows(48086, 44313));
}

@Test
public void testTrendlineOverwritesExisingField() throws Exception {
JSONObject result =
executeQuery(
String.format(
"source=%s | where balance > 30000 | trendline sma(2, balance) as balance | fields"
+ " balance",
TEST_INDEX_BANK));
verifySchema(result, schema("balance", "double"));
verifyDataRows(result, rows((Object) null), rows(36031.5), rows(36689), rows(44313));
}

@Test
public void testTrendlineWithSort() throws Exception {
JSONObject result =
executeQuery(
String.format(
"source=%s | where balance > 30000 | trendline sort - balance sma(2, balance) |"
+ " fields balance, balance_trendline",
TEST_INDEX_BANK));
verifySchema(result, schema("balance", "long"), schema("balance_trendline", "double"));
verifyDataRows(
result, rows(48086, null), rows(40540, 44313), rows(39225, 39882.5), rows(32838, 36031.5));
}

@Test
public void testTrendlinePreFilterNullValues() throws Exception {
JSONObject result =
executeQuery(
String.format(
"source=%s | trendline sma(2, balance) | fields" + " balance, balance_trendline",
TEST_INDEX_BANK_WITH_NULL_VALUES));
verifySchema(result, schema("balance", "long"), schema("balance_trendline", "double"));
verifyDataRows(
result, rows(39225, null), rows(32838, 36031.5), rows(4180, 18509), rows(48086, 26133));
}
}
10 changes: 8 additions & 2 deletions integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,10 @@ public void testFillNullPushDownExplain() throws Exception {

@Test
public void testTrendlinePushDownExplain() throws Exception {
String expected = loadFromFile("expectedOutput/ppl/explain_trendline_push.json");
String expected =
isCalciteEnabled()
? loadFromFile("expectedOutput/calcite/explain_trendline_push.json")
: loadFromFile("expectedOutput/ppl/explain_trendline_push.json");

assertJsonEqualsIgnoreId(
expected,
Expand All @@ -236,7 +239,10 @@ public void testTrendlinePushDownExplain() throws Exception {

@Test
public void testTrendlineWithSortPushDownExplain() throws Exception {
String expected = loadFromFile("expectedOutput/ppl/explain_trendline_sort_push.json");
String expected =
isCalciteEnabled()
? loadFromFile("expectedOutput/calcite/explain_trendline_sort_push.json")
: loadFromFile("expectedOutput/ppl/explain_trendline_sort_push.json");

assertJsonEqualsIgnoreId(
expected,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"calcite": {
"logical": "LogicalProject(ageTrend=[CASE(>(COUNT() OVER (ROWS 1 PRECEDING), 1), /(SUM($8) OVER (ROWS 1 PRECEDING), CAST(COUNT($8) OVER (ROWS 1 PRECEDING)):DOUBLE NOT NULL), null:NULL)])\n LogicalFilter(condition=[IS NOT NULL($8)])\n LogicalSort(fetch=[5])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n",
"physical": "EnumerableCalc(expr#0..3=[{inputs}], expr#4=[1], expr#5=[>($t1, $t4)], expr#6=[CAST($t3):DOUBLE NOT NULL], expr#7=[/($t2, $t6)], expr#8=[null:NULL], expr#9=[CASE($t5, $t7, $t8)], ageTrend=[$t9])\n EnumerableWindow(window#0=[window(rows between $1 PRECEDING and CURRENT ROW aggs [COUNT(), $SUM0($0), COUNT($0)])])\n EnumerableCalc(expr#0=[{inputs}], expr#1=[IS NOT NULL($t0)], age=[$t0], $condition=[$t1])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[LIMIT->5, PROJECT->[age]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":5,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"age\"],\"excludes\":[]}}, requestedTotalSize=5, pageSize=null, startFrom=0)])\n"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"calcite": {
"logical": "LogicalProject(ageTrend=[CASE(>(COUNT() OVER (ROWS 1 PRECEDING), 1), /(SUM($8) OVER (ROWS 1 PRECEDING), CAST(COUNT($8) OVER (ROWS 1 PRECEDING)):DOUBLE NOT NULL), null:NULL)])\n LogicalFilter(condition=[IS NOT NULL($8)])\n LogicalSort(sort0=[$8], dir0=[ASC])\n LogicalSort(fetch=[5])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n",
"physical": "EnumerableCalc(expr#0..3=[{inputs}], expr#4=[1], expr#5=[>($t1, $t4)], expr#6=[CAST($t3):DOUBLE NOT NULL], expr#7=[/($t2, $t6)], expr#8=[null:NULL], expr#9=[CASE($t5, $t7, $t8)], ageTrend=[$t9])\n EnumerableWindow(window#0=[window(rows between $1 PRECEDING and CURRENT ROW aggs [COUNT(), $SUM0($0), COUNT($0)])])\n EnumerableCalc(expr#0=[{inputs}], expr#1=[IS NOT NULL($t0)], age=[$t0], $condition=[$t1])\n EnumerableSort(sort0=[$0], dir0=[ASC])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[LIMIT->5, PROJECT->[age]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":5,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"age\"],\"excludes\":[]}}, requestedTotalSize=5, pageSize=null, startFrom=0)])\n"
}
}
1 change: 1 addition & 0 deletions ppl/src/main/antlr/OpenSearchPPLLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ NUM: 'NUM';

// TRENDLINE KEYWORDS
SMA: 'SMA';
WMA: 'WMA';

// ARGUMENT KEYWORDS
KEEPEMPTY: 'KEEPEMPTY';
Expand Down
Loading
Loading