From 4b8484fb1226b9588aa415c4bc479900560b9299 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 6 Nov 2024 15:43:32 -0800 Subject: [PATCH 01/12] WMA implementation Signed-off-by: Andy Kwok --- .../ppl/FlintSparkPPLTrendlineITSuite.scala | 235 +++++++++++++++++- .../src/main/antlr4/OpenSearchPPLLexer.g4 | 1 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 2 +- .../opensearch/sql/ast/tree/Trendline.java | 2 +- .../function/BuiltinFunctionName.java | 1 + .../sql/ppl/CatalystQueryPlanVisitor.java | 2 +- .../sql/ppl/utils/TrendlineCatalystUtils.java | 180 ++++++++++++-- ...nTrendlineCommandTranslatorTestSuite.scala | 118 ++++++++- 8 files changed, 511 insertions(+), 30 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala index bc4463537..fb22bc7c4 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -7,9 +7,11 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Ascending, CaseWhen, CurrentRow, Descending, Divide, Expression, LessThan, Literal, Multiply, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.opensearch.sql.ppl.utils.SortUtils class FlintSparkPPLTrendlineITSuite extends QueryTest @@ -244,4 +246,235 @@ class FlintSparkPPLTrendlineITSuite implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) } + + test("test trendline wma command with sort field and without alias") { + val frame = sql(s""" + | source = $testTable | trendline sort + age wma(3, age) + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "age_trendline"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec", "Canada", 2023, 4, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 26.666666666666668), + Row("Jake", 70, "California", "USA", 2023, 4, 49.166666666666664)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical +// // scalastyle:off +// println(logicalPlan.toString()) +// // scalastyle:on println + + val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val divisor = Literal(6) + val wmaExpression = Divide(dividend, divisor) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, unresolvedRelation) + val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + /** + * Expected logical plan: + * 'Project [*] + * +- 'Project [*, ((( + * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + + * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS age_trendline#185] + * +- 'Sort ['age ASC NULLS FIRST], true + * +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test trendline wma command with sort field and with alias") { + val frame = sql(s""" + | source = $testTable | trendline sort + age wma(3, age) as trendline_alias + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "trendline_alias"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec", "Canada", 2023, 4, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 26.666666666666668), + Row("Jake", 70, "California", "USA", 2023, 4, 49.166666666666664)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + + val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val divisor = Literal(6) + val wmaExpression = Divide(dividend, divisor) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "trendline_alias")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, unresolvedRelation) + val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + + // scalastyle:off + println(logicalPlan.toString()) + println(expectedPlan.toString()) + // scalastyle:on println + + /** + * 'Project [*] + * +- 'Project [*, ((( + * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + + * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS trendline_alias#185] + * +- 'Sort ['age ASC NULLS FIRST], true + * +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test multiple trendline wma commands") { + val frame = sql(s""" + | source = $testTable | trendline sort + age wma(2, age) as two_points_wma wma(3, age) as three_points_wma + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array("name", "age", "state", "country", "year", "month", "two_points_wma", "three_points_wma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 20, "Quebec", "Canada", 2023, 4, null, null), + Row("John", 25, "Ontario", "Canada", 2023, 4, 23.333333333333332, null), + Row("Hello", 30, "New York", "USA", 2023, 4, 28.333333333333332, 26.666666666666668), + Row("Jake", 70, "California", "USA", 2023, 4, 56.666666666666664, 49.166666666666664)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // TBC The logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // scalastyle:off + println(logicalPlan.toString()) + // scalastyle:on println + + val dividendTwo = Add(getNthValueAggregation("age", "age", 1, -1), + getNthValueAggregation("age", "age", 2, -1)) + val twoPointsExpression = Divide(dividendTwo, Literal(3)) + + val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val threePointsExpression = Divide(dividend, Literal(6)) + + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(twoPointsExpression, "two_points_wma")(), Alias(threePointsExpression, "three_points_wma")()) + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, unresolvedRelation) + val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + /** + * 'Project [*] + * +- 'Project [*, (( + * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + + * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS two_points_wma#247, + * + * ((( + * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + + * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#248] + * +- 'Sort ['age ASC NULLS FIRST], true + * +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test rendline wma command on evaluated column") { + val frame = sql(s""" + | source = $testTable | eval doubled_age = age * 2 | trendline sort + age wma(2, doubled_age) as doubled_age_wma | fields name, doubled_age, doubled_age_wma + | """.stripMargin) + + assert( + frame.columns.sameElements( + Array("name", "doubled_age", "doubled_age_wma"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("Jane", 40, null), + Row("John", 50, 46.666666666666664), + Row("Hello", 60, 56.666666666666664), + Row("Jake", 140, 113.33333333333333)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // TBC The logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val dividend = Add(getNthValueAggregation("doubled_age", "age", 1, -1), + getNthValueAggregation("doubled_age", "age", 2, -1)) + val wmaExpression = Divide(dividend, Literal(3)) + val trendlineProjectList = Seq(UnresolvedStar(None), + Alias(wmaExpression, "doubled_age_wma")()) + + val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) + + + val doubledAged = Alias(UnresolvedFunction(seq("*"), seq(UnresolvedAttribute("age"), Literal(2)), isDistinct = false) , "doubled_age")() + val doubleAgeProject = Project(seq(UnresolvedStar(None), doubledAged), unresolvedRelation) + + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, + doubleAgeProject) + + val expectedPlan = Project( + Seq(UnresolvedAttribute("name"),UnresolvedAttribute("doubled_age"),UnresolvedAttribute("doubled_age_wma")), + Project(trendlineProjectList, sortedTable )) + + + /** + * + 'Project ['name, 'doubled_age, 'doubled_age_wma] + +- 'Project [*, (( + ('nth_value('doubled_age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + + ('nth_value('doubled_age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS doubled_age_wma#288] + +- 'Sort ['age ASC NULLS FIRST], true + +- 'Project [*, '`*`('age, 2) AS doubled_age#287] + +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + + */ + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + + } + + private def getNthValueAggregation(dataField: String, sortField: String, lookBackPos: Int, lookBackRange: Int): Expression = { + Multiply( + WindowExpression( + UnresolvedFunction("nth_value", Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), isDistinct = false), + WindowSpecDefinition( + Seq(), + seq(SortUtils.sortOrder(UnresolvedAttribute(sortField), true)), + SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow) + )), + Literal(lookBackPos)) + } + } diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 10b2e01b8..3ce8b6f1e 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -96,6 +96,7 @@ NULLS: 'NULLS'; //TRENDLINE KEYWORDS SMA: 'SMA'; +WMA: 'WMA'; // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index e44964c72..24cdc21b0 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -271,7 +271,7 @@ trendlineClause ; trendlineType - : SMA + : (SMA | WMA) ; kmeansCommand diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java index 9fa1ae81d..d08e89e3b 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -62,6 +62,6 @@ public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dat } public enum TrendlineType { - SMA + SMA, WMA } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index f039bf47f..e232c3668 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -293,6 +293,7 @@ public enum BuiltinFunctionName { WILDCARDQUERY(FunctionName.of("wildcardquery")), WILDCARD_QUERY(FunctionName.of("wildcard_query")), + NTH_VALUE(FunctionName.of("nth_value")), COALESCE(FunctionName.of("coalesce")); private FunctionName name; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 00a7905f0..debd37376 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -245,7 +245,7 @@ public LogicalPlan visitTrendline(Trendline node, CatalystPlanContext context) { trendlineProjectExpressions.add(UnresolvedStar$.MODULE$.apply(Option.empty())); } - trendlineProjectExpressions.addAll(TrendlineCatalystUtils.visitTrendlineComputations(expressionAnalyzer, node.getComputations(), context)); + trendlineProjectExpressions.addAll(TrendlineCatalystUtils.visitTrendlineComputations(expressionAnalyzer, node.getComputations(), node.getSortByField(), context)); return context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(seq(trendlineProjectExpressions), p)); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java index 67603ccc7..05ac40988 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java @@ -6,8 +6,7 @@ package org.opensearch.sql.ppl.utils; import org.apache.spark.sql.catalyst.expressions.*; -import org.opensearch.sql.ast.expression.AggregateFunction; -import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.*; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.expression.function.BuiltinFunctionName; @@ -16,20 +15,26 @@ import scala.Option; import scala.Tuple2; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; public interface TrendlineCatalystUtils { - static List visitTrendlineComputations(CatalystExpressionVisitor expressionVisitor, List computations, CatalystPlanContext context) { + + static List visitTrendlineComputations(CatalystExpressionVisitor expressionVisitor, List computations, Optional sortField, CatalystPlanContext context) { return computations.stream() - .map(computation -> visitTrendlineComputation(expressionVisitor, computation, context)) + .map(computation -> visitTrendlineComputation(expressionVisitor, computation, sortField, context)) .collect(Collectors.toList()); } - static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expressionVisitor, Trendline.TrendlineComputation node, CatalystPlanContext context) { + + static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expressionVisitor, Trendline.TrendlineComputation node, Optional sortField, CatalystPlanContext context) { + //window lower boundary expressionVisitor.visitLiteral(new Literal(Math.negateExact(node.getNumberOfDataPoints() - 1), DataType.INTEGER), context); Expression windowLowerBoundary = context.popNamedParseExpressions().get(); @@ -40,26 +45,28 @@ static NamedExpression visitTrendlineComputation(CatalystExpressionVisitor expre seq(), new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); - if (node.getComputationType() == Trendline.TrendlineType.SMA) { - //calculate avg value of the data field - expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); - Expression avgFunction = context.popNamedParseExpressions().get(); - - //sma window - WindowExpression sma = new WindowExpression( - avgFunction, - windowDefinition); - - CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context); - - return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(smaOrNull, - node.getAlias(), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList())); - } else { - throw new IllegalArgumentException(node.getComputationType()+" is not supported"); + switch (node.getComputationType()) { + case SMA: + //calculate avg value of the data field + expressionVisitor.visitAggregateFunction(new AggregateFunction(BuiltinFunctionName.AVG.name(), node.getDataField()), context); + Expression avgFunction = context.popNamedParseExpressions().get(); + + //sma window + WindowExpression sma = new WindowExpression( + avgFunction, + windowDefinition); + + CaseWhen smaOrNull = trendlineOrNullWhenThereAreTooFewDataPoints(expressionVisitor, sma, node, context); + + return getAlias(node.getAlias(), smaOrNull); + case WMA: + if (sortField.isPresent()) { + return getWMAComputationExpression(expressionVisitor, node, sortField.get(), context); + } else { + throw new IllegalArgumentException(node.getComputationType()+" requires a sort field for computation"); + } + default: + throw new IllegalArgumentException(node.getComputationType()+" is not supported"); } } @@ -84,4 +91,127 @@ private static CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(CatalystExpr ); return new CaseWhen(seq(nullWhenNumberOfDataPointsLessThenRequired), Option.apply(trendlineWindow)); } + + /** + * Responsible to produce a Spark Logical Plan with given TrendLine command arguments, below is the sample logical plan + * with configuration [dataField=salary, sortField=age, dataPoints=3] + * -- +- 'Project [ + * -- (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + + * -- ('nth_value('salary, 2) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) + + * -- ('nth_value('salary, 3) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 3)) / 6) + * -- AS WMA#702] + * + * @param visitor Visitor instance to process any UnresolvedExpression. + * @param node Trendline command's arguments. + * @param sortField Field used for window aggregation. + * @param context Context instance to retrieved Expression in resolved form. + * @return a NamedExpression instance which will calculate WMA with provided argument. + */ + private static NamedExpression getWMAComputationExpression(CatalystExpressionVisitor visitor, + Trendline.TrendlineComputation node, + Field sortField, + CatalystPlanContext context) { + + //window lower boundary + Expression windowLowerBoundary = getIntExpression(visitor, context, + Math.negateExact(node.getNumberOfDataPoints() - 1)); + //window definition + visitor.analyze(sortField, context); + Expression sortDefinition = context.popNamedParseExpressions().get(); + WindowSpecDefinition windowDefinition = getCommonWindowDefinition( + sortDefinition, + SortUtils.isSortedAscending(sortField), + windowLowerBoundary); + // Divisor + Expression divider = getIntExpression(visitor, context, + (node.getNumberOfDataPoints() * (node.getNumberOfDataPoints()+1) / 2)); + // Aggregation + Expression WMAExpression = getNthValueAggregations(visitor, node, context, windowDefinition, + node.getNumberOfDataPoints()) + .stream() + .reduce(Add::new) + .orElse(null); + + return getAlias(node.getAlias(), new Divide(WMAExpression, divider)); + } + + /** + * Helper method to produce an Alias Expression with provide value and name. + * @param name The name for the Alias. + * @param expression The expression which will be evaluated. + * @return A Alias instance with logical plan representation of `expression AS name`. + */ + private static NamedExpression getAlias(String name, Expression expression) { + return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(expression, + name, + NamedExpression.newExprId(), + seq(Collections.emptyList()), + Option.empty(), + seq(Collections.emptyList())); + } + + /** + * Helper method to retrieve an Int in expression form for logical plan composition purpose. + * @param expressionVisitor Visitor instance to process the incoming object. + * @param context Context instance to retrieve the Expression instance. + * @param i Target value for the expression. + * @return An expression object which contain integer value i. + */ + static Expression getIntExpression(CatalystExpressionVisitor expressionVisitor, CatalystPlanContext context, int i) { + expressionVisitor.visitLiteral(new Literal(i, + DataType.INTEGER), context); + return context.popNamedParseExpressions().get(); + } + + + /** + * Helper method to retrieve a WindowSpecDefinition with provided sorting condition. + * `windowspecdefinition('sortField ascending NULLS FIRST, specifiedwindowframe(RowFrame, windowLowerBoundary, currentrow$())` + * @param sortField The field being used for the sorting operation. + * @param ascending The boolean instance for the sorting order. + * @param windowLowerBoundary The Integer expression instance which specify the even lookbehind / lookahead. + * @return A WindowSpecDefinition instance which will be used to composite the WMA calculation. + */ + static WindowSpecDefinition getCommonWindowDefinition(Expression sortField, boolean ascending, Expression windowLowerBoundary) { + return new WindowSpecDefinition( + seq(), + seq(SortUtils.sortOrder(sortField, ascending)), + new SpecifiedWindowFrame(RowFrame$.MODULE$, windowLowerBoundary, CurrentRow$.MODULE$)); + } + + /** + * To produce a list of Expression with responsible to return appropriate lookbehind / lookahead value for WMA calculation, sample logical plan listed below. + * (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * + * @param visitor Visitor instance to resolve Expression. + * @param node Treeline command instruction. + * @param context Context instance to retrieve the resolved expression. + * @param windowDefinition The windowDefinition for the individual datapoint lookbehind / lookahead. + * @param dataPoints Number of data-points for WMA calculation, this will always equal to number of Expression being generated. + * @return List instance which contain the SQL statement for WMA individual datapoint's calculations. + */ + private static List getNthValueAggregations(CatalystExpressionVisitor visitor, + Trendline.TrendlineComputation node, + CatalystPlanContext context, + WindowSpecDefinition windowDefinition, + int dataPoints) { + + List expressions = new ArrayList<>(); + for (int i = 1; i <= dataPoints; i++) { + // Get the offset parameter + Expression offSetExpression = getIntExpression(visitor, context, i); + + // Composite the nth_value expression. + Function func = new Function(BuiltinFunctionName.NTH_VALUE.name(), + List.of(node.getDataField(), new Literal(i, DataType.INTEGER))); + + visitor.visitFunction(func, context); + Expression nthValueExp = context.popNamedParseExpressions().get(); + + expressions.add(new Multiply( + new WindowExpression(nthValueExp, windowDefinition), offSetExpression)); + } + return expressions; + } + } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala index d22750ee0..00dbbc574 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -7,11 +7,13 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.opensearch.sql.ppl.utils.SortUtils import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, CaseWhen, CurrentRow, Descending, LessThan, Literal, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Ascending, CaseWhen, CurrentRow, Descending, Divide, Expression, LessThan, Literal, Multiply, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} @@ -132,4 +134,118 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite Project(trendlineProjectList, sort)) comparePlans(logPlan, expectedPlan, checkAnalysis = false) } + + test("wma - with sort") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=relation | trendline sort age wma(3, age)"), context) + + val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val divisor = Literal(6) + val wmaExpression = Divide(dividend, divisor) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, UnresolvedRelation(Seq("relation"))) + val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + /** + * Expected logical plan: + * 'Project [*] + * !+- 'Project [*, ((( + * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + + * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS age_trendline#0] + * ! +- 'Sort ['age ASC NULLS FIRST], true + * ! +- 'UnresolvedRelation [relation], [], false + */ + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + } + + test("wma - with sort and alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=relation | trendline sort age wma(3, age) as TEST_CUSTOM_COLUMN"), context) + + val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val divisor = Literal(6) + val wmaExpression = Divide(dividend, divisor) + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "TEST_CUSTOM_COLUMN")()) + + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, + UnresolvedRelation(Seq("relation"))) + + /** + * Expected logical plan: + * 'Project [*] + * !+- 'Project [*, ((( + * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + + * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS TEST_CUSTOM_COLUMN#0] + * ! +- 'Sort ['age ASC NULLS FIRST], true + * ! +- 'UnresolvedRelation [relation], [], false + */ + val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + + } + + test("wma - multiple trendline commands") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=relation | trendline sort age wma(2, age) as two_points_wma wma(3, age) as three_points_wma"), context) + + val dividendTwo = Add(getNthValueAggregation("age", "age", 1, -1), + getNthValueAggregation("age", "age", 2, -1)) + val twoPointsExpression = Divide(dividendTwo, Literal(3)) + + val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) + val divisor = Literal(6) + val threePointsExpression = Divide(dividend, divisor) + + val trendlineProjectList = Seq(UnresolvedStar(None), Alias(twoPointsExpression, "two_points_wma")(), Alias(threePointsExpression, "three_points_wma")()) + + val sortedTable = Sort( + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, + UnresolvedRelation(Seq("relation"))) + + /** + * Expected logical plan: + * 'Project [*] + * +- 'Project [*, (( + * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + + * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS two_points_wma#0, + * + * ((( + * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + + * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#1] + * +- 'Sort ['age ASC NULLS FIRST], true + * +- 'UnresolvedRelation [relation], [], false + */ + val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + comparePlans(logPlan, expectedPlan, checkAnalysis = false) + + } + + + + private def getNthValueAggregation(dataField: String, sortField: String, lookBackPos: Int, lookBackRange: Int): Expression = { + Multiply( + WindowExpression( + UnresolvedFunction("nth_value", Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), isDistinct = false), + WindowSpecDefinition( + Seq(), + seq(SortUtils.sortOrder(UnresolvedAttribute(sortField), true)), + SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow) + )), + Literal(lookBackPos)) + } + + } From 647f9711ad29e63ecf8a5565e658832a30080545 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 6 Nov 2024 15:55:30 -0800 Subject: [PATCH 02/12] Update test cases Signed-off-by: Andy Kwok --- .../ppl/FlintSparkPPLTrendlineITSuite.scala | 40 +++++-------------- 1 file changed, 11 insertions(+), 29 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala index fb22bc7c4..7c10a6dd0 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -252,6 +252,7 @@ class FlintSparkPPLTrendlineITSuite | source = $testTable | trendline sort + age wma(3, age) | """.stripMargin) + // Compare the headers assert( frame.columns.sameElements( Array("name", "age", "state", "country", "year", "month", "age_trendline"))) @@ -268,16 +269,12 @@ class FlintSparkPPLTrendlineITSuite implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) + // Compare the logical plans val logicalPlan: LogicalPlan = frame.queryExecution.logical -// // scalastyle:off -// println(logicalPlan.toString()) -// // scalastyle:on println - val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) - val divisor = Literal(6) - val wmaExpression = Divide(dividend, divisor) + val wmaExpression = Divide(dividend, Literal(6)) val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) val sortedTable = Sort( @@ -301,6 +298,7 @@ class FlintSparkPPLTrendlineITSuite | source = $testTable | trendline sort + age wma(3, age) as trendline_alias | """.stripMargin) + // Compare the headers assert( frame.columns.sameElements( Array("name", "age", "state", "country", "year", "month", "trendline_alias"))) @@ -317,25 +315,18 @@ class FlintSparkPPLTrendlineITSuite implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) + // Compare the logical plans val logicalPlan: LogicalPlan = frame.queryExecution.logical - - val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) - val divisor = Literal(6) - val wmaExpression = Divide(dividend, divisor) + val wmaExpression = Divide(dividend, Literal(6)) val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "trendline_alias")()) val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) val sortedTable = Sort( Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, unresolvedRelation) val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) - // scalastyle:off - println(logicalPlan.toString()) - println(expectedPlan.toString()) - // scalastyle:on println - /** * 'Project [*] * +- 'Project [*, ((( @@ -353,6 +344,7 @@ class FlintSparkPPLTrendlineITSuite | source = $testTable | trendline sort + age wma(2, age) as two_points_wma wma(3, age) as three_points_wma | """.stripMargin) + // Compare the headers assert( frame.columns.sameElements( Array("name", "age", "state", "country", "year", "month", "two_points_wma", "three_points_wma"))) @@ -369,11 +361,8 @@ class FlintSparkPPLTrendlineITSuite implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // TBC The logical plan + // Compare the logical plans val logicalPlan: LogicalPlan = frame.queryExecution.logical - // scalastyle:off - println(logicalPlan.toString()) - // scalastyle:on println val dividendTwo = Add(getNthValueAggregation("age", "age", 1, -1), getNthValueAggregation("age", "age", 2, -1)) @@ -405,11 +394,12 @@ class FlintSparkPPLTrendlineITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } - test("test rendline wma command on evaluated column") { + test("test trendline wma command on evaluated column") { val frame = sql(s""" | source = $testTable | eval doubled_age = age * 2 | trendline sort + age wma(2, doubled_age) as doubled_age_wma | fields name, doubled_age, doubled_age_wma | """.stripMargin) + // Compare the headers assert( frame.columns.sameElements( Array("name", "doubled_age", "doubled_age_wma"))) @@ -426,30 +416,22 @@ class FlintSparkPPLTrendlineITSuite implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) - // TBC The logical plan + // Compare the logical plans val logicalPlan: LogicalPlan = frame.queryExecution.logical - val dividend = Add(getNthValueAggregation("doubled_age", "age", 1, -1), getNthValueAggregation("doubled_age", "age", 2, -1)) val wmaExpression = Divide(dividend, Literal(3)) val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "doubled_age_wma")()) - val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) - - val doubledAged = Alias(UnresolvedFunction(seq("*"), seq(UnresolvedAttribute("age"), Literal(2)), isDistinct = false) , "doubled_age")() val doubleAgeProject = Project(seq(UnresolvedStar(None), doubledAged), unresolvedRelation) - val sortedTable = Sort( Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, doubleAgeProject) - val expectedPlan = Project( Seq(UnresolvedAttribute("name"),UnresolvedAttribute("doubled_age"),UnresolvedAttribute("doubled_age_wma")), Project(trendlineProjectList, sortedTable )) - - /** * 'Project ['name, 'doubled_age, 'doubled_age_wma] From b0d09c7af5278b60e1a7d179c4dbdf3726ecf481 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 6 Nov 2024 15:58:15 -0800 Subject: [PATCH 03/12] Update tests Signed-off-by: Andy Kwok --- ...nTrendlineCommandTranslatorTestSuite.scala | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala index 00dbbc574..d49672ce4 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -135,7 +135,7 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite comparePlans(logPlan, expectedPlan, checkAnalysis = false) } - test("wma - with sort") { + test("WMA - with sort") { val context = new CatalystPlanContext val logPlan = planTransformer.visit(plan(pplParser, "source=relation | trendline sort age wma(3, age)"), context) @@ -143,8 +143,7 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) - val divisor = Literal(6) - val wmaExpression = Divide(dividend, divisor) + val wmaExpression = Divide(dividend, Literal(6)) val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) val sortedTable = Sort( Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, UnresolvedRelation(Seq("relation"))) @@ -162,7 +161,7 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite comparePlans(logPlan, expectedPlan, checkAnalysis = false) } - test("wma - with sort and alias") { + test("WMA - with sort and alias") { val context = new CatalystPlanContext val logPlan = planTransformer.visit(plan(pplParser, "source=relation | trendline sort age wma(3, age) as TEST_CUSTOM_COLUMN"), context) @@ -170,10 +169,8 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) - val divisor = Literal(6) - val wmaExpression = Divide(dividend, divisor) + val wmaExpression = Divide(dividend, Literal(6)) val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "TEST_CUSTOM_COLUMN")()) - val sortedTable = Sort( Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, UnresolvedRelation(Seq("relation"))) @@ -193,7 +190,7 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite } - test("wma - multiple trendline commands") { + test("WMA - multiple trendline commands") { val context = new CatalystPlanContext val logPlan = planTransformer.visit(plan(pplParser, "source=relation | trendline sort age wma(2, age) as two_points_wma wma(3, age) as three_points_wma"), context) @@ -205,15 +202,11 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) - val divisor = Literal(6) - val threePointsExpression = Divide(dividend, divisor) - + val threePointsExpression = Divide(dividend, Literal(6)) val trendlineProjectList = Seq(UnresolvedStar(None), Alias(twoPointsExpression, "two_points_wma")(), Alias(threePointsExpression, "three_points_wma")()) - val sortedTable = Sort( Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, UnresolvedRelation(Seq("relation"))) - /** * Expected logical plan: * 'Project [*] @@ -233,8 +226,6 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite } - - private def getNthValueAggregation(dataField: String, sortField: String, lookBackPos: Int, lookBackRange: Int): Expression = { Multiply( WindowExpression( From 99ed2b5ff8628dfb4193bb6c377e23bbb7209363 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 6 Nov 2024 16:36:31 -0800 Subject: [PATCH 04/12] Refactor code Signed-off-by: Andy Kwok --- .../sql/ppl/utils/TrendlineCatalystUtils.java | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java index 05ac40988..307041bd4 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java @@ -111,35 +111,34 @@ private static NamedExpression getWMAComputationExpression(CatalystExpressionVis Trendline.TrendlineComputation node, Field sortField, CatalystPlanContext context) { - + int dataPoints = node.getNumberOfDataPoints(); //window lower boundary Expression windowLowerBoundary = getIntExpression(visitor, context, - Math.negateExact(node.getNumberOfDataPoints() - 1)); + Math.negateExact(dataPoints - 1)); //window definition visitor.analyze(sortField, context); Expression sortDefinition = context.popNamedParseExpressions().get(); - WindowSpecDefinition windowDefinition = getCommonWindowDefinition( + WindowSpecDefinition windowDefinition = getWmaCommonWindowDefinition( sortDefinition, SortUtils.isSortedAscending(sortField), windowLowerBoundary); // Divisor - Expression divider = getIntExpression(visitor, context, - (node.getNumberOfDataPoints() * (node.getNumberOfDataPoints()+1) / 2)); + Expression divisor = getIntExpression(visitor, context, + (dataPoints * (dataPoints + 1) / 2)); // Aggregation - Expression WMAExpression = getNthValueAggregations(visitor, node, context, windowDefinition, - node.getNumberOfDataPoints()) + Expression WMAExpression = getNthValueAggregations(visitor, node, context, windowDefinition, dataPoints) .stream() .reduce(Add::new) .orElse(null); - return getAlias(node.getAlias(), new Divide(WMAExpression, divider)); + return getAlias(node.getAlias(), new Divide(WMAExpression, divisor)); } /** * Helper method to produce an Alias Expression with provide value and name. * @param name The name for the Alias. * @param expression The expression which will be evaluated. - * @return A Alias instance with logical plan representation of `expression AS name`. + * @return An Alias instance with logical plan representation of `expression AS name`. */ private static NamedExpression getAlias(String name, Expression expression) { return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(expression, @@ -151,7 +150,7 @@ private static NamedExpression getAlias(String name, Expression expression) { } /** - * Helper method to retrieve an Int in expression form for logical plan composition purpose. + * Helper method to retrieve an Int expression instance for logical plan composition purpose. * @param expressionVisitor Visitor instance to process the incoming object. * @param context Context instance to retrieve the Expression instance. * @param i Target value for the expression. @@ -167,12 +166,13 @@ static Expression getIntExpression(CatalystExpressionVisitor expressionVisitor, /** * Helper method to retrieve a WindowSpecDefinition with provided sorting condition. * `windowspecdefinition('sortField ascending NULLS FIRST, specifiedwindowframe(RowFrame, windowLowerBoundary, currentrow$())` + * * @param sortField The field being used for the sorting operation. * @param ascending The boolean instance for the sorting order. * @param windowLowerBoundary The Integer expression instance which specify the even lookbehind / lookahead. * @return A WindowSpecDefinition instance which will be used to composite the WMA calculation. */ - static WindowSpecDefinition getCommonWindowDefinition(Expression sortField, boolean ascending, Expression windowLowerBoundary) { + static WindowSpecDefinition getWmaCommonWindowDefinition(Expression sortField, boolean ascending, Expression windowLowerBoundary) { return new WindowSpecDefinition( seq(), seq(SortUtils.sortOrder(sortField, ascending)), @@ -180,7 +180,7 @@ static WindowSpecDefinition getCommonWindowDefinition(Expression sortField, bool } /** - * To produce a list of Expression with responsible to return appropriate lookbehind / lookahead value for WMA calculation, sample logical plan listed below. + * To produce a list of Expressions responsible to return appropriate lookbehind / lookahead value for WMA calculation, sample logical plan listed below. * (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + * * @param visitor Visitor instance to resolve Expression. @@ -195,7 +195,6 @@ private static List getNthValueAggregations(CatalystExpressionVisito CatalystPlanContext context, WindowSpecDefinition windowDefinition, int dataPoints) { - List expressions = new ArrayList<>(); for (int i = 1; i <= dataPoints; i++) { // Get the offset parameter From 74f001f650b862b472c8890803fe5ef457215ecb Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 6 Nov 2024 17:30:51 -0800 Subject: [PATCH 05/12] Addres comments Signed-off-by: Andy Kwok --- .../src/main/antlr4/OpenSearchPPLParser.g4 | 3 ++- .../sql/ppl/utils/TrendlineCatalystUtils.java | 27 ++++++++++++------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 24cdc21b0..7290a6d10 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -271,7 +271,8 @@ trendlineClause ; trendlineType - : (SMA | WMA) + : SMA + | WMA ; kmeansCommand diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java index 307041bd4..513561bfa 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java @@ -96,10 +96,19 @@ private static CaseWhen trendlineOrNullWhenThereAreTooFewDataPoints(CatalystExpr * Responsible to produce a Spark Logical Plan with given TrendLine command arguments, below is the sample logical plan * with configuration [dataField=salary, sortField=age, dataPoints=3] * -- +- 'Project [ - * -- (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + - * -- ('nth_value('salary, 2) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) + - * -- ('nth_value('salary, 3) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 3)) / 6) + * -- (((('nth_value('salary, 1) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + + * -- ('nth_value('salary, 2) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + + * -- ('nth_value('salary, 3) windowspecdefinition(Field(field=age, fieldArgs=[]) ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) * -- AS WMA#702] + * . + * And the corresponded SQL query: + * . + * SELECT name, salary, + * ( nth_value(salary, 1) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *1 + + * nth_value(salary, 2) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *2 + + * nth_value(salary, 3) OVER (ORDER BY age ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) *3 )/6 AS WMA + * FROM employees + * ORDER BY age; * * @param visitor Visitor instance to process any UnresolvedExpression. * @param node Trendline command's arguments. @@ -113,7 +122,7 @@ private static NamedExpression getWMAComputationExpression(CatalystExpressionVis CatalystPlanContext context) { int dataPoints = node.getNumberOfDataPoints(); //window lower boundary - Expression windowLowerBoundary = getIntExpression(visitor, context, + Expression windowLowerBoundary = parseIntToExpression(visitor, context, Math.negateExact(dataPoints - 1)); //window definition visitor.analyze(sortField, context); @@ -123,15 +132,15 @@ private static NamedExpression getWMAComputationExpression(CatalystExpressionVis SortUtils.isSortedAscending(sortField), windowLowerBoundary); // Divisor - Expression divisor = getIntExpression(visitor, context, + Expression divisor = parseIntToExpression(visitor, context, (dataPoints * (dataPoints + 1) / 2)); // Aggregation - Expression WMAExpression = getNthValueAggregations(visitor, node, context, windowDefinition, dataPoints) + Expression wmaExpression = getNthValueAggregations(visitor, node, context, windowDefinition, dataPoints) .stream() .reduce(Add::new) .orElse(null); - return getAlias(node.getAlias(), new Divide(WMAExpression, divisor)); + return getAlias(node.getAlias(), new Divide(wmaExpression, divisor)); } /** @@ -156,7 +165,7 @@ private static NamedExpression getAlias(String name, Expression expression) { * @param i Target value for the expression. * @return An expression object which contain integer value i. */ - static Expression getIntExpression(CatalystExpressionVisitor expressionVisitor, CatalystPlanContext context, int i) { + static Expression parseIntToExpression(CatalystExpressionVisitor expressionVisitor, CatalystPlanContext context, int i) { expressionVisitor.visitLiteral(new Literal(i, DataType.INTEGER), context); return context.popNamedParseExpressions().get(); @@ -198,7 +207,7 @@ private static List getNthValueAggregations(CatalystExpressionVisito List expressions = new ArrayList<>(); for (int i = 1; i <= dataPoints; i++) { // Get the offset parameter - Expression offSetExpression = getIntExpression(visitor, context, i); + Expression offSetExpression = parseIntToExpression(visitor, context, i); // Composite the nth_value expression. Function func = new Function(BuiltinFunctionName.NTH_VALUE.name(), From 039323368fc8a65d94b932c321bfaeccd6a53194 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 6 Nov 2024 18:15:35 -0800 Subject: [PATCH 06/12] Update doc Signed-off-by: Andy Kwok --- docs/ppl-lang/ppl-trendline-command.md | 64 +++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 6 deletions(-) diff --git a/docs/ppl-lang/ppl-trendline-command.md b/docs/ppl-lang/ppl-trendline-command.md index 393a9dd59..b466e2e8f 100644 --- a/docs/ppl-lang/ppl-trendline-command.md +++ b/docs/ppl-lang/ppl-trendline-command.md @@ -3,8 +3,7 @@ **Description** Using ``trendline`` command to calculate moving averages of fields. - -### Syntax +### Syntax - SMA (Simple Moving Average) `TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...` * [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. @@ -13,8 +12,6 @@ Using ``trendline`` command to calculate moving averages of fields. * field: mandatory. the name of the field the moving average should be calculated for. * alias: optional. the name of the resulting column containing the moving average. -And the moment only the Simple Moving Average (SMA) type is supported. - It is calculated like f[i]: The value of field 'f' in the i-th data-point @@ -23,7 +20,7 @@ It is calculated like SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t -### Example 1: Calculate simple moving average for a timeseries of temperatures +#### Example 1: Calculate simple moving average for a timeseries of temperatures The example calculates the simple moving average over temperatures using two datapoints. @@ -41,7 +38,7 @@ PPL query: | 15| 258|2023-04-06 17:07:...| 14.5| +-----------+---------+--------------------+----------+ -### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting +#### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id. @@ -58,3 +55,58 @@ PPL query: | 12| 1492|2023-04-06 17:07:...| 12.5| 13.0| | 12| 1492|2023-04-06 17:07:...| 12.0|12.333333333333334| +-----------+---------+--------------------+------------+------------------+ + + +### Syntax - WMA (Weighted Moving Average) +`TRENDLINE sort <[+|-] sort-field> WMA(number-of-datapoints, field) [AS alias] [WMA(number-of-datapoints, field) [AS alias]]...` + +* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. +* sort-field: mandatory. this field specifies the ordering of data poients when calculating the nth_value aggregation. +* number-of-datapoints: mandatory. number of datapoints to calculate the moving average (must be greater than zero). +* field: mandatory. the name of the field the moving averag should be calculated for. +* alias: optional. the name of the resulting column containing the moving average. + +It is calculated like + + f[i]: The value of field 'f' in the i-th data point + n: The number of data points in the moving window (period) + t: The current time index + w[i]: The weight assigned to the i-th data point, typically increasing for more recent points + + WMA(t) = ( Σ from i=t−n+1 to t of (w[i] * f[i]) ) / ( Σ from i=t−n+1 to t of w[i] ) + +#### Example 1: Calculate weighted moving average for a timeseries of temperatures + +The example calculates the simple moving average over temperatures using two datapoints. + +PPL query: + + os> source=t | trendline sort timestamp wma(2, temperature) as temp_trend; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+----------+ + |temperature|device-id| timestamp|temp_trend| + +-----------+---------+--------------------+----------+ + | 12| 1492|2023-04-06 17:07:...| NULL| + | 12| 1492|2023-04-06 17:07:...| 12.0| + | 13| 256|2023-04-06 17:07:...| 12.6| + | 14| 257|2023-04-06 17:07:...| 13.6| + | 15| 258|2023-04-06 17:07:...| 14.6| + +-----------+---------+--------------------+----------+ + +#### Example 2: Calculate simple moving averages for a timeseries of temperatures with sorting + +The example calculates two simple moving average over temperatures using two and three datapoints sorted descending by device-id. + +PPL query: + + os> source=t | trendline sort - device-id wma(2, temperature) as temp_trend_2 wma(3, temperature) as temp_trend_3; + fetched rows / total rows = 5/5 + +-----------+---------+--------------------+------------+------------------+ + |temperature|device-id| timestamp|temp_trend_2| temp_trend_3| + +-----------+---------+--------------------+------------+------------------+ + | 15| 258|2023-04-06 17:07:...| NULL| NULL| + | 14| 257|2023-04-06 17:07:...| 14.3| NULL| + | 13| 256|2023-04-06 17:07:...| 13.3| 13.6| + | 12| 1492|2023-04-06 17:07:...| 12.3| 12.6| + | 12| 1492|2023-04-06 17:07:...| 12.0| 12.16| + +-----------+---------+--------------------+------------+------------------+ From 9c7a9f95b7b4e1260edd082e04950dbf58a9b318 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 6 Nov 2024 19:21:30 -0800 Subject: [PATCH 07/12] Update example Signed-off-by: Andy Kwok --- docs/ppl-lang/PPL-Example-Commands.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 851531b5b..7766c3b50 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -65,6 +65,7 @@ _- **Limitation: new field added by eval command with a function cannot be dropp - `source = table | where cidrmatch(ip, '192.169.1.0/24')` - `source = table | where cidrmatch(ipv6, '2003:db8::/32')` - `source = table | trendline sma(2, temperature) as temp_trend` +- `source = table | trendline sort timestamp wma(2, temperature) as temp_trend` #### **IP related queries** [See additional command details](functions/ppl-ip.md) From 67660677c511c8e619e3c05dd729cd5f482c3794 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Thu, 7 Nov 2024 09:56:52 -0800 Subject: [PATCH 08/12] Update readme Signed-off-by: Andy Kwok --- DEVELOPER_GUIDE.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index bb8f697ec..23373fb84 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -11,6 +11,11 @@ To execute the unit tests, run the following command: ``` sbt test ``` +To run a specific unit test in SBT, use the testOnly command with the full path of the test class: +``` +sbt test:testOnly org.opensearch.flint.spark.ppl.PPLLogicalPlanTrendlineCommandTranslatorTestSuite +``` + ## Integration Test The integration test is defined in the `integration` directory of the project. The integration tests will automatically trigger unit tests and will only run if all unit tests pass. If you want to run the integration test for the project, you can do so by running the following command: @@ -23,6 +28,13 @@ If you get integration test failures with error message "Previous attempts to fi 3. Run `sudo ln -s $HOME/.docker/desktop/docker.sock /var/run/docker.sock` or `sudo ln -s $HOME/.docker/run/docker.sock /var/run/docker.sock` 4. If you use Docker Desktop, as an alternative of `3`, check mark the "Allow the default Docker socket to be used (requires password)" in advanced settings of Docker Desktop. +Running only a selected set of integration test suites is possible with the following command: +``` +sbt "project integtest" it:testOnly org.opensearch.flint.spark.ppl.FlintSparkPPLTrendlineITSuite +``` +This command runs only the specified test suite within the integtest submodule. + + ### AWS Integration Test The `aws-integration` folder contains tests for cloud server providers. For instance, test against AWS OpenSearch domain, configure the following settings. The client will use the default credential provider to access the AWS OpenSearch domain. ``` From 54fecc568a843889d9cce7636e49ebbdd2b7c4e6 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Tue, 12 Nov 2024 12:43:25 -0800 Subject: [PATCH 09/12] Update scalafmt Signed-off-by: Andy Kwok --- .../ppl/FlintSparkPPLTrendlineITSuite.scala | 178 +++++++++++------- ...nTrendlineCommandTranslatorTestSuite.scala | 132 ++++++++----- 2 files changed, 191 insertions(+), 119 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala index 7c10a6dd0..589cad33b 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -5,13 +5,14 @@ package org.opensearch.flint.spark.ppl +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.opensearch.sql.ppl.utils.SortUtils + import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Ascending, CaseWhen, CurrentRow, Descending, Divide, Expression, LessThan, Literal, Multiply, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest -import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq -import org.opensearch.sql.ppl.utils.SortUtils class FlintSparkPPLTrendlineITSuite extends QueryTest @@ -271,24 +272,29 @@ class FlintSparkPPLTrendlineITSuite // Compare the logical plans val logicalPlan: LogicalPlan = frame.queryExecution.logical - val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), - getNthValueAggregation("age", "age", 2, -2)), + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) val wmaExpression = Divide(dividend, Literal(6)) val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) val sortedTable = Sort( - Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, unresolvedRelation) - val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + /** - * Expected logical plan: - * 'Project [*] - * +- 'Project [*, ((( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + - * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS age_trendline#185] - * +- 'Sort ['age ASC NULLS FIRST], true - * +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + * Expected logical plan: 'Project [*] +- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS age_trendline#185] +- 'Sort ['age ASC NULLS FIRST], true +- + * 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false */ comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -317,24 +323,30 @@ class FlintSparkPPLTrendlineITSuite // Compare the logical plans val logicalPlan: LogicalPlan = frame.queryExecution.logical - val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), - getNthValueAggregation("age", "age", 2, -2)), + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) val wmaExpression = Divide(dividend, Literal(6)) - val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "trendline_alias")()) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "trendline_alias")()) val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) val sortedTable = Sort( - Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, unresolvedRelation) - val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) /** - * 'Project [*] - * +- 'Project [*, ((( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + - * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS trendline_alias#185] - * +- 'Sort ['age ASC NULLS FIRST], true - * +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + * 'Project [*] +- 'Project [*, ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS + * FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS trendline_alias#185] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [spark_catalog, default, + * flint_ppl_test], [], false */ comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -347,7 +359,15 @@ class FlintSparkPPLTrendlineITSuite // Compare the headers assert( frame.columns.sameElements( - Array("name", "age", "state", "country", "year", "month", "two_points_wma", "three_points_wma"))) + Array( + "name", + "age", + "state", + "country", + "year", + "month", + "two_points_wma", + "three_points_wma"))) // Retrieve the results val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = @@ -364,32 +384,43 @@ class FlintSparkPPLTrendlineITSuite // Compare the logical plans val logicalPlan: LogicalPlan = frame.queryExecution.logical - val dividendTwo = Add(getNthValueAggregation("age", "age", 1, -1), + val dividendTwo = Add( + getNthValueAggregation("age", "age", 1, -1), getNthValueAggregation("age", "age", 2, -1)) val twoPointsExpression = Divide(dividendTwo, Literal(3)) - val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), - getNthValueAggregation("age", "age", 2, -2)), + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) val threePointsExpression = Divide(dividend, Literal(6)) - val trendlineProjectList = Seq(UnresolvedStar(None), Alias(twoPointsExpression, "two_points_wma")(), Alias(threePointsExpression, "three_points_wma")()) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsExpression, "two_points_wma")(), + Alias(threePointsExpression, "three_points_wma")()) val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) val sortedTable = Sort( - Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, unresolvedRelation) - val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + unresolvedRelation) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + /** - * 'Project [*] - * +- 'Project [*, (( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS two_points_wma#247, + * 'Project [*] +- 'Project [*, (( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS + * FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 2)) / 3) AS two_points_wma#247, * - * ((( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + - * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#248] - * +- 'Sort ['age ASC NULLS FIRST], true - * +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + * ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#248] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [spark_catalog, default, + * flint_ppl_test], [], false */ comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -400,9 +431,7 @@ class FlintSparkPPLTrendlineITSuite | """.stripMargin) // Compare the headers - assert( - frame.columns.sameElements( - Array("name", "doubled_age", "doubled_age_wma"))) + assert(frame.columns.sameElements(Array("name", "doubled_age", "doubled_age_wma"))) // Retrieve the results val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = @@ -418,44 +447,57 @@ class FlintSparkPPLTrendlineITSuite // Compare the logical plans val logicalPlan: LogicalPlan = frame.queryExecution.logical - val dividend = Add(getNthValueAggregation("doubled_age", "age", 1, -1), + val dividend = Add( + getNthValueAggregation("doubled_age", "age", 1, -1), getNthValueAggregation("doubled_age", "age", 2, -1)) val wmaExpression = Divide(dividend, Literal(3)) - val trendlineProjectList = Seq(UnresolvedStar(None), - Alias(wmaExpression, "doubled_age_wma")()) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "doubled_age_wma")()) val unresolvedRelation = UnresolvedRelation(testTable.split("\\.").toSeq) - val doubledAged = Alias(UnresolvedFunction(seq("*"), seq(UnresolvedAttribute("age"), Literal(2)), isDistinct = false) , "doubled_age")() + val doubledAged = Alias( + UnresolvedFunction( + seq("*"), + seq(UnresolvedAttribute("age"), Literal(2)), + isDistinct = false), + "doubled_age")() val doubleAgeProject = Project(seq(UnresolvedStar(None), doubledAged), unresolvedRelation) - val sortedTable = Sort( - Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, - doubleAgeProject) + val sortedTable = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, doubleAgeProject) val expectedPlan = Project( - Seq(UnresolvedAttribute("name"),UnresolvedAttribute("doubled_age"),UnresolvedAttribute("doubled_age_wma")), - Project(trendlineProjectList, sortedTable )) - /** - * - 'Project ['name, 'doubled_age, 'doubled_age_wma] - +- 'Project [*, (( - ('nth_value('doubled_age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + - ('nth_value('doubled_age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS doubled_age_wma#288] - +- 'Sort ['age ASC NULLS FIRST], true - +- 'Project [*, '`*`('age, 2) AS doubled_age#287] - +- 'UnresolvedRelation [spark_catalog, default, flint_ppl_test], [], false + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("doubled_age"), + UnresolvedAttribute("doubled_age_wma")), + Project(trendlineProjectList, sortedTable)) + /** + * 'Project ['name, 'doubled_age, 'doubled_age_wma] +- 'Project [*, (( + * ('nth_value('doubled_age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + ('nth_value('doubled_age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 2)) / 3) AS doubled_age_wma#288] +- 'Sort ['age ASC NULLS FIRST], true +- + * 'Project [*, '`*`('age, 2) AS doubled_age#287] +- 'UnresolvedRelation [spark_catalog, + * default, flint_ppl_test], [], false */ comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } - private def getNthValueAggregation(dataField: String, sortField: String, lookBackPos: Int, lookBackRange: Int): Expression = { + private def getNthValueAggregation( + dataField: String, + sortField: String, + lookBackPos: Int, + lookBackRange: Int): Expression = { Multiply( WindowExpression( - UnresolvedFunction("nth_value", Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), isDistinct = false), + UnresolvedFunction( + "nth_value", + Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), + isDistinct = false), WindowSpecDefinition( Seq(), seq(SortUtils.sortOrder(UnresolvedAttribute(sortField), true)), - SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow) - )), + SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow))), Literal(lookBackPos)) } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala index d49672ce4..baf472a08 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -138,25 +138,32 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite test("WMA - with sort") { val context = new CatalystPlanContext val logPlan = - planTransformer.visit(plan(pplParser, "source=relation | trendline sort age wma(3, age)"), context) + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age wma(3, age)"), + context) - val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), - getNthValueAggregation("age", "age", 2, -2)), - getNthValueAggregation("age", "age", 3, -2)) + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), + getNthValueAggregation("age", "age", 3, -2)) val wmaExpression = Divide(dividend, Literal(6)) val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "age_trendline")()) val sortedTable = Sort( - Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, UnresolvedRelation(Seq("relation"))) - val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, + UnresolvedRelation(Seq("relation"))) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + /** - * Expected logical plan: - * 'Project [*] - * !+- 'Project [*, ((( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + - * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS age_trendline#0] - * ! +- 'Sort ['age ASC NULLS FIRST], true - * ! +- 'UnresolvedRelation [relation], [], false + * Expected logical plan: 'Project [*] !+- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS age_trendline#0] ! +- 'Sort ['age ASC NULLS FIRST], true ! +- + * 'UnresolvedRelation [relation], [], false */ comparePlans(logPlan, expectedPlan, checkAnalysis = false) } @@ -164,28 +171,34 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite test("WMA - with sort and alias") { val context = new CatalystPlanContext val logPlan = - planTransformer.visit(plan(pplParser, "source=relation | trendline sort age wma(3, age) as TEST_CUSTOM_COLUMN"), context) + planTransformer.visit( + plan(pplParser, "source=relation | trendline sort age wma(3, age) as TEST_CUSTOM_COLUMN"), + context) - val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), - getNthValueAggregation("age", "age", 2, -2)), + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) val wmaExpression = Divide(dividend, Literal(6)) - val trendlineProjectList = Seq(UnresolvedStar(None), Alias(wmaExpression, "TEST_CUSTOM_COLUMN")()) + val trendlineProjectList = + Seq(UnresolvedStar(None), Alias(wmaExpression, "TEST_CUSTOM_COLUMN")()) val sortedTable = Sort( - Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, UnresolvedRelation(Seq("relation"))) /** - * Expected logical plan: - * 'Project [*] - * !+- 'Project [*, ((( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + - * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS TEST_CUSTOM_COLUMN#0] - * ! +- 'Sort ['age ASC NULLS FIRST], true - * ! +- 'UnresolvedRelation [relation], [], false + * Expected logical plan: 'Project [*] !+- 'Project [*, ((( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + ('nth_value('age, 3) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 3)) / 6) AS TEST_CUSTOM_COLUMN#0] ! +- 'Sort ['age ASC NULLS FIRST], true + * ! +- 'UnresolvedRelation [relation], [], false */ - val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) comparePlans(logPlan, expectedPlan, checkAnalysis = false) } @@ -193,50 +206,67 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite test("WMA - multiple trendline commands") { val context = new CatalystPlanContext val logPlan = - planTransformer.visit(plan(pplParser, "source=relation | trendline sort age wma(2, age) as two_points_wma wma(3, age) as three_points_wma"), context) + planTransformer.visit( + plan( + pplParser, + "source=relation | trendline sort age wma(2, age) as two_points_wma wma(3, age) as three_points_wma"), + context) - val dividendTwo = Add(getNthValueAggregation("age", "age", 1, -1), + val dividendTwo = Add( + getNthValueAggregation("age", "age", 1, -1), getNthValueAggregation("age", "age", 2, -1)) val twoPointsExpression = Divide(dividendTwo, Literal(3)) - val dividend = Add(Add(getNthValueAggregation("age", "age", 1, -2), - getNthValueAggregation("age", "age", 2, -2)), + val dividend = Add( + Add( + getNthValueAggregation("age", "age", 1, -2), + getNthValueAggregation("age", "age", 2, -2)), getNthValueAggregation("age", "age", 3, -2)) val threePointsExpression = Divide(dividend, Literal(6)) - val trendlineProjectList = Seq(UnresolvedStar(None), Alias(twoPointsExpression, "two_points_wma")(), Alias(threePointsExpression, "three_points_wma")()) + val trendlineProjectList = Seq( + UnresolvedStar(None), + Alias(twoPointsExpression, "two_points_wma")(), + Alias(threePointsExpression, "three_points_wma")()) val sortedTable = Sort( - Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, + Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), + global = true, UnresolvedRelation(Seq("relation"))) + /** - * Expected logical plan: - * 'Project [*] - * +- 'Project [*, (( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS two_points_wma#0, + * Expected logical plan: 'Project [*] +- 'Project [*, (( ('nth_value('age, 1) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, + * currentrow$())) * 1) + ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -1, currentrow$())) * 2)) / 3) AS two_points_wma#0, * - * ((( - * ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + - * ('nth_value('age, 2) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 2)) + - * ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#1] - * +- 'Sort ['age ASC NULLS FIRST], true - * +- 'UnresolvedRelation [relation], [], false + * ((( ('nth_value('age, 1) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 1) + ('nth_value('age, 2) + * windowspecdefinition('age ASC NULLS FIRST, specifiedwindowframe(RowFrame, -2, + * currentrow$())) * 2)) + ('nth_value('age, 3) windowspecdefinition('age ASC NULLS FIRST, + * specifiedwindowframe(RowFrame, -2, currentrow$())) * 3)) / 6) AS three_points_wma#1] +- + * 'Sort ['age ASC NULLS FIRST], true +- 'UnresolvedRelation [relation], [], false */ - val expectedPlan = Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) + val expectedPlan = + Project(Seq(UnresolvedStar(None)), Project(trendlineProjectList, sortedTable)) comparePlans(logPlan, expectedPlan, checkAnalysis = false) } - private def getNthValueAggregation(dataField: String, sortField: String, lookBackPos: Int, lookBackRange: Int): Expression = { + private def getNthValueAggregation( + dataField: String, + sortField: String, + lookBackPos: Int, + lookBackRange: Int): Expression = { Multiply( WindowExpression( - UnresolvedFunction("nth_value", Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), isDistinct = false), + UnresolvedFunction( + "nth_value", + Seq(UnresolvedAttribute(dataField), Literal(lookBackPos)), + isDistinct = false), WindowSpecDefinition( Seq(), seq(SortUtils.sortOrder(UnresolvedAttribute(sortField), true)), - SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow) - )), + SpecifiedWindowFrame(RowFrame, Literal(lookBackRange), CurrentRow))), Literal(lookBackPos)) } - } From 6b6661548808b1b81db1c23e9a09c18fce685960 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Tue, 12 Nov 2024 16:33:28 -0800 Subject: [PATCH 10/12] Update grammar rule Signed-off-by: Andy Kwok --- ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 7290a6d10..357673e73 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -267,7 +267,7 @@ trendlineCommand ; trendlineClause - : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)? + : trendlineType LT_PRTHS numberOfDataPoints = INTEGER_LITERAL COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)? ; trendlineType From b6d356d09655d083a09f50df38dfea5137086122 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 13 Nov 2024 14:03:33 -0800 Subject: [PATCH 11/12] Address review comments Signed-off-by: Andy Kwok --- DEVELOPER_GUIDE.md | 4 ++-- .../function/BuiltinFunctionName.java | 2 -- .../sql/ppl/utils/TrendlineCatalystUtils.java | 19 ++++++++++++------- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index 23373fb84..834a2a201 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -13,7 +13,7 @@ sbt test ``` To run a specific unit test in SBT, use the testOnly command with the full path of the test class: ``` -sbt test:testOnly org.opensearch.flint.spark.ppl.PPLLogicalPlanTrendlineCommandTranslatorTestSuite +sbt "; project pplSparkIntegration; test:testOnly org.opensearch.flint.spark.ppl.PPLLogicalPlanTrendlineCommandTranslatorTestSuite" ``` @@ -30,7 +30,7 @@ If you get integration test failures with error message "Previous attempts to fi Running only a selected set of integration test suites is possible with the following command: ``` -sbt "project integtest" it:testOnly org.opensearch.flint.spark.ppl.FlintSparkPPLTrendlineITSuite +sbt "; project integtest; it:testOnly org.opensearch.flint.spark.ppl.FlintSparkPPLTrendlineITSuite" ``` This command runs only the specified test suite within the integtest submodule. diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index e232c3668..86970cefb 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -292,8 +292,6 @@ public enum BuiltinFunctionName { MULTIMATCHQUERY(FunctionName.of("multimatchquery")), WILDCARDQUERY(FunctionName.of("wildcardquery")), WILDCARD_QUERY(FunctionName.of("wildcard_query")), - - NTH_VALUE(FunctionName.of("nth_value")), COALESCE(FunctionName.of("coalesce")); private FunctionName name; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java index 513561bfa..647f4542e 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/TrendlineCatalystUtils.java @@ -5,6 +5,7 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.expressions.*; import org.opensearch.sql.ast.expression.*; import org.opensearch.sql.ast.expression.Literal; @@ -12,6 +13,7 @@ import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.ppl.CatalystExpressionVisitor; import org.opensearch.sql.ppl.CatalystPlanContext; +import scala.collection.mutable.Seq; import scala.Option; import scala.Tuple2; @@ -22,6 +24,8 @@ import java.util.stream.Collectors; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static scala.Option.empty; +import static scala.collection.JavaConverters.asScalaBufferConverter; public interface TrendlineCatalystUtils { @@ -208,13 +212,14 @@ private static List getNthValueAggregations(CatalystExpressionVisito for (int i = 1; i <= dataPoints; i++) { // Get the offset parameter Expression offSetExpression = parseIntToExpression(visitor, context, i); - - // Composite the nth_value expression. - Function func = new Function(BuiltinFunctionName.NTH_VALUE.name(), - List.of(node.getDataField(), new Literal(i, DataType.INTEGER))); - - visitor.visitFunction(func, context); - Expression nthValueExp = context.popNamedParseExpressions().get(); + // Get the dataField in Expression + visitor.analyze(node.getDataField(), context); + Expression dataField = context.popNamedParseExpressions().get(); + // nth_value Expression + UnresolvedFunction nthValueExp = new UnresolvedFunction( + asScalaBufferConverter(List.of("nth_value")).asScala().seq(), + asScalaBufferConverter(List.of(dataField, offSetExpression)).asScala().seq(), + false, empty(), false); expressions.add(new Multiply( new WindowExpression(nthValueExp, windowDefinition), offSetExpression)); From 2b61f4cdadcdb77ca9af7ce5222395574caad6a3 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 13 Nov 2024 16:21:42 -0800 Subject: [PATCH 12/12] Address review comments Signed-off-by: Andy Kwok --- .../flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala | 9 +++++++++ ...LLogicalPlanTrendlineCommandTranslatorTestSuite.scala | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala index 589cad33b..9a8379288 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTrendlineITSuite.scala @@ -7,10 +7,12 @@ package org.opensearch.flint.spark.ppl import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.opensearch.sql.ppl.utils.SortUtils +import org.scalatest.matchers.should.Matchers.{a, convertToAnyShouldWrapper} import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Ascending, CaseWhen, CurrentRow, Descending, Divide, Expression, LessThan, Literal, Multiply, RowFrame, SortOrder, SpecifiedWindowFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -483,6 +485,13 @@ class FlintSparkPPLTrendlineITSuite } + test("test invalid wma command with negative dataPoint value") { + val exception = intercept[ParseException](sql(s""" + | source = $testTable | trendline sort + age wma(-3, age) + | """.stripMargin)) + assert(exception.getMessage contains "[PARSE_SYNTAX_ERROR] Syntax error") + } + private def getNthValueAggregation( dataField: String, sortField: String, diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala index baf472a08..ec1775631 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTrendlineCommandTranslatorTestSuite.scala @@ -6,6 +6,7 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.common.antlr.SyntaxCheckException import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.opensearch.sql.ppl.utils.SortUtils @@ -251,6 +252,14 @@ class PPLLogicalPlanTrendlineCommandTranslatorTestSuite } + test("WMA - with negative dataPoint value") { + val context = new CatalystPlanContext + val exception = intercept[SyntaxCheckException]( + planTransformer + .visit(plan(pplParser, "source=relation | trendline sort age wma(-3, age)"), context)) + assert(exception.getMessage startsWith "Failed to parse query due to offending symbol [-]") + } + private def getNthValueAggregation( dataField: String, sortField: String,