diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 2155fb2efebc1..fe825c34f41cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -506,6 +506,7 @@ object FunctionRegistry { expression[RegrSXY]("regr_sxy"), expression[RegrSYY]("regr_syy"), expression[RegrSlope]("regr_slope"), + expression[RegrIntercept]("regr_intercept"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index 3bf6769f3c0b6..7a856a05b6fa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -35,8 +35,8 @@ abstract class Covariance(val left: Expression, val right: Expression, nullOnDiv override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) protected[sql] val n = AttributeReference("n", DoubleType, nullable = false)() - protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)() - protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)() + protected[sql] val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)() + protected[sql] val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)() protected[sql] val ck = AttributeReference("ck", DoubleType, nullable = false)() protected def divideByZeroEvalResult: Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala index d6d1b8fc6325f..c371f0b40c227 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala @@ -269,8 +269,7 @@ case class RegrSlope(left: Expression, right: Expression) extends DeclarativeAgg override lazy val aggBufferAttributes: Seq[AttributeReference] = covarPop.aggBufferAttributes ++ varPop.aggBufferAttributes - override lazy val initialValues: Seq[Expression] = - covarPop.initialValues ++ varPop.initialValues + override lazy val initialValues: Seq[Expression] = covarPop.initialValues ++ varPop.initialValues override lazy val updateExpressions: Seq[Expression] = covarPop.updateExpressions ++ varPop.updateExpressions @@ -291,3 +290,57 @@ case class RegrSlope(left: Expression, right: Expression) extends DeclarativeAgg newLeft: Expression, newRight: Expression): RegrSlope = copy(left = newLeft, right = newRight) } + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the intercept of the univariate linear regression line for non-null pairs in a group, where `y` is the dependent variable and `x` is the independent variable.", + examples = """ + Examples: + > SELECT _FUNC_(y, x) FROM VALUES (1,1), (2,2), (3,3) AS tab(y, x); + 0.0 + > SELECT _FUNC_(y, x) FROM VALUES (1, null) AS tab(y, x); + NULL + > SELECT _FUNC_(y, x) FROM VALUES (null, 1) AS tab(y, x); + NULL + """, + group = "agg_funcs", + since = "3.4.0") +// scalastyle:on line.size.limit +case class RegrIntercept(left: Expression, right: Expression) extends DeclarativeAggregate + with ImplicitCastInputTypes with BinaryLike[Expression] { + + private val covarPop = new CovPopulation(right, left) + + private val varPop = new VariancePop(right) + + override def nullable: Boolean = true + + override def dataType: DataType = DoubleType + + override def inputTypes: Seq[DoubleType] = Seq(DoubleType, DoubleType) + + override lazy val aggBufferAttributes: Seq[AttributeReference] = + covarPop.aggBufferAttributes ++ varPop.aggBufferAttributes + + override lazy val initialValues: Seq[Expression] = covarPop.initialValues ++ varPop.initialValues + + override lazy val updateExpressions: Seq[Expression] = + covarPop.updateExpressions ++ varPop.updateExpressions + + override lazy val mergeExpressions: Seq[Expression] = + covarPop.mergeExpressions ++ varPop.mergeExpressions + + override lazy val evaluateExpression: Expression = { + If(covarPop.n === 0.0, Literal.create(null, DoubleType), + covarPop.yAvg - covarPop.ck / varPop.m2 * covarPop.xAvg) + } + + override lazy val inputAggBufferAttributes: Seq[AttributeReference] = + covarPop.inputAggBufferAttributes ++ varPop.inputAggBufferAttributes + + override def prettyName: String = "regr_intercept" + + override protected def withNewChildrenInternal( + newLeft: Expression, newRight: Expression): RegrIntercept = + copy(left = newLeft, right = newRight) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala index 5f19a4b79cd5c..9be586a31fc42 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala @@ -64,4 +64,21 @@ class AggregateExpressionSuite extends SparkFunSuite { assert(RegrSlope(Literal(3.0D), Literal(1D)).checkInputDataTypes() === TypeCheckResult.TypeCheckSuccess) } + + test("test regr_intercept input types") { + val checkResult1 = RegrIntercept(Literal("a"), Literal(1)).checkInputDataTypes() + assert(checkResult1.isInstanceOf[TypeCheckResult.TypeCheckFailure]) + assert(checkResult1.asInstanceOf[TypeCheckResult.TypeCheckFailure].message + .contains("argument 1 requires double type, however, ''a'' is of string type")) + val checkResult2 = RegrIntercept(Literal(3.0D), Literal('b')).checkInputDataTypes() + assert(checkResult2.isInstanceOf[TypeCheckResult.TypeCheckFailure]) + assert(checkResult2.asInstanceOf[TypeCheckResult.TypeCheckFailure].message + .contains("argument 2 requires double type, however, ''b'' is of string type")) + val checkResult3 = RegrIntercept(Literal(3.0D), Literal(Array(0))).checkInputDataTypes() + assert(checkResult3.isInstanceOf[TypeCheckResult.TypeCheckFailure]) + assert(checkResult3.asInstanceOf[TypeCheckResult.TypeCheckFailure].message + .contains("argument 2 requires double type, however, '[0]' is of array type")) + assert(RegrIntercept(Literal(3.0D), Literal(1D)).checkInputDataTypes() === + TypeCheckResult.TypeCheckSuccess) + } } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 2d4eef394053e..2840bfd98f750 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -374,6 +374,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.RegrAvgX | regr_avgx | SELECT regr_avgx(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.RegrAvgY | regr_avgy | SELECT regr_avgy(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.RegrCount | regr_count | SELECT regr_count(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct | +| org.apache.spark.sql.catalyst.expressions.aggregate.RegrIntercept | regr_intercept | SELECT regr_intercept(y, x) FROM VALUES (1,1), (2,2), (3,3) AS tab(y, x) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.RegrR2 | regr_r2 | SELECT regr_r2(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.RegrSXX | regr_sxx | SELECT regr_sxx(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.RegrSXY | regr_sxy | SELECT regr_sxy(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct | diff --git a/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql b/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql index 887e837e80f21..c7cb5bf1117a7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql @@ -44,3 +44,9 @@ SELECT regr_slope(y, x) FROM testRegression; SELECT regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL; SELECT k, regr_slope(y, x) FROM testRegression GROUP BY k; SELECT k, regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY k; + +-- SPARK-37623: Support ANSI Aggregate Function: regr_intercept +SELECT regr_intercept(y, x) FROM testRegression; +SELECT regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL; +SELECT k, regr_intercept(y, x) FROM testRegression GROUP BY k; +SELECT k, regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql index 5f3d6b7e60c37..1152d77da0cf4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql @@ -86,7 +86,7 @@ SELECT regr_syy(b, a) FROM aggtest; SELECT regr_sxy(b, a) FROM aggtest; SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest; SELECT regr_r2(b, a) FROM aggtest; --- SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest; +SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest; SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest; SELECT corr(b, a) FROM aggtest; diff --git a/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part1.sql index c760da09ded19..4b816fb682b55 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part1.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udf/postgreSQL/udf-aggregates_part1.sql @@ -86,7 +86,7 @@ SELECT regr_syy(b, a) FROM aggtest; SELECT regr_sxy(b, a) FROM aggtest; SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest; SELECT regr_r2(b, a) FROM aggtest; --- SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest; +SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest; SELECT udf(covar_pop(b, udf(a))), covar_samp(udf(b), a) FROM aggtest; SELECT corr(b, udf(a)) FROM aggtest; diff --git a/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out b/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out index 04378f32838df..0066ff8f064dd 100644 --- a/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 29 +-- Number of queries: 33 -- !query @@ -244,3 +244,36 @@ SELECT k, regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT struct -- !query output 2 0.8314087759815244 + + +-- !query +SELECT regr_intercept(y, x) FROM testRegression +-- !query schema +struct +-- !query output +1.1547344110854496 + + +-- !query +SELECT regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL +-- !query schema +struct +-- !query output +1.1547344110854496 + + +-- !query +SELECT k, regr_intercept(y, x) FROM testRegression GROUP BY k +-- !query schema +struct +-- !query output +1 NULL +2 1.1547344110854496 + + +-- !query +SELECT k, regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY k +-- !query schema +struct +-- !query output +2 1.1547344110854496 diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out index 6d5b89f0c637b..ee44cd9171a1f 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 56 +-- Number of queries: 57 -- !query @@ -336,6 +336,14 @@ struct 0.019497798203180258 +-- !query +SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest +-- !query schema +struct +-- !query output +0.5127507004412711 82.56199260123087 + + -- !query SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out index 268a8553cb487..49d3fbd7d877e 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 55 +-- Number of queries: 56 -- !query @@ -327,6 +327,14 @@ struct 0.019497798203180258 +-- !query +SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest +-- !query schema +struct +-- !query output +0.5127507004412711 82.56199260123087 + + -- !query SELECT udf(covar_pop(b, udf(a))), covar_samp(udf(b), a) FROM aggtest -- !query schema