Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ object FunctionRegistry {
expression[RegrSXY]("regr_sxy"),
expression[RegrSYY]("regr_syy"),
expression[RegrSlope]("regr_slope"),
expression[RegrIntercept]("regr_intercept"),

// string functions
expression[Ascii]("ascii"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ abstract class Covariance(val left: Expression, val right: Expression, nullOnDiv
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)

protected[sql] val n = AttributeReference("n", DoubleType, nullable = false)()
protected val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)()
protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)()
protected[sql] val xAvg = AttributeReference("xAvg", DoubleType, nullable = false)()
protected[sql] val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)()
protected[sql] val ck = AttributeReference("ck", DoubleType, nullable = false)()

protected def divideByZeroEvalResult: Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,7 @@ case class RegrSlope(left: Expression, right: Expression) extends DeclarativeAgg
override lazy val aggBufferAttributes: Seq[AttributeReference] =
covarPop.aggBufferAttributes ++ varPop.aggBufferAttributes

override lazy val initialValues: Seq[Expression] =
covarPop.initialValues ++ varPop.initialValues
override lazy val initialValues: Seq[Expression] = covarPop.initialValues ++ varPop.initialValues

override lazy val updateExpressions: Seq[Expression] =
covarPop.updateExpressions ++ varPop.updateExpressions
Expand All @@ -291,3 +290,57 @@ case class RegrSlope(left: Expression, right: Expression) extends DeclarativeAgg
newLeft: Expression, newRight: Expression): RegrSlope =
copy(left = newLeft, right = newRight)
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(y, x) - Returns the intercept of the univariate linear regression line for non-null pairs in a group, where `y` is the dependent variable and `x` is the independent variable.",
examples = """
Examples:
> SELECT _FUNC_(y, x) FROM VALUES (1,1), (2,2), (3,3) AS tab(y, x);
0.0
> SELECT _FUNC_(y, x) FROM VALUES (1, null) AS tab(y, x);
NULL
> SELECT _FUNC_(y, x) FROM VALUES (null, 1) AS tab(y, x);
NULL
""",
group = "agg_funcs",
since = "3.4.0")
// scalastyle:on line.size.limit
case class RegrIntercept(left: Expression, right: Expression) extends DeclarativeAggregate
with ImplicitCastInputTypes with BinaryLike[Expression] {

private val covarPop = new CovPopulation(right, left)

private val varPop = new VariancePop(right)

override def nullable: Boolean = true

override def dataType: DataType = DoubleType

override def inputTypes: Seq[DoubleType] = Seq(DoubleType, DoubleType)

override lazy val aggBufferAttributes: Seq[AttributeReference] =
covarPop.aggBufferAttributes ++ varPop.aggBufferAttributes

override lazy val initialValues: Seq[Expression] = covarPop.initialValues ++ varPop.initialValues

override lazy val updateExpressions: Seq[Expression] =
covarPop.updateExpressions ++ varPop.updateExpressions

override lazy val mergeExpressions: Seq[Expression] =
covarPop.mergeExpressions ++ varPop.mergeExpressions

override lazy val evaluateExpression: Expression = {
If(covarPop.n === 0.0, Literal.create(null, DoubleType),
covarPop.yAvg - covarPop.ck / varPop.m2 * covarPop.xAvg)
}

override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
covarPop.inputAggBufferAttributes ++ varPop.inputAggBufferAttributes

override def prettyName: String = "regr_intercept"

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): RegrIntercept =
copy(left = newLeft, right = newRight)
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,21 @@ class AggregateExpressionSuite extends SparkFunSuite {
assert(RegrSlope(Literal(3.0D), Literal(1D)).checkInputDataTypes() ===
TypeCheckResult.TypeCheckSuccess)
}

test("test regr_intercept input types") {
val checkResult1 = RegrIntercept(Literal("a"), Literal(1)).checkInputDataTypes()
assert(checkResult1.isInstanceOf[TypeCheckResult.TypeCheckFailure])
assert(checkResult1.asInstanceOf[TypeCheckResult.TypeCheckFailure].message
.contains("argument 1 requires double type, however, ''a'' is of string type"))
val checkResult2 = RegrIntercept(Literal(3.0D), Literal('b')).checkInputDataTypes()
assert(checkResult2.isInstanceOf[TypeCheckResult.TypeCheckFailure])
assert(checkResult2.asInstanceOf[TypeCheckResult.TypeCheckFailure].message
.contains("argument 2 requires double type, however, ''b'' is of string type"))
val checkResult3 = RegrIntercept(Literal(3.0D), Literal(Array(0))).checkInputDataTypes()
assert(checkResult3.isInstanceOf[TypeCheckResult.TypeCheckFailure])
assert(checkResult3.asInstanceOf[TypeCheckResult.TypeCheckFailure].message
.contains("argument 2 requires double type, however, '[0]' is of array<int> type"))
assert(RegrIntercept(Literal(3.0D), Literal(1D)).checkInputDataTypes() ===
TypeCheckResult.TypeCheckSuccess)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@
| org.apache.spark.sql.catalyst.expressions.aggregate.RegrAvgX | regr_avgx | SELECT regr_avgx(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct<regr_avgx(y, x):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.RegrAvgY | regr_avgy | SELECT regr_avgy(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct<regr_avgy(y, x):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.RegrCount | regr_count | SELECT regr_count(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct<regr_count(y, x):bigint> |
| org.apache.spark.sql.catalyst.expressions.aggregate.RegrIntercept | regr_intercept | SELECT regr_intercept(y, x) FROM VALUES (1,1), (2,2), (3,3) AS tab(y, x) | struct<regr_intercept(y, x):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.RegrR2 | regr_r2 | SELECT regr_r2(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct<regr_r2(y, x):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.RegrSXX | regr_sxx | SELECT regr_sxx(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct<regr_sxx(y, x):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.RegrSXY | regr_sxy | SELECT regr_sxy(y, x) FROM VALUES (1, 2), (2, 2), (2, 3), (2, 4) AS tab(y, x) | struct<regr_sxy(y, x):double> |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,9 @@ SELECT regr_slope(y, x) FROM testRegression;
SELECT regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL;
SELECT k, regr_slope(y, x) FROM testRegression GROUP BY k;
SELECT k, regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY k;

-- SPARK-37623: Support ANSI Aggregate Function: regr_intercept
SELECT regr_intercept(y, x) FROM testRegression;
SELECT regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL;
SELECT k, regr_intercept(y, x) FROM testRegression GROUP BY k;
SELECT k, regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY k;
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ SELECT regr_syy(b, a) FROM aggtest;
SELECT regr_sxy(b, a) FROM aggtest;
SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest;
SELECT regr_r2(b, a) FROM aggtest;
-- SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest;
SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest;
SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest;
SELECT corr(b, a) FROM aggtest;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ SELECT regr_syy(b, a) FROM aggtest;
SELECT regr_sxy(b, a) FROM aggtest;
SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest;
SELECT regr_r2(b, a) FROM aggtest;
-- SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest;
SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest;
SELECT udf(covar_pop(b, udf(a))), covar_samp(udf(b), a) FROM aggtest;
SELECT corr(b, udf(a)) FROM aggtest;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 29
-- Number of queries: 33


-- !query
Expand Down Expand Up @@ -244,3 +244,36 @@ SELECT k, regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT
struct<k:int,regr_slope(y, x):double>
-- !query output
2 0.8314087759815244


-- !query
SELECT regr_intercept(y, x) FROM testRegression
-- !query schema
struct<regr_intercept(y, x):double>
-- !query output
1.1547344110854496


-- !query
SELECT regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL
-- !query schema
struct<regr_intercept(y, x):double>
-- !query output
1.1547344110854496


-- !query
SELECT k, regr_intercept(y, x) FROM testRegression GROUP BY k
-- !query schema
struct<k:int,regr_intercept(y, x):double>
-- !query output
1 NULL
2 1.1547344110854496


-- !query
SELECT k, regr_intercept(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY k
-- !query schema
struct<k:int,regr_intercept(y, x):double>
-- !query output
2 1.1547344110854496
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 56
-- Number of queries: 57


-- !query
Expand Down Expand Up @@ -336,6 +336,14 @@ struct<regr_r2(b, a):double>
0.019497798203180258


-- !query
SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest
-- !query schema
struct<regr_slope(b, a):double,regr_intercept(b, a):double>
-- !query output
0.5127507004412711 82.56199260123087


-- !query
SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest
-- !query schema
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 55
-- Number of queries: 56


-- !query
Expand Down Expand Up @@ -327,6 +327,14 @@ struct<regr_r2(b, a):double>
0.019497798203180258


-- !query
SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest
-- !query schema
struct<regr_slope(b, a):double,regr_intercept(b, a):double>
-- !query output
0.5127507004412711 82.56199260123087


-- !query
SELECT udf(covar_pop(b, udf(a))), covar_samp(udf(b), a) FROM aggtest
-- !query schema
Expand Down