From a6dfd385e79603415b1c067f150e0575b558c739 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sat, 28 May 2022 11:31:07 +0800 Subject: [PATCH 1/4] [SPARK-37623][SQL] Support ANSI Aggregate Function: regr_intercept --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../aggregate/linearRegression.scala | 63 ++++++++++++++++++- .../aggregate/AggregateExpressionSuite.scala | 17 +++++ .../sql-functions/sql-expression-schema.md | 1 + .../sql-tests/inputs/linear-regression.sql | 6 ++ .../inputs/postgreSQL/aggregates_part1.sql | 2 +- .../udf/postgreSQL/udf-aggregates_part1.sql | 2 +- .../results/linear-regression.sql.out | 35 ++++++++++- .../postgreSQL/aggregates_part1.sql.out | 10 ++- .../postgreSQL/udf-aggregates_part1.sql.out | 10 ++- 10 files changed, 140 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 2155fb2efebc1..fe825c34f41cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -506,6 +506,7 @@ object FunctionRegistry { expression[RegrSXY]("regr_sxy"), expression[RegrSYY]("regr_syy"), expression[RegrSlope]("regr_slope"), + expression[RegrIntercept]("regr_intercept"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala index d6d1b8fc6325f..1bff7ff31c524 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala @@ -256,9 +256,9 @@ case class RegrSYY( case class RegrSlope(left: Expression, right: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] { - private val covarPop = new CovPopulation(right, left) + private[sql] val covarPop = new CovPopulation(right, left) - private val varPop = new VariancePop(right) + private[sql] val varPop = new VariancePop(right) override def nullable: Boolean = true @@ -291,3 +291,62 @@ case class RegrSlope(left: Expression, right: Expression) extends DeclarativeAgg newLeft: Expression, newRight: Expression): RegrSlope = copy(left = newLeft, right = newRight) } + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the intercept of the univariate linear regression line for non-null pairs in a group, where `y` is the dependent variable and `x` is the independent variable.", + examples = """ + Examples: + > SELECT _FUNC_(y, x) FROM VALUES (1,1), (2,2), (3,3) AS tab(y, x); + 0.0 + > SELECT _FUNC_(y, x) FROM VALUES (1, null) AS tab(y, x); + NULL + > SELECT _FUNC_(y, x) FROM VALUES (null, 1) AS tab(y, x); + NULL + """, + group = "agg_funcs", + since = "3.4.0") +// scalastyle:on line.size.limit +case class RegrIntercept(left: Expression, right: Expression) extends DeclarativeAggregate + with ImplicitCastInputTypes with BinaryLike[Expression] { + + private val avgY = Average(left) + + private val regrSlope = RegrSlope(left, right) + + private val avgX = Average(right) + + override def nullable: Boolean = true + + override def dataType: DataType = DoubleType + + override def inputTypes: Seq[DoubleType] = Seq(DoubleType, DoubleType) + + override lazy val aggBufferAttributes: Seq[AttributeReference] = + avgY.aggBufferAttributes ++ regrSlope.aggBufferAttributes ++ avgX.aggBufferAttributes + + override lazy val initialValues: Seq[Expression] = + avgY.initialValues ++ regrSlope.initialValues ++ avgX.initialValues + + override lazy val updateExpressions: Seq[Expression] = + avgY.updateExpressions ++ regrSlope.updateExpressions ++ avgX.updateExpressions + + override lazy val mergeExpressions: Seq[Expression] = + avgY.mergeExpressions ++ regrSlope.mergeExpressions ++ avgX.mergeExpressions + + override lazy val evaluateExpression: Expression = { + If(regrSlope.covarPop.n === 0.0, Literal.create(null, DoubleType), + avgY.evaluateExpression - + (regrSlope.covarPop.ck / regrSlope.varPop.m2) * avgX.evaluateExpression) + } + + override lazy val inputAggBufferAttributes: Seq[AttributeReference] = + avgY.inputAggBufferAttributes ++ regrSlope.inputAggBufferAttributes ++ + avgX.inputAggBufferAttributes + + override def prettyName: String = "regr_intercept" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): RegrIntercept = + copy(left = newLeft, right = newRight) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala index 5f19a4b79cd5c..9be586a31fc42 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala @@ -64,4 +64,21 @@ class AggregateExpressionSuite extends SparkFunSuite { assert(RegrSlope(Literal(3.0D), Literal(1D)).checkInputDataTypes() === TypeCheckResult.TypeCheckSuccess) } + + test("test regr_intercept input types") { + val checkResult1 = RegrIntercept(Literal("a"), Literal(1)).checkInputDataTypes() + assert(checkResult1.isInstanceOf[TypeCheckResult.TypeCheckFailure]) + assert(checkResult1.asInstanceOf[TypeCheckResult.TypeCheckFailure].message + .contains("argument 1 requires double type, however, ''a'' is of string type")) + val checkResult2 = RegrIntercept(Literal(3.0D), Literal('b')).checkInputDataTypes() + assert(checkResult2.isInstanceOf[TypeCheckResult.TypeCheckFailure]) + assert(checkResult2.asInstanceOf[TypeCheckResult.TypeCheckFailure].message + .contains("argument 2 requires double type, however, ''b'' is of string type")) + val checkResult3 = RegrIntercept(Literal(3.0D), Literal(Array(0))).checkInputDataTypes() + assert(checkResult3.isInstanceOf[TypeCheckResult.TypeCheckFailure]) + assert(checkResult3.asInstanceOf[TypeCheckResult.TypeCheckFailure].message + .contains("argument 2 requires double type, however, '[0]' is of array type")) + assert(RegrIntercept(Literal(3.0D), Literal(1D)).checkInputDataTypes() === + TypeCheckResult.TypeCheckSuccess) + } } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 2d4eef394053e..2840bfd98f750 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -374,6 +374,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.RegrAvgX | regr_avgx | SELECT regr_avgx(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.RegrAvgY | regr_avgy | SELECT regr_avgy(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.RegrCount | regr_count | SELECT regr_count(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct | +| org.apache.spark.sql.catalyst.expressions.aggregate.RegrIntercept | regr_intercept | SELECT regr_intercept(y, x) FROM VALUES (1,1), (2,2), (3,3) AS tab(y, x) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.RegrR2 | regr_r2 | SELECT regr_r2(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.RegrSXX | regr_sxx | SELECT regr_sxx(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.RegrSXY | regr_sxy | SELECT regr_sxy(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct | diff --git a/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql b/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql index 887e837e80f21..c7cb5bf1117a7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql @@ -44,3 +44,9 @@ SELECT regr_slope(y, x) FROM testRegression; SELECT regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL; SELECT k, regr_slope(y, x) FROM testRegression GROUP BY k; SELECT k, regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY k; + +-- SPARK-37623: Support ANSI Aggregate Function: regr_intercept +SELECT regr_intercept(y, x) FROM testRegression; +SELECT regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL; +SELECT k, regr_intercept(y, x) FROM testRegression GROUP BY k; +SELECT k, regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql index 5f3d6b7e60c37..1152d77da0cf4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql @@ -86,7 +86,7 @@ SELECT regr_syy(b, a) FROM aggtest; SELECT regr_sxy(b, a) FROM aggtest; SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest; SELECT regr_r2(b, a) FROM aggtest; --- SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest; +SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest; SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest; SELECT corr(b, a) FROM aggtest; diff --git a/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part1.sql index c760da09ded19..4b816fb682b55 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part1.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part1.sql @@ -86,7 +86,7 @@ SELECT regr_syy(b, a) FROM aggtest; SELECT regr_sxy(b, a) FROM aggtest; SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest; SELECT regr_r2(b, a) FROM aggtest; --- SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest; +SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest; SELECT udf(covar_pop(b, udf(a))), covar_samp(udf(b), a) FROM aggtest; SELECT corr(b, udf(a)) FROM aggtest; diff --git a/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out b/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out index 04378f32838df..f70b369f63338 100644 --- a/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 29 +-- Number of queries: 33 -- !query @@ -244,3 +244,36 @@ SELECT k, regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT struct -- !query output 2 0.8314087759815244 + + +-- !query +SELECT regr_intercept(y, x) FROM testRegression +-- !query schema +struct +-- !query output +0.15473441108544606 + + +-- !query +SELECT regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL +-- !query schema +struct +-- !query output +1.154734411085446 + + +-- !query +SELECT k, regr_intercept(y, x) FROM testRegression GROUP BY k +-- !query schema +struct +-- !query output +1 NULL +2 2.404734411085446 + + +-- !query +SELECT k, regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY k +-- !query schema +struct +-- !query output +2 1.154734411085446 diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out index 6d5b89f0c637b..ee44cd9171a1f 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 56 +-- Number of queries: 57 -- !query @@ -336,6 +336,14 @@ struct 0.019497798203180258 +-- !query +SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest +-- !query schema +struct +-- !query output +0.5127507004412711 82.56199260123087 + + -- !query SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out index 268a8553cb487..49d3fbd7d877e 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 55 +-- Number of queries: 56 -- !query @@ -327,6 +327,14 @@ struct 0.019497798203180258 +-- !query +SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest +-- !query schema +struct +-- !query output +0.5127507004412711 82.56199260123087 + + -- !query SELECT udf(covar_pop(b, udf(a))), covar_samp(udf(b), a) FROM aggtest -- !query schema From 4d65722999ead74614becf1a61994224c827496c Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sun, 29 May 2022 10:18:37 +0800 Subject: [PATCH 2/4] Update code --- .../expressions/aggregate/linearRegression.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala index 1bff7ff31c524..a978aaf9f2b87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala @@ -296,14 +296,14 @@ case class RegrSlope(left: Expression, right: Expression) extends DeclarativeAgg @ExpressionDescription( usage = "_FUNC_(y, x) - Returns the intercept of the univariate linear regression line for non-null pairs in a group, where `y` is the dependent variable and `x` is the independent variable.", examples = """ - Examples: - > SELECT _FUNC_(y, x) FROM VALUES (1,1), (2,2), (3,3) AS tab(y, x); - 0.0 - > SELECT _FUNC_(y, x) FROM VALUES (1, null) AS tab(y, x); - NULL - > SELECT _FUNC_(y, x) FROM VALUES (null, 1) AS tab(y, x); - NULL - """, + Examples: + > SELECT _FUNC_(y, x) FROM VALUES (1,1), (2,2), (3,3) AS tab(y, x); + 0.0 + > SELECT _FUNC_(y, x) FROM VALUES (1, null) AS tab(y, x); + NULL + > SELECT _FUNC_(y, x) FROM VALUES (null, 1) AS tab(y, x); + NULL + """, group = "agg_funcs", since = "3.4.0") // scalastyle:on line.size.limit From d6507eeda024927eff41ca8caac1a677bb90d668 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 30 May 2022 13:47:44 +0800 Subject: [PATCH 3/4] Update code --- .../expressions/aggregate/Covariance.scala | 4 ++-- .../aggregate/linearRegression.scala | 22 +++++-------------- .../results/linear-regression.sql.out | 8 +++---- 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index 3bf6769f3c0b6..7a856a05b6fa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -35,8 +35,8 @@ abstract class Covariance(val left: Expression, val right: Expression, nullOnDiv override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) protected[sql] val n = AttributeReference("n", DoubleType, nullable = false)() - protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)() - protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)() + protected[sql] val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)() + protected[sql] val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)() protected[sql] val ck = AttributeReference("ck", DoubleType, nullable = false)() protected def divideByZeroEvalResult: Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala index a978aaf9f2b87..eb44a632e90c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala @@ -310,39 +310,29 @@ case class RegrSlope(left: Expression, right: Expression) extends DeclarativeAgg case class RegrIntercept(left: Expression, right: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] { - private val avgY = Average(left) - private val regrSlope = RegrSlope(left, right) - private val avgX = Average(right) - override def nullable: Boolean = true override def dataType: DataType = DoubleType override def inputTypes: Seq[DoubleType] = Seq(DoubleType, DoubleType) - override lazy val aggBufferAttributes: Seq[AttributeReference] = - avgY.aggBufferAttributes ++ regrSlope.aggBufferAttributes ++ avgX.aggBufferAttributes + override lazy val aggBufferAttributes: Seq[AttributeReference] = regrSlope.aggBufferAttributes - override lazy val initialValues: Seq[Expression] = - avgY.initialValues ++ regrSlope.initialValues ++ avgX.initialValues + override lazy val initialValues: Seq[Expression] = regrSlope.initialValues - override lazy val updateExpressions: Seq[Expression] = - avgY.updateExpressions ++ regrSlope.updateExpressions ++ avgX.updateExpressions + override lazy val updateExpressions: Seq[Expression] = regrSlope.updateExpressions - override lazy val mergeExpressions: Seq[Expression] = - avgY.mergeExpressions ++ regrSlope.mergeExpressions ++ avgX.mergeExpressions + override lazy val mergeExpressions: Seq[Expression] = regrSlope.mergeExpressions override lazy val evaluateExpression: Expression = { If(regrSlope.covarPop.n === 0.0, Literal.create(null, DoubleType), - avgY.evaluateExpression - - (regrSlope.covarPop.ck / regrSlope.varPop.m2) * avgX.evaluateExpression) + regrSlope.covarPop.yAvg - regrSlope.evaluateExpression * regrSlope.covarPop.xAvg) } override lazy val inputAggBufferAttributes: Seq[AttributeReference] = - avgY.inputAggBufferAttributes ++ regrSlope.inputAggBufferAttributes ++ - avgX.inputAggBufferAttributes + regrSlope.inputAggBufferAttributes override def prettyName: String = "regr_intercept" diff --git a/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out b/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out index f70b369f63338..0066ff8f064dd 100644 --- a/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out @@ -251,7 +251,7 @@ SELECT regr_intercept(y, x) FROM testRegression -- !query schema struct -- !query output -0.15473441108544606 +1.1547344110854496 -- !query @@ -259,7 +259,7 @@ SELECT regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT -- !query schema struct -- !query output -1.154734411085446 +1.1547344110854496 -- !query @@ -268,7 +268,7 @@ SELECT k, regr_intercept(y, x) FROM testRegression GROUP BY k struct -- !query output 1 NULL -2 2.404734411085446 +2 1.1547344110854496 -- !query @@ -276,4 +276,4 @@ SELECT k, regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS -- !query schema struct -- !query output -2 1.154734411085446 +2 1.1547344110854496 From 537d420a1a2e7d81c3a3152c6b2b34dbb5f57fa5 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 31 May 2022 11:07:31 +0800 Subject: [PATCH 4/4] Update code --- .../aggregate/linearRegression.scala | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala index eb44a632e90c4..c371f0b40c227 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala @@ -256,9 +256,9 @@ case class RegrSYY( case class RegrSlope(left: Expression, right: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] { - private[sql] val covarPop = new CovPopulation(right, left) + private val covarPop = new CovPopulation(right, left) - private[sql] val varPop = new VariancePop(right) + private val varPop = new VariancePop(right) override def nullable: Boolean = true @@ -269,8 +269,7 @@ case class RegrSlope(left: Expression, right: Expression) extends DeclarativeAgg override lazy val aggBufferAttributes: Seq[AttributeReference] = covarPop.aggBufferAttributes ++ varPop.aggBufferAttributes - override lazy val initialValues: Seq[Expression] = - covarPop.initialValues ++ varPop.initialValues + override lazy val initialValues: Seq[Expression] = covarPop.initialValues ++ varPop.initialValues override lazy val updateExpressions: Seq[Expression] = covarPop.updateExpressions ++ varPop.updateExpressions @@ -310,7 +309,9 @@ case class RegrSlope(left: Expression, right: Expression) extends DeclarativeAgg case class RegrIntercept(left: Expression, right: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] { - private val regrSlope = RegrSlope(left, right) + private val covarPop = new CovPopulation(right, left) + + private val varPop = new VariancePop(right) override def nullable: Boolean = true @@ -318,21 +319,24 @@ case class RegrIntercept(left: Expression, right: Expression) extends Declarativ override def inputTypes: Seq[DoubleType] = Seq(DoubleType, DoubleType) - override lazy val aggBufferAttributes: Seq[AttributeReference] = regrSlope.aggBufferAttributes + override lazy val aggBufferAttributes: Seq[AttributeReference] = + covarPop.aggBufferAttributes ++ varPop.aggBufferAttributes - override lazy val initialValues: Seq[Expression] = regrSlope.initialValues + override lazy val initialValues: Seq[Expression] = covarPop.initialValues ++ varPop.initialValues - override lazy val updateExpressions: Seq[Expression] = regrSlope.updateExpressions + override lazy val updateExpressions: Seq[Expression] = + covarPop.updateExpressions ++ varPop.updateExpressions - override lazy val mergeExpressions: Seq[Expression] = regrSlope.mergeExpressions + override lazy val mergeExpressions: Seq[Expression] = + covarPop.mergeExpressions ++ varPop.mergeExpressions override lazy val evaluateExpression: Expression = { - If(regrSlope.covarPop.n === 0.0, Literal.create(null, DoubleType), - regrSlope.covarPop.yAvg - regrSlope.evaluateExpression * regrSlope.covarPop.xAvg) + If(covarPop.n === 0.0, Literal.create(null, DoubleType), + covarPop.yAvg - covarPop.ck / varPop.m2 * covarPop.xAvg) } override lazy val inputAggBufferAttributes: Seq[AttributeReference] = - regrSlope.inputAggBufferAttributes + covarPop.inputAggBufferAttributes ++ varPop.inputAggBufferAttributes override def prettyName: String = "regr_intercept"