diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 268f5734813ba..077dfc6770d94 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2147,7 +2147,7 @@ test_that("group by, agg functions", { df3 <- agg(gd, age = "stddev") expect_is(df3, "SparkDataFrame") df3_local <- collect(df3) - expect_true(is.nan(df3_local[df3_local$name == "Andy", ][1, 2])) + expect_true(is.na(df3_local[df3_local$name == "Andy", ][1, 2])) df4 <- agg(gd, sumAge = sum(df$age)) expect_is(df4, "SparkDataFrame") @@ -2178,7 +2178,7 @@ test_that("group by, agg functions", { df7 <- agg(gd2, value = "stddev") df7_local <- collect(df7) expect_true(abs(df7_local[df7_local$name == "ID1", ][1, 2] - 6.928203) < 1e-6) - expect_true(is.nan(df7_local[df7_local$name == "ID2", ][1, 2])) + expect_true(is.na(df7_local[df7_local$name == "ID2", ][1, 2])) mockLines3 <- c("{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"Andy\", \"age\":30}", diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index feff2c7e9f543..c1de58d85d5bf 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -24,6 +24,8 @@ license: | ## Upgrading from Spark SQL 3.0 to 3.1 + - In Spark 3.1, statistical aggregation function includes `std`, `stddev`, `stddev_samp`, `variance`, `var_samp`, `skewness`, `kurtosis`, `covar_samp`, `corr` will return `NULL` instead of `Double.NaN` when `DivideByZero` occurs during expression evaluation, for example, when `stddev_samp` applied on a single element set. In Spark version 3.0 and earlier, it will return `Double.NaN` in such case. To restore the behavior before Spark 3.1, you can set `spark.sql.legacy.statisticalAggregate` to `true`. + - In Spark 3.1, grouping_id() returns long values. In Spark version 3.0 and earlier, this function returns int values. To restore the behavior before Spark 3.1, you can set `spark.sql.legacy.integerGroupingId` to `true`. - In Spark 3.1, SQL UI data adopts the `formatted` mode for the query plan explain results. To restore the behavior before Spark 3.1, you can set `spark.sql.ui.explainMode` to `extended`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index deaa49bf423b1..f72d9be205df3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -450,14 +450,20 @@ object TypeCoercion { case Abs(e @ StringType()) => Abs(Cast(e, DoubleType)) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) - case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) - case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) + case s @ StddevPop(e @ StringType(), _) => + s.withNewChildren(Seq(Cast(e, DoubleType))) + case s @ StddevSamp(e @ StringType(), _) => + s.withNewChildren(Seq(Cast(e, DoubleType))) case UnaryMinus(e @ StringType()) => UnaryMinus(Cast(e, DoubleType)) case UnaryPositive(e @ StringType()) => UnaryPositive(Cast(e, DoubleType)) - case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) - case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) - case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) - case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType)) + case v @ VariancePop(e @ StringType(), _) => + v.withNewChildren(Seq(Cast(e, DoubleType))) + case v @ VarianceSamp(e @ StringType(), _) => + v.withNewChildren(Seq(Cast(e, DoubleType))) + case s @ Skewness(e @ StringType(), _) => + s.withNewChildren(Seq(Cast(e, DoubleType))) + case k @ Kurtosis(e @ StringType(), _) => + k.withNewChildren(Seq(Cast(e, DoubleType))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 53759ca3d9165..2cc9adb5aa06e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -43,7 +44,7 @@ import org.apache.spark.sql.types._ * * @param child to compute central moments of. */ -abstract class CentralMomentAgg(child: Expression) +abstract class CentralMomentAgg(child: Expression, nullOnDivideByZero: Boolean) extends DeclarativeAggregate with ImplicitCastInputTypes { /** @@ -62,6 +63,13 @@ abstract class CentralMomentAgg(child: Expression) protected val m3 = AttributeReference("m3", DoubleType, nullable = false)() protected val m4 = AttributeReference("m4", DoubleType, nullable = false)() + protected def divideByZeroEvalResult: Expression = { + if (nullOnDivideByZero) Literal.create(null, DoubleType) else Double.NaN + } + + override def stringArgs: Iterator[Any] = + super.stringArgs.filter(_.isInstanceOf[Expression]) + private def trimHigherOrder[T](expressions: Seq[T]) = expressions.take(momentOrder + 1) override val aggBufferAttributes = trimHigherOrder(Seq(n, avg, m2, m3, m4)) @@ -145,7 +153,12 @@ abstract class CentralMomentAgg(child: Expression) group = "agg_funcs", since = "1.6.0") // scalastyle:on line.size.limit -case class StddevPop(child: Expression) extends CentralMomentAgg(child) { +case class StddevPop( + child: Expression, + nullOnDivideByZero: Boolean = !SQLConf.get.legacyStatisticalAggregate) + extends CentralMomentAgg(child, nullOnDivideByZero) { + + def this(child: Expression) = this(child, !SQLConf.get.legacyStatisticalAggregate) override protected def momentOrder = 2 @@ -168,13 +181,18 @@ case class StddevPop(child: Expression) extends CentralMomentAgg(child) { group = "agg_funcs", since = "1.6.0") // scalastyle:on line.size.limit -case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { +case class StddevSamp( + child: Expression, + nullOnDivideByZero: Boolean = !SQLConf.get.legacyStatisticalAggregate) + extends CentralMomentAgg(child, nullOnDivideByZero) { + + def this(child: Expression) = this(child, !SQLConf.get.legacyStatisticalAggregate) override protected def momentOrder = 2 override val evaluateExpression: Expression = { If(n === 0.0, Literal.create(null, DoubleType), - If(n === 1.0, Double.NaN, sqrt(m2 / (n - 1.0)))) + If(n === 1.0, divideByZeroEvalResult, sqrt(m2 / (n - 1.0)))) } override def prettyName: String = @@ -191,7 +209,12 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { """, group = "agg_funcs", since = "1.6.0") -case class VariancePop(child: Expression) extends CentralMomentAgg(child) { +case class VariancePop( + child: Expression, + nullOnDivideByZero: Boolean = !SQLConf.get.legacyStatisticalAggregate) + extends CentralMomentAgg(child, nullOnDivideByZero) { + + def this(child: Expression) = this(child, !SQLConf.get.legacyStatisticalAggregate) override protected def momentOrder = 2 @@ -212,13 +235,18 @@ case class VariancePop(child: Expression) extends CentralMomentAgg(child) { """, group = "agg_funcs", since = "1.6.0") -case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { +case class VarianceSamp( + child: Expression, + nullOnDivideByZero: Boolean = !SQLConf.get.legacyStatisticalAggregate) + extends CentralMomentAgg(child, nullOnDivideByZero) { + + def this(child: Expression) = this(child, !SQLConf.get.legacyStatisticalAggregate) override protected def momentOrder = 2 override val evaluateExpression: Expression = { If(n === 0.0, Literal.create(null, DoubleType), - If(n === 1.0, Double.NaN, m2 / (n - 1.0))) + If(n === 1.0, divideByZeroEvalResult, m2 / (n - 1.0))) } override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("var_samp") @@ -235,7 +263,12 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { """, group = "agg_funcs", since = "1.6.0") -case class Skewness(child: Expression) extends CentralMomentAgg(child) { +case class Skewness( + child: Expression, + nullOnDivideByZero: Boolean = !SQLConf.get.legacyStatisticalAggregate) + extends CentralMomentAgg(child, nullOnDivideByZero) { + + def this(child: Expression) = this(child, !SQLConf.get.legacyStatisticalAggregate) override def prettyName: String = "skewness" @@ -243,7 +276,7 @@ case class Skewness(child: Expression) extends CentralMomentAgg(child) { override val evaluateExpression: Expression = { If(n === 0.0, Literal.create(null, DoubleType), - If(m2 === 0.0, Double.NaN, sqrt(n) * m3 / sqrt(m2 * m2 * m2))) + If(m2 === 0.0, divideByZeroEvalResult, sqrt(n) * m3 / sqrt(m2 * m2 * m2))) } } @@ -258,13 +291,18 @@ case class Skewness(child: Expression) extends CentralMomentAgg(child) { """, group = "agg_funcs", since = "1.6.0") -case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { +case class Kurtosis( + child: Expression, + nullOnDivideByZero: Boolean = !SQLConf.get.legacyStatisticalAggregate) + extends CentralMomentAgg(child, nullOnDivideByZero) { + + def this(child: Expression) = this(child, !SQLConf.get.legacyStatisticalAggregate) override protected def momentOrder = 4 override val evaluateExpression: Expression = { If(n === 0.0, Literal.create(null, DoubleType), - If(m2 === 0.0, Double.NaN, n * m4 / (m2 * m2) - 3.0)) + If(m2 === 0.0, divideByZeroEvalResult, n * m4 / (m2 * m2) - 3.0)) } override def prettyName: String = "kurtosis" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 9ef05bb5d4fec..737e8cd3ffa41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -28,7 +29,7 @@ import org.apache.spark.sql.types._ * Definition of Pearson correlation can be found at * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient */ -abstract class PearsonCorrelation(x: Expression, y: Expression) +abstract class PearsonCorrelation(x: Expression, y: Expression, nullOnDivideByZero: Boolean) extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = Seq(x, y) @@ -43,6 +44,13 @@ abstract class PearsonCorrelation(x: Expression, y: Expression) protected val xMk = AttributeReference("xMk", DoubleType, nullable = false)() protected val yMk = AttributeReference("yMk", DoubleType, nullable = false)() + protected def divideByZeroEvalResult: Expression = { + if (nullOnDivideByZero) Literal.create(null, DoubleType) else Double.NaN + } + + override def stringArgs: Iterator[Any] = + super.stringArgs.filter(_.isInstanceOf[Expression]) + override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck, xMk, yMk) override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0)) @@ -102,12 +110,18 @@ abstract class PearsonCorrelation(x: Expression, y: Expression) group = "agg_funcs", since = "1.6.0") // scalastyle:on line.size.limit -case class Corr(x: Expression, y: Expression) - extends PearsonCorrelation(x, y) { +case class Corr( + x: Expression, + y: Expression, + nullOnDivideByZero: Boolean = !SQLConf.get.legacyStatisticalAggregate) + extends PearsonCorrelation(x, y, nullOnDivideByZero) { + + def this(x: Expression, y: Expression) = + this(x, y, !SQLConf.get.legacyStatisticalAggregate) override val evaluateExpression: Expression = { If(n === 0.0, Literal.create(null, DoubleType), - If(n === 1.0, Double.NaN, ck / sqrt(xMk * yMk))) + If(n === 1.0, divideByZeroEvalResult, ck / sqrt(xMk * yMk))) } override def prettyName: String = "corr" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index f03c2f2710a04..7c4d6ded6559e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** * Compute the covariance between two expressions. * When applied on empty data (i.e., count is zero), it returns NULL. */ -abstract class Covariance(x: Expression, y: Expression) +abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Boolean) extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = Seq(x, y) @@ -38,6 +39,13 @@ abstract class Covariance(x: Expression, y: Expression) protected val yAvg = AttributeReference("yAvg", DoubleType, nullable = false)() protected val ck = AttributeReference("ck", DoubleType, nullable = false)() + protected def divideByZeroEvalResult: Expression = { + if (nullOnDivideByZero) Literal.create(null, DoubleType) else Double.NaN + } + + override def stringArgs: Iterator[Any] = + super.stringArgs.filter(_.isInstanceOf[Expression]) + override val aggBufferAttributes: Seq[AttributeReference] = Seq(n, xAvg, yAvg, ck) override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0)) @@ -88,7 +96,15 @@ abstract class Covariance(x: Expression, y: Expression) """, group = "agg_funcs", since = "2.0.0") -case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) { +case class CovPopulation( + left: Expression, + right: Expression, + nullOnDivideByZero: Boolean = !SQLConf.get.legacyStatisticalAggregate) + extends Covariance(left, right, nullOnDivideByZero) { + + def this(left: Expression, right: Expression) = + this(left, right, !SQLConf.get.legacyStatisticalAggregate) + override val evaluateExpression: Expression = { If(n === 0.0, Literal.create(null, DoubleType), ck / n) } @@ -105,10 +121,18 @@ case class CovPopulation(left: Expression, right: Expression) extends Covariance """, group = "agg_funcs", since = "2.0.0") -case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) { +case class CovSample( + left: Expression, + right: Expression, + nullOnDivideByZero: Boolean = !SQLConf.get.legacyStatisticalAggregate) + extends Covariance(left, right, nullOnDivideByZero) { + + def this(left: Expression, right: Expression) = + this(left, right, !SQLConf.get.legacyStatisticalAggregate) + override val evaluateExpression: Expression = { If(n === 0.0, Literal.create(null, DoubleType), - If(n === 1.0, Double.NaN, ck / (n - 1.0))) + If(n === 1.0, divideByZeroEvalResult, ck / (n - 1.0))) } override def prettyName: String = "covar_samp" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8cbdbfe16d2bc..e25c4e52e4e9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2342,6 +2342,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_STATISTICAL_AGGREGATE = + buildConf("spark.sql.legacy.statisticalAggregate") + .internal() + .doc("When set to true, statistical aggregate function returns Double.NaN " + + "if divide by zero occurred during expression evaluation, otherwise, it returns null. " + + "Before version 3.1.0, it returns NaN in divideByZero case by default.") + .version("3.1.0") + .booleanConf + .createWithDefault(false) + val TRUNCATE_TABLE_IGNORE_PERMISSION_ACL = buildConf("spark.sql.truncateTable.ignorePermissionAcl.enabled") .internal() @@ -3364,6 +3374,8 @@ class SQLConf extends Serializable with Logging { def allowNegativeScaleOfDecimalEnabled: Boolean = getConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED) + def legacyStatisticalAggregate: Boolean = getConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE) + def truncateTableIgnorePermissionAcl: Boolean = getConf(SQLConf.TRUNCATE_TABLE_IGNORE_PERMISSION_ACL) diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out index f7bba96738eab..212365f92946c 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/aggregates_part1.sql.out @@ -143,7 +143,7 @@ SELECT var_pop(1.0), var_samp(2.0) -- !query schema struct -- !query output -0.0 NaN +0.0 NULL -- !query @@ -151,7 +151,7 @@ SELECT stddev_pop(CAST(3.0 AS Decimal(38,0))), stddev_samp(CAST(4.0 AS Decimal(3 -- !query schema struct -- !query output -0.0 NaN +0.0 NULL -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out index 4dd4712345a89..f7439d873b4eb 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/window_part4.sql.out @@ -195,7 +195,7 @@ struct -- !query output -NaN +NULL -- !query @@ -2558,7 +2558,7 @@ SELECT var_samp('1') FROM t -- !query schema struct -- !query output -NaN +NULL -- !query @@ -2566,7 +2566,7 @@ SELECT skewness('1') FROM t -- !query schema struct -- !query output -NaN +NULL -- !query @@ -2574,4 +2574,4 @@ SELECT kurtosis('1') FROM t -- !query schema struct -- !query output -NaN +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out index 76637bf578e6f..a428a7a9c923b 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/postgreSQL/udf-aggregates_part1.sql.out @@ -143,7 +143,7 @@ SELECT udf(var_pop(1.0)), var_samp(udf(2.0)) -- !query schema struct -- !query output -0.0 NaN +0.0 NULL -- !query @@ -151,7 +151,7 @@ SELECT stddev_pop(udf(CAST(3.0 AS Decimal(38,0)))), stddev_samp(CAST(udf(4.0) AS -- !query schema struct -- !query output -0.0 NaN +0.0 NULL -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-window.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-window.sql.out index a84070535b658..928b9ebb12364 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-window.sql.out @@ -289,13 +289,13 @@ ORDER BY cate, udf(val) struct,collect_set:array,skewness:double,kurtosis:double> -- !query output NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NULL NULL -3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 0.0 NaN NaN 0.0 [3] [3] NaN NaN -NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NaN NaN +3 NULL 3 3 3 1 3 3.0 NULL NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NULL 1 0.0 NULL NULL 0.0 [3] [3] NULL NULL +NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NULL NULL 1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5 1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5 2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 4.772185885555555E8 1.0 0.5773502691896258 0.4714045207910317 [1,1,2] [1,2] 1.1539890888012805 -0.6672217220327235 -1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 NULL NULL NaN 0.0 [1] [1] NaN NaN -2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 0.0 NaN 0.7071067811865476 0.5 [1,2] [1,2] 0.0 -2.0000000000000013 +1 b 1 1 1 1 1 1.0 NULL 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NULL 1 NULL NULL NULL 0.0 [1] [1] NULL NULL +2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 0.0 NULL 0.7071067811865476 0.5 [1,2] [1,2] 0.0 -2.0000000000000013 3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 5.3687091175E8 1.0 1.0 0.816496580927726 [1,2,3] [1,2,3] 0.7057890433107311 -1.4999999999999984 diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index a8875fd449bad..028dd7a12d25d 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 29 +-- Number of queries: 32 -- !query @@ -313,13 +313,13 @@ ORDER BY cate, val struct,collect_set:array,skewness:double,kurtosis:double> -- !query output NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NULL NULL -3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 0.0 NaN NaN 0.0 [3] [3] NaN NaN -NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NaN NaN +3 NULL 3 3 3 1 3 3.0 NULL NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NULL 1 0.0 NULL NULL 0.0 [3] [3] NULL NULL +NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NULL NULL 1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5 1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5 2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 4.772185885555555E8 1.0 0.5773502691896258 0.4714045207910317 [1,1,2] [1,2] 1.1539890888012805 -0.6672217220327235 -1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 NULL NULL NaN 0.0 [1] [1] NaN NaN -2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 0.0 NaN 0.7071067811865476 0.5 [1,2] [1,2] 0.0 -2.0000000000000013 +1 b 1 1 1 1 1 1.0 NULL 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NULL 1 NULL NULL NULL 0.0 [1] [1] NULL NULL +2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 0.0 NULL 0.7071067811865476 0.5 [1,2] [1,2] 0.0 -2.0000000000000013 3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 5.3687091175E8 1.0 1.0 0.816496580927726 [1,2,3] [1,2,3] 0.7057890433107311 -1.4999999999999984 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 353444b664412..d4e64aa03df0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -456,25 +456,51 @@ class DataFrameAggregateSuite extends QueryTest } test("zero moments") { - val input = Seq((1, 2)).toDF("a", "b") - checkAnswer( - input.agg(stddev($"a"), stddev_samp($"a"), stddev_pop($"a"), variance($"a"), - var_samp($"a"), var_pop($"a"), skewness($"a"), kurtosis($"a")), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, - Double.NaN, Double.NaN)) + withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true") { + val input = Seq((1, 2)).toDF("a", "b") + checkAnswer( + input.agg(stddev($"a"), stddev_samp($"a"), stddev_pop($"a"), variance($"a"), + var_samp($"a"), var_pop($"a"), skewness($"a"), kurtosis($"a")), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) - checkAnswer( - input.agg( - expr("stddev(a)"), - expr("stddev_samp(a)"), - expr("stddev_pop(a)"), - expr("variance(a)"), - expr("var_samp(a)"), - expr("var_pop(a)"), - expr("skewness(a)"), - expr("kurtosis(a)")), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, - Double.NaN, Double.NaN)) + checkAnswer( + input.agg( + expr("stddev(a)"), + expr("stddev_samp(a)"), + expr("stddev_pop(a)"), + expr("variance(a)"), + expr("var_samp(a)"), + expr("var_pop(a)"), + expr("skewness(a)"), + expr("kurtosis(a)")), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) + } + } + + test("SPARK-13860: zero moments LEGACY_STATISTICAL_AGGREGATE off") { + withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "false") { + val input = Seq((1, 2)).toDF("a", "b") + checkAnswer( + input.agg(stddev($"a"), stddev_samp($"a"), stddev_pop($"a"), variance($"a"), + var_samp($"a"), var_pop($"a"), skewness($"a"), kurtosis($"a")), + Row(null, null, 0.0, null, null, 0.0, + null, null)) + + checkAnswer( + input.agg( + expr("stddev(a)"), + expr("stddev_samp(a)"), + expr("stddev_pop(a)"), + expr("variance(a)"), + expr("var_samp(a)"), + expr("var_pop(a)"), + expr("skewness(a)"), + expr("kurtosis(a)")), + Row(null, null, 0.0, null, null, 0.0, + null, null)) + } } test("null moments") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index c5dcdc44cc64f..616e333033aa9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -94,89 +94,187 @@ class DataFrameWindowFunctionsSuite extends QueryTest } test("corr, covar_pop, stddev_pop functions in specific window") { - val df = Seq( - ("a", "p1", 10.0, 20.0), - ("b", "p1", 20.0, 10.0), - ("c", "p2", 20.0, 20.0), - ("d", "p2", 20.0, 20.0), - ("e", "p3", 0.0, 0.0), - ("f", "p3", 6.0, 12.0), - ("g", "p3", 6.0, 12.0), - ("h", "p3", 8.0, 16.0), - ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") - checkAnswer( - df.select( - $"key", - corr("value1", "value2").over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - covar_pop("value1", "value2") - .over(Window.partitionBy("partitionId") + withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") + checkAnswer( + df.select( + $"key", + corr("value1", "value2").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + covar_pop("value1", "value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + + // As stddev_pop(expr) = sqrt(var_pop(expr)) + // the "stddev_pop" column can be calculated from the "var_pop" column. + // + // As corr(expr1, expr2) = covar_pop(expr1, expr2) / (stddev_pop(expr1) * stddev_pop(expr2)) + // the "corr" column can be calculated from the "covar_pop" and the two "stddev_pop" columns + Seq( + Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("c", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("d", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("e", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("i", Double.NaN, 0.0, 0.0, 0.0, 0.0, 0.0))) + } + } + + test("SPARK-13860: " + + "corr, covar_pop, stddev_pop functions in specific window " + + "LEGACY_STATISTICAL_AGGREGATE off") { + withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "false") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") + checkAnswer( + df.select( + $"key", + corr("value1", "value2").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + covar_pop("value1", "value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + + // As stddev_pop(expr) = sqrt(var_pop(expr)) + // the "stddev_pop" column can be calculated from the "var_pop" column. + // + // As corr(expr1, expr2) = covar_pop(expr1, expr2) / (stddev_pop(expr1) * stddev_pop(expr2)) + // the "corr" column can be calculated from the "covar_pop" and the two "stddev_pop" columns + Seq( + Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("c", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("d", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("e", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("i", null, 0.0, 0.0, 0.0, 0.0, 0.0))) + } + } + + test("covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window") { + withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "true") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") + checkAnswer( + df.select( + $"key", + covar_samp("value1", "value2").over(Window.partitionBy("partitionId") .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - var_pop("value1") - .over(Window.partitionBy("partitionId") + var_samp("value1").over(Window.partitionBy("partitionId") .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - stddev_pop("value1") - .over(Window.partitionBy("partitionId") + variance("value1").over(Window.partitionBy("partitionId") .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - var_pop("value2") - .over(Window.partitionBy("partitionId") + stddev_samp("value1").over(Window.partitionBy("partitionId") .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - stddev_pop("value2") - .over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), - - // As stddev_pop(expr) = sqrt(var_pop(expr)) - // the "stddev_pop" column can be calculated from the "var_pop" column. - // - // As corr(expr1, expr2) = covar_pop(expr1, expr2) / (stddev_pop(expr1) * stddev_pop(expr2)) - // the "corr" column can be calculated from the "covar_pop" and the two "stddev_pop" columns. - Seq( - Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), - Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), - Row("c", null, 0.0, 0.0, 0.0, 0.0, 0.0), - Row("d", null, 0.0, 0.0, 0.0, 0.0, 0.0), - Row("e", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), - Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), - Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), - Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), - Row("i", Double.NaN, 0.0, 0.0, 0.0, 0.0, 0.0))) + stddev("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)) + ), + Seq( + Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("c", 0.0, 0.0, 0.0, 0.0, 0.0), + Row("d", 0.0, 0.0, 0.0, 0.0, 0.0), + Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), + Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), + Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), + Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), + Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))) + } } - test("covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window") { - val df = Seq( - ("a", "p1", 10.0, 20.0), - ("b", "p1", 20.0, 10.0), - ("c", "p2", 20.0, 20.0), - ("d", "p2", 20.0, 20.0), - ("e", "p3", 0.0, 0.0), - ("f", "p3", 6.0, 12.0), - ("g", "p3", 6.0, 12.0), - ("h", "p3", 8.0, 16.0), - ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") - checkAnswer( - df.select( - $"key", - covar_samp("value1", "value2").over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - var_samp("value1").over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - variance("value1").over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - stddev_samp("value1").over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), - stddev("value1").over(Window.partitionBy("partitionId") - .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)) - ), - Seq( - Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), - Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), - Row("c", 0.0, 0.0, 0.0, 0.0, 0.0), - Row("d", 0.0, 0.0, 0.0, 0.0, 0.0), - Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), - Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), - Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), - Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), - Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))) + test("SPARK-13860: " + + "covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window " + + "LEGACY_STATISTICAL_AGGREGATE off") { + withSQLConf(SQLConf.LEGACY_STATISTICAL_AGGREGATE.key -> "false") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") + checkAnswer( + df.select( + $"key", + covar_samp("value1", "value2").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_samp("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + variance("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_samp("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)) + ), + Seq( + Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("c", 0.0, 0.0, 0.0, 0.0, 0.0), + Row("d", 0.0, 0.0, 0.0, 0.0, 0.0), + Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), + Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), + Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), + Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), + Row("i", null, null, null, null, null))) + } } test("collect_list in ascending ordered window") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 87771eed17b1b..70dcfb05c2ba9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -825,7 +825,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te """ |SELECT corr(b, c) FROM covar_tab WHERE a = 3 """.stripMargin), - Row(Double.NaN) :: Nil) + Row(null) :: Nil) checkAnswer( spark.sql( @@ -834,10 +834,10 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te """.stripMargin), Row(1, null) :: Row(2, null) :: - Row(3, Double.NaN) :: - Row(4, Double.NaN) :: - Row(5, Double.NaN) :: - Row(6, Double.NaN) :: Nil) + Row(3, null) :: + Row(4, null) :: + Row(5, null) :: + Row(6, null) :: Nil) val corr7 = spark.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) @@ -869,7 +869,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te // one row test val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b") - checkAnswer(df3.groupBy().agg(covar_samp("a", "b")), Row(Double.NaN)) + checkAnswer(df3.groupBy().agg(covar_samp("a", "b")), Row(null)) checkAnswer(df3.groupBy().agg(covar_pop("a", "b")), Row(0.0)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala index 15712a18ce751..6bf7bd6cbb90e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala @@ -62,7 +62,6 @@ class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleto // Moved because: // - Spark uses a different default stddev (sample instead of pop) // - Tiny numerical differences in stddev results. - // - Different StdDev behavior when n=1 (NaN instead of 0) checkAnswer(sql(s""" |select p_mfgr,p_name, p_size, |rank() over(distribute by p_mfgr sort by p_name) as r, @@ -88,22 +87,22 @@ class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleto Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 4, 3, 0.6666666666666666, 0.6, 2, 4, 11.0, 15.448840301675292, 2, 6, 2), Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 5, 4, 0.8333333333333334, 0.8, 3, 5, 14.4, 15.388307249337076, 2, 28, 34), Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 6, 5, 1.0, 1.0, 3, 6, 19.0, 17.787636155487327, 2, 42, 6), - Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 1, 1, 0.2, 0.0, 1, 1, 14.0, Double.NaN, 4, 14, 14), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 1, 1, 0.2, 0.0, 1, 1, 14.0, null, 4, 14, 14), Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 2, 2, 0.4, 0.25, 1, 2, 27.0, 18.384776310850235, 4, 40, 14), Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 3, 3, 0.6, 0.5, 2, 3, 18.666666666666668, 19.42506971244462, 4, 2, 14), Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 4, 4, 0.8, 0.75, 2, 4, 20.25, 16.17353805861084, 4, 25, 40), Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 5, 5, 1.0, 1.0, 3, 5, 19.8, 14.042791745233567, 4, 18, 2), - Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 1, 1, 0.2, 0.0, 1, 1, 17.0,Double.NaN, 2, 17, 17), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 1, 1, 0.2, 0.0, 1, 1, 17.0, null, 2, 17, 17), Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 2, 2, 0.4, 0.25, 1, 2, 15.5, 2.1213203435596424, 2, 14, 17), Row("Manufacturer#3", "almond antique metallic orange dim", 19, 3, 3, 0.6, 0.5, 2, 3, 16.666666666666668, 2.516611478423583, 2, 19, 17), Row("Manufacturer#3", "almond antique misty red olive", 1, 4, 4, 0.8, 0.75, 2, 4, 12.75, 8.098353742170895, 2, 1, 14), Row("Manufacturer#3", "almond antique olive coral navajo", 45, 5, 5, 1.0, 1.0, 3, 5, 19.2, 16.037456157383566, 2, 45, 19), - Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 1, 1, 0.2, 0.0, 1, 1, 10.0, Double.NaN, 0, 10, 10), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 1, 1, 0.2, 0.0, 1, 1, 10.0, null, 0, 10, 10), Row("Manufacturer#4", "almond antique violet mint lemon", 39, 2, 2, 0.4, 0.25, 1, 2, 24.5, 20.506096654409877, 0, 39, 10), Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 3, 3, 0.6, 0.5, 2, 3, 25.333333333333332, 14.571661996262929, 0, 27, 10), Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 4, 4, 0.8, 0.75, 2, 4, 20.75, 15.01943185787443, 0, 7, 39), Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 5, 5, 1.0, 1.0, 3, 5, 19.0, 13.583077707206124, 0, 12, 27), - Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 1, 1, 0.2, 0.0, 1, 1, 31.0, Double.NaN, 1, 31, 31), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 1, 1, 0.2, 0.0, 1, 1, 31.0, null, 1, 31, 31), Row("Manufacturer#5", "almond antique medium spring khaki", 6, 2, 2, 0.4, 0.25, 1, 2, 18.5, 17.67766952966369, 1, 6, 31), Row("Manufacturer#5", "almond antique sky peru orange", 2, 3, 3, 0.6, 0.5, 2, 3, 13.0, 15.716233645501712, 1, 2, 31), Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 4, 4, 0.8, 0.75, 2, 4, 21.25, 20.902551678363736, 1, 46, 6),