Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/ppl-lang/PPL-Example-Commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ _- **Limitation: new field added by eval command with a function cannot be dropp
- `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10'
- `source = table | where cidrmatch(ip, '192.169.1.0/24')`
- `source = table | where cidrmatch(ipv6, '2003:db8::/32')`
- `source = table | trendline sma(2, temperature) as temp_trend`

```sql
source = table | eval status_category =
Expand Down
3 changes: 2 additions & 1 deletion docs/ppl-lang/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md).
- [`subquery commands`](ppl-subquery-command.md)

- [`correlation commands`](ppl-correlation-command.md)


- [`trendline commands`](ppl-trendline-command.md)

* **Functions**

Expand Down
60 changes: 60 additions & 0 deletions docs/ppl-lang/ppl-trendline-command.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
## PPL trendline Command

**Description**
Using ``trendline`` command to calculate moving averages of fields.


### Syntax
`TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...`

* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first.
* sort-field: mandatory when sorting is used. The field used to sort.
* number-of-datapoints: mandatory. number of datapoints to calculate the moving average (must be greater than zero).
* field: mandatory. the name of the field the moving average should be calculated for.
* alias: optional. the name of the resulting column containing the moving average.

And the moment only the Simple Moving Average (SMA) type is supported.

It is calculated like

f[i]: The value of field 'f' in the i-th data-point
n: The number of data-points in the moving window (period)
t: The current time index

SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t

### Example 1: Calculate simple moving average for a timeseries of temperatures

The example calculates the simple moving average over temperatures using two datapoints.

PPL query:

os> source=t | trendline sma(2, temperature) as temp_trend;
fetched rows / total rows = 5/5
+-----------+---------+--------------------+----------+
|temperature|device-id| timestamp|temp_trend|
+-----------+---------+--------------------+----------+
| 12| 1492|2023-04-06 17:07:...| NULL|
| 12| 1492|2023-04-06 17:07:...| 12.0|
| 13| 256|2023-04-06 17:07:...| 12.5|
| 14| 257|2023-04-06 17:07:...| 13.5|
| 15| 258|2023-04-06 17:07:...| 14.5|
+-----------+---------+--------------------+----------+

### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting

The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id.

PPL query:

os> source=t | trendline sort - device-id sma(2, temperature) as temp_trend_2 sma(3, temperature) as temp_trend_3;
fetched rows / total rows = 5/5
+-----------+---------+--------------------+------------+------------------+
|temperature|device-id| timestamp|temp_trend_2| temp_trend_3|
+-----------+---------+--------------------+------------+------------------+
| 15| 258|2023-04-06 17:07:...| NULL| NULL|
| 14| 257|2023-04-06 17:07:...| 14.5| NULL|
| 13| 256|2023-04-06 17:07:...| 13.5| 14.0|
| 12| 1492|2023-04-06 17:07:...| 12.5| 13.0|
| 12| 1492|2023-04-06 17:07:...| 12.0|12.333333333333334|
+-----------+---------+--------------------+------------+------------------+
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLTrendlineITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedStateCountryTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("test trendline sma command without fields command and without alias") {
val frame = sql(s"""
| source = $testTable | sort - age | trendline sma(2, age)
| """.stripMargin)

assert(
frame.columns.sameElements(
Array("name", "age", "state", "country", "year", "month", "age_trendline")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Jake", 70, "California", "USA", 2023, 4, null),
Row("Hello", 30, "New York", "USA", 2023, 4, 50.0),
Row("John", 25, "Ontario", "Canada", 2023, 4, 27.5),
Row("Jane", 20, "Quebec", "Canada", 2023, 4, 22.5))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val ageField = UnresolvedAttribute("age")
val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table)
val countWindow = new WindowExpression(
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val smaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), smaWindow)
val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_trendline")())
val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sort))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test trendline sma command with fields command") {
val frame = sql(s"""
| source = $testTable | trendline sort - age sma(3, age) as age_sma | fields name, age, age_sma
| """.stripMargin)

assert(frame.columns.sameElements(Array("name", "age", "age_sma")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Jake", 70, null),
Row("Hello", 30, null),
Row("John", 25, 41.666666666666664),
Row("Jane", 20, 25))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val nameField = UnresolvedAttribute("name")
val ageField = UnresolvedAttribute("age")
val ageSmaField = UnresolvedAttribute("age_sma")
val sort = Sort(Seq(SortOrder(ageField, Descending)), global = true, table)
val countWindow = new WindowExpression(
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)))
val smaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)))
val caseWhen = CaseWhen(Seq((LessThan(countWindow, Literal(3)), Literal(null))), smaWindow)
val trendlineProjectList = Seq(UnresolvedStar(None), Alias(caseWhen, "age_sma")())
val expectedPlan =
Project(Seq(nameField, ageField, ageSmaField), Project(trendlineProjectList, sort))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test multiple trendline sma commands") {
val frame = sql(s"""
| source = $testTable | trendline sort + age sma(2, age) as two_points_sma sma(3, age) as three_points_sma | fields name, age, two_points_sma, three_points_sma
| """.stripMargin)

assert(frame.columns.sameElements(Array("name", "age", "two_points_sma", "three_points_sma")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Jane", 20, null, null),
Row("John", 25, 22.5, null),
Row("Hello", 30, 27.5, 25.0),
Row("Jake", 70, 50.0, 41.666666666666664))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val nameField = UnresolvedAttribute("name")
val ageField = UnresolvedAttribute("age")
val ageTwoPointsSmaField = UnresolvedAttribute("two_points_sma")
val ageThreePointsSmaField = UnresolvedAttribute("three_points_sma")
val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, table)
val twoPointsCountWindow = new WindowExpression(
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val twoPointsSmaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val threePointsCountWindow = new WindowExpression(
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)))
val threePointsSmaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(ageField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-2), CurrentRow)))
val twoPointsCaseWhen = CaseWhen(
Seq((LessThan(twoPointsCountWindow, Literal(2)), Literal(null))),
twoPointsSmaWindow)
val threePointsCaseWhen = CaseWhen(
Seq((LessThan(threePointsCountWindow, Literal(3)), Literal(null))),
threePointsSmaWindow)
val trendlineProjectList = Seq(
UnresolvedStar(None),
Alias(twoPointsCaseWhen, "two_points_sma")(),
Alias(threePointsCaseWhen, "three_points_sma")())
val expectedPlan = Project(
Seq(nameField, ageField, ageTwoPointsSmaField, ageThreePointsSmaField),
Project(trendlineProjectList, sort))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test trendline sma command on evaluated column") {
val frame = sql(s"""
| source = $testTable | eval doubled_age = age * 2 | trendline sort + age sma(2, doubled_age) as doubled_age_sma | fields name, doubled_age, doubled_age_sma
| """.stripMargin)

assert(frame.columns.sameElements(Array("name", "doubled_age", "doubled_age_sma")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Jane", 40, null),
Row("John", 50, 45.0),
Row("Hello", 60, 55.0),
Row("Jake", 140, 100.0))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val nameField = UnresolvedAttribute("name")
val ageField = UnresolvedAttribute("age")
val doubledAgeField = UnresolvedAttribute("doubled_age")
val doubledAgeSmaField = UnresolvedAttribute("doubled_age_sma")
val evalProject = Project(
Seq(
UnresolvedStar(None),
Alias(
UnresolvedFunction("*", Seq(ageField, Literal(2)), isDistinct = false),
"doubled_age")()),
table)
val sort = Sort(Seq(SortOrder(ageField, Ascending)), global = true, evalProject)
val countWindow = new WindowExpression(
UnresolvedFunction("COUNT", Seq(Literal(1)), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val doubleAgeSmaWindow = WindowExpression(
UnresolvedFunction("AVG", Seq(doubledAgeField), isDistinct = false),
WindowSpecDefinition(Seq(), Seq(), SpecifiedWindowFrame(RowFrame, Literal(-1), CurrentRow)))
val caseWhen =
CaseWhen(Seq((LessThan(countWindow, Literal(2)), Literal(null))), doubleAgeSmaWindow)
val trendlineProjectList =
Seq(UnresolvedStar(None), Alias(caseWhen, "doubled_age_sma")())
val expectedPlan = Project(
Seq(nameField, doubledAgeField, doubledAgeSmaField),
Project(trendlineProjectList, sort))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test trendline sma command chaining") {
val frame = sql(s"""
| source = $testTable | eval age_1 = age, age_2 = age | trendline sort - age_1 sma(3, age_1) | trendline sort + age_2 sma(3, age_2)
| """.stripMargin)

assert(
frame.columns.sameElements(
Array(
"name",
"age",
"state",
"country",
"year",
"month",
"age_1",
"age_2",
"age_1_trendline",
"age_2_trendline")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("Hello", 30, "New York", "USA", 2023, 4, 30, 30, null, 25.0),
Row("Jake", 70, "California", "USA", 2023, 4, 70, 70, null, 41.666666666666664),
Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20, 20, 25.0, null),
Row("John", 25, "Ontario", "Canada", 2023, 4, 25, 25, 41.666666666666664, null))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))
}
}
4 changes: 4 additions & 0 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ AD: 'AD';
ML: 'ML';
FILLNULL: 'FILLNULL';
FLATTEN: 'FLATTEN';
TRENDLINE: 'TRENDLINE';

//Native JOIN KEYWORDS
JOIN: 'JOIN';
Expand Down Expand Up @@ -90,6 +91,9 @@ FIELDSUMMARY: 'FIELDSUMMARY';
INCLUDEFIELDS: 'INCLUDEFIELDS';
NULLS: 'NULLS';

//TRENDLINE KEYWORDS
SMA: 'SMA';

// ARGUMENT KEYWORDS
KEEPEMPTY: 'KEEPEMPTY';
CONSECUTIVE: 'CONSECUTIVE';
Expand Down
14 changes: 14 additions & 0 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ commands
| fillnullCommand
| fieldsummaryCommand
| flattenCommand
| trendlineCommand
;

commandName
Expand Down Expand Up @@ -84,6 +85,7 @@ commandName
| FILLNULL
| FIELDSUMMARY
| FLATTEN
| TRENDLINE
Comment thread
salyh marked this conversation as resolved.
;

searchCommand
Expand Down Expand Up @@ -252,6 +254,17 @@ flattenCommand
: FLATTEN fieldExpression
;

trendlineCommand
: TRENDLINE (SORT sortField)? trendlineClause (trendlineClause)*
Comment thread
salyh marked this conversation as resolved.
;

trendlineClause
: trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)?
;

trendlineType
: SMA
;

kmeansCommand
: KMEANS (kmeansParameter)*
Expand Down Expand Up @@ -1131,4 +1144,5 @@ keywordsCanBeId
| ANTI
| BETWEEN
| CIDRMATCH
| trendlineType
;
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ public T visitLookup(Lookup node, C context) {
return visitChildren(node, context);
}

public T visitTrendline(Trendline node, C context) {
return visitChildren(node, context);
}

public T visitCorrelation(Correlation node, C context) {
return visitChildren(node, context);
}
Expand Down
Loading