Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
}

public enum TrendlineType {
SMA
SMA,
WMA
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.opensearch.sql.ast.expression.ParseMethod;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.WindowFrame;
import org.opensearch.sql.ast.expression.WindowFrame.FrameType;
import org.opensearch.sql.ast.expression.subquery.SubqueryExpression;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.Aggregation;
Expand Down Expand Up @@ -86,11 +87,13 @@
import org.opensearch.sql.ast.tree.SubqueryAlias;
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.Trendline.TrendlineType;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.ast.tree.Window;
import org.opensearch.sql.calcite.plan.OpenSearchConstants;
import org.opensearch.sql.calcite.utils.JoinAndLookupUtils;
import org.opensearch.sql.calcite.utils.PlanUtils;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.exception.CalciteUnsupportedException;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
Expand Down Expand Up @@ -975,6 +978,126 @@ public RelNode visitTableFunction(TableFunction node, CalcitePlanContext context

@Override
public RelNode visitTrendline(Trendline node, CalcitePlanContext context) {
throw new CalciteUnsupportedException("Trendline command is unsupported in Calcite");
visitChildren(node, context);

node.getSortByField()
.ifPresent(
sortField -> {
SortOption sortOption = analyzeSortOption(sortField.getFieldArgs());
RexNode field = rexVisitor.analyze(sortField, context);
if (sortOption == DEFAULT_DESC) {
context.relBuilder.sort(context.relBuilder.desc(field));
} else {
context.relBuilder.sort(field);
}
});

List<RexNode> trendlineNodes = new ArrayList<>();
List<String> aliases = new ArrayList<>();
node.getComputations()
.forEach(
trendlineComputation -> {
RexNode field = rexVisitor.analyze(trendlineComputation.getDataField(), context);
context.relBuilder.filter(context.relBuilder.isNotNull(field));

WindowFrame windowFrame =
WindowFrame.of(
FrameType.ROWS,
StringUtils.format(
"%d PRECEDING", trendlineComputation.getNumberOfDataPoints() - 1),
"CURRENT ROW");
RexNode countExpr =
PlanUtils.makeOver(
context,
BuiltinFunctionName.COUNT,
null,
List.of(),
List.of(),
List.of(),
windowFrame);
// CASE WHEN count() over (ROWS (windowSize-1) PRECEDING) > windowSize - 1
RexNode whenConditionExpr =
PPLFuncImpTable.INSTANCE.resolve(
context.rexBuilder,
">",
countExpr,
context.relBuilder.literal(trendlineComputation.getNumberOfDataPoints() - 1));

RexNode thenExpr;
switch (trendlineComputation.getComputationType()) {
case TrendlineType.SMA:
// THEN avg(field) over (ROWS (windowSize-1) PRECEDING)
thenExpr =
PlanUtils.makeOver(
context,
BuiltinFunctionName.AVG,
field,
List.of(),
List.of(),
List.of(),
windowFrame);
break;
case TrendlineType.WMA:
// THEN wma expression
thenExpr =
buildWmaRexNode(
field,
trendlineComputation.getNumberOfDataPoints(),
windowFrame,
context);
break;
default:
throw new IllegalStateException("Unsupported trendline type");
}

// ELSE NULL
RexNode elseExpr = context.relBuilder.literal(null);

List<RexNode> caseOperands = new ArrayList<>();
caseOperands.add(whenConditionExpr);
caseOperands.add(thenExpr);
caseOperands.add(elseExpr);
RexNode trendlineNode =
context.rexBuilder.makeCall(SqlStdOperatorTable.CASE, caseOperands);
trendlineNodes.add(trendlineNode);
aliases.add(trendlineComputation.getAlias());
});

projectPlusOverriding(trendlineNodes, aliases, context);
return context.relBuilder.peek();
}

private RexNode buildWmaRexNode(
RexNode field,
Integer numberOfDataPoints,
WindowFrame windowFrame,
CalcitePlanContext context) {

// Divisor: 1 + 2 + 3 + ... + windowSize, aka (windowSize * (windowSize + 1) / 2)
RexNode divisor = context.relBuilder.literal(numberOfDataPoints * (numberOfDataPoints + 1) / 2);

// Divider: 1 * NTH_VALUE(field, 1) + 2 * NTH_VALUE(field, 2) + ... + windowSize *
// NTH_VALUE(field, windowSize)
RexNode divider = context.relBuilder.literal(0);
for (int i = 1; i <= numberOfDataPoints; i++) {
RexNode nthValueExpr =
PlanUtils.makeOver(
context,
BuiltinFunctionName.NTH_VALUE,
field,
List.of(context.relBuilder.literal(i)),
List.of(),
List.of(),
windowFrame);
divider =
context.relBuilder.call(
SqlStdOperatorTable.PLUS,
divider,
context.relBuilder.call(
SqlStdOperatorTable.MULTIPLY, nthValueExpr, context.relBuilder.literal(i)));
}
// Divider / CAST(Divisor, DOUBLE)
return context.relBuilder.call(
SqlStdOperatorTable.DIVIDE, divider, context.relBuilder.cast(divisor, SqlTypeName.DOUBLE));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ static RexNode makeOver(
true,
lowerBound,
upperBound);
case NTH_VALUE:
return withOver(
context.relBuilder.aggregateCall(SqlStdOperatorTable.NTH_VALUE, field, argList.get(0)),
partitions,
orderKeys,
true,
lowerBound,
upperBound);
default:
return withOver(
makeAggCall(context, functionName, false, field, argList),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ public enum BuiltinFunctionName {
IS_BLANK(FunctionName.of("isblank")),

ROW_NUMBER(FunctionName.of("row_number")),
NTH_VALUE(FunctionName.of("nth_value")),
RANK(FunctionName.of("rank")),
DENSE_RANK(FunctionName.of("dense_rank")),

Expand Down
69 changes: 57 additions & 12 deletions docs/user/ppl/cmd/trendline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,36 @@ Description

Syntax
============
`TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...`
`TRENDLINE [sort <[+|-] sort-field>] [SMA|WMA](number-of-datapoints, field) [AS alias] [[SMA|WMA](number-of-datapoints, field) [AS alias]]...`

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this command later enforce sorting either on an implicit timestamp field?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this command later enforce sorting either on an implicit timestamp field?

We could wait for more user feedbacks. Seems more than one command could be affected with an implicit timestamp field.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it could depend on customer feature request.


* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first.
* sort-field: mandatory when sorting is used. The field used to sort.
* number-of-datapoints: mandatory. The number of datapoints to calculate the moving average (must be greater than zero).
* field: mandatory. The name of the field the moving average should be calculated for.
* alias: optional. The name of the resulting column containing the moving average (defaults to the field name with "_trendline").

At the moment only the Simple Moving Average (SMA) type is supported.
Starting with version 3.1.0, two trendline algorithms are supported, aka Simple Moving Average (SMA) and Weighted Moving Average (WMA).

It is calculated like
Suppose:

f[i]: The value of field 'f' in the i-th data-point
n: The number of data-points in the moving window (period)
t: The current time index
* f[i]: The value of field 'f' in the i-th data-point
* n: The number of data-points in the moving window (period)
* t: The current time index

SMA is calculated like

SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t

Example 1: Calculate the moving average on one field.
WMA places more weights on recent values compared to equal-weighted SMA algorithm

WMA(t) = (1/(1 + 2 + ... + n)) * Σ(1 * f[i-n+1] + 2 * f[t-n+2] + ... + n * f[t])
= (2/(n * (n + 1))) * Σ((i - t + n) * f[i]), where i = t-n+1 to t


Example 1: Calculate the simple moving average on one field.
=====================================================

The example shows how to calculate the moving average on one field.
The example shows how to calculate the simple moving average on one field.

PPL query::

Expand All @@ -52,10 +60,10 @@ PPL query::
+------+


Example 2: Calculate the moving average on multiple fields.
Example 2: Calculate the simple moving average on multiple fields.
===========================================================

The example shows how to calculate the moving average on multiple fields.
The example shows how to calculate the simple moving average on multiple fields.

PPL query::

Expand All @@ -70,10 +78,10 @@ PPL query::
| 15.5 | 30.5 |
+------+-----------+

Example 4: Calculate the moving average on one field without specifying an alias.
Example 3: Calculate the simple moving average on one field without specifying an alias.
=================================================================================

The example shows how to calculate the moving average on one field.
The example shows how to calculate the simple moving average on one field.

PPL query::

Expand All @@ -88,3 +96,40 @@ PPL query::
| 15.5 |
+--------------------------+

Example 4: Calculate the weighted moving average on one field.
=================================================================================

Version
-------
3.1.0

Configuration
-------------
wma algorithm requires Calcite enabled.

Enable Calcite:

>> curl -H 'Content-Type: application/json' -X PUT localhost:9200/_plugins/_query/settings -d '{
"persistent" : {
"plugins.calcite.enabled" : true
}
}'

The example shows how to calculate the weighted moving average on one field.

PPL query::

PPL> source=accounts | trendline wma(2, account_number) | fields account_number_trendline;
fetched rows / total rows = 4/4
+--------------------------+
| account_number_trendline |
|--------------------------|
| null |
| 4.333333333333333 |
| 10.666666666666666 |
| 16.333333333333332 |
+--------------------------+

Limitation
==========
Starting with version 3.1.0, the ``trendline`` command requires all values in the specified ``field`` to be non-null. Any rows with null values present in the calculation field will be automatically excluded from the command's output.
4 changes: 2 additions & 2 deletions integ-test/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,8 @@ integTest {

dependsOn ':opensearch-sql-plugin:bundlePlugin'
if(getOSFamilyType() != "windows") {
dependsOn startPrometheus
finalizedBy stopPrometheus
// dependsOn startPrometheus
// finalizedBy stopPrometheus
}

// enable calcite codegen in IT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,6 @@ public void init() throws Exception {
disallowCalciteFallback();
}

@Override
public void testTrendlinePushDownExplain() throws Exception {
withFallbackEnabled(
() -> {
try {
super.testTrendlinePushDownExplain();
} catch (Exception e) {
throw new RuntimeException(e);
}
},
"https://github.com/opensearch-project/sql/issues/3466");
}

@Override
public void testTrendlineWithSortPushDownExplain() throws Exception {
withFallbackEnabled(
() -> {
try {
super.testTrendlineWithSortPushDownExplain();
} catch (Exception e) {
throw new RuntimeException(e);
}
},
"https://github.com/opensearch-project/sql/issues/3466");
}

@Override
@Ignore("test only in v2")
public void testExplainModeUnsupportedInV2() throws IOException {}
Expand Down
Loading
Loading