diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 4d35d37a414e..824024a84cba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ @@ -68,13 +69,8 @@ case class SortOrder( override def children: Seq[Expression] = child +: sameOrderExpressions - override def checkInputDataTypes(): TypeCheckResult = { - if (RowOrdering.isOrderable(dataType)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"cannot sort data type ${dataType.catalogString}") - } - } + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(dataType, prettyName) override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnwrapUDT.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnwrapUDT.scala index cb740672af34..d6b754a297d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnwrapUDT.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnwrapUDT.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, UserDefinedType} - /** * Unwrap UDT data type column into its underlying type. */ @@ -33,8 +34,13 @@ case class UnwrapUDT(child: Expression) extends UnaryExpression with NonSQLExpre if (child.dataType.isInstanceOf[UserDefinedType[_]]) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure( - s"Input type should be UserDefinedType but got ${child.dataType.catalogString}") + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType("UserDefinedType"), + "inputSql" -> toSQLExpr(child), + "inputType" -> toSQLType(child.dataType))) } } override def dataType: DataType = child.dataType.asInstanceOf[UserDefinedType[_]].sqlType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 1e6cc356173e..5e22225db1ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreePattern.{COALESCE, NULL_CHECK, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ - /** * An expression that is evaluated to the first non-null input. * @@ -57,8 +58,14 @@ case class Coalesce(children: Seq[Expression]) override def checkInputDataTypes(): TypeCheckResult = { if (children.length < 1) { - TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName requires at least one argument") + DataTypeMismatch( + errorSubClass = "WRONG_NUM_ARGS", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "expectedNum" -> "> 0", + "actualNum" -> children.length.toString + ) + ) } else { TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), prettyName) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala index de5fde27f6a8..ac31292f0328 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala @@ -22,6 +22,8 @@ import java.util.regex.Pattern import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.trees.UnaryLike @@ -182,7 +184,12 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge override def checkInputDataTypes(): TypeCheckResult = { if (children.size > 3 || children.size < 2) { - TypeCheckResult.TypeCheckFailure(s"$prettyName function requires two or three arguments") + DataTypeMismatch( + errorSubClass = "WRONG_NUM_ARGS", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "expectedNum" -> "[2, 3]", + "actualNum" -> children.length.toString)) } else { super[ExpectsInputTypes].checkInputDataTypes() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index de9d1523bd3c..2ed13944be9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.expressions.xml import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -42,7 +43,14 @@ abstract class XPathExtract override def checkInputDataTypes(): TypeCheckResult = { if (!path.foldable) { - TypeCheckFailure("path should be a string literal") + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "path", + "inputType" -> toSQLType(StringType), + "inputExpr" -> toSQLExpr(path) + ) + ) } else { super.checkInputDataTypes() } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index b718f410be6b..d530be5f5e4b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -327,10 +327,14 @@ class AnalysisErrorSuite extends AnalysisTest { testRelation2.groupBy($"a")(sum(UnresolvedStar(None))), "Invalid usage of '*' in expression 'sum'." :: Nil) - errorTest( + errorClassTest( "sorting by unsupported column types", mapRelation.orderBy($"map".asc), - "sort" :: "type" :: "map" :: Nil) + errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + messageParameters = Map( + "sqlExpr" -> "\"map ASC NULLS FIRST\"", + "functionName" -> "`sortorder`", + "dataType" -> "\"MAP\"")) errorClassTest( "sorting by attributes are not from grouping expressions", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index a7cdd589606c..9bc765df75e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -440,7 +440,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer ) ) - assertError(Coalesce(Nil), "function coalesce requires at least one argument") + val coalesce = Coalesce(Nil) + checkError( + exception = intercept[AnalysisException] { + assertSuccess(coalesce) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_ARGS", + parameters = Map( + "sqlExpr" -> "\"coalesce()\"", + "functionName" -> toSQLId(coalesce.prettyName), + "expectedNum" -> "> 0", + "actualNum" -> "0")) val murmur3Hash = new Murmur3Hash(Nil) checkError( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala index 4e6976f76ea5..9332ef559532 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ @@ -83,4 +84,15 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(SortPrefix(SortOrder(list1, Ascending)), 0L) checkEvaluation(SortPrefix(SortOrder(nullVal, Ascending)), null) } + + test("Cannot sort map type") { + val m = Literal.create(Map(), MapType(StringType, StringType, valueContainsNull = false)) + val sortOrderExpression = SortOrder(m, Ascending) + assert(sortOrderExpression.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "INVALID_ORDERING_TYPE", + messageParameters = Map( + "functionName" -> "`sortorder`", + "dataType" -> "\"MAP\""))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index f9726c4a6dd5..804d9351c7f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -1497,12 +1497,43 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } // arguments checking - assert(ParseUrl(Seq(Literal("1"))).checkInputDataTypes().isFailure) - assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal("3"), Literal("4"))) - .checkInputDataTypes().isFailure) - assert(ParseUrl(Seq(Literal("1"), Literal(2))).checkInputDataTypes().isFailure) - assert(ParseUrl(Seq(Literal(1), Literal("2"))).checkInputDataTypes().isFailure) - assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal(3))).checkInputDataTypes().isFailure) + assert(ParseUrl(Seq(Literal("1"))).checkInputDataTypes() == DataTypeMismatch( + errorSubClass = "WRONG_NUM_ARGS", + messageParameters = Map( + "functionName" -> "`parse_url`", + "expectedNum" -> "[2, 3]", + "actualNum" -> "1") + )) + assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal("3"), + Literal("4"))).checkInputDataTypes() == DataTypeMismatch( + errorSubClass = "WRONG_NUM_ARGS", + messageParameters = Map( + "functionName" -> "`parse_url`", + "expectedNum" -> "[2, 3]", + "actualNum" -> "4") + )) + assert(ParseUrl(Seq(Literal("1"), Literal(2))).checkInputDataTypes() == DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "2", + "requiredType" -> "\"STRING\"", + "inputSql" -> "\"2\"", + "inputType" -> "\"INT\""))) + assert(ParseUrl(Seq(Literal(1), Literal("2"))).checkInputDataTypes() == DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> "\"STRING\"", + "inputSql" -> "\"1\"", + "inputType" -> "\"INT\""))) + assert(ParseUrl(Seq(Literal("1"), Literal("2"), + Literal(3))).checkInputDataTypes() == DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "3", + "requiredType" -> "\"STRING\"", + "inputSql" -> "\"3\"", + "inputType" -> "\"INT\""))) // Test escaping of arguments GenerateUnsafeProjection.generate(ParseUrl(Seq(Literal("\"quote"), Literal("\"quote"))) :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnwrapUDTExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnwrapUDTExpressionSuite.scala new file mode 100644 index 000000000000..d1b13a4bec99 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnwrapUDTExpressionSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType +import org.apache.spark.sql.types.BooleanType + +class UnwrapUDTExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("Input type should be UserDefinedType") { + val b1 = Literal.create(false, BooleanType) + val unwrapUDTExpression = UnwrapUDT(b1) + assert(unwrapUDTExpression.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType("UserDefinedType"), + "inputSql" -> "\"false\"", + "inputType" -> "\"BOOLEAN\""))) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala index c6f6d3abb860..8d9f90a1a87c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.xml import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.StringType @@ -195,7 +196,13 @@ class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { // Validate that non-foldable paths are not supported. val nonLitPath = exprCtor(Literal("abcd"), NonFoldableLiteral("/")) - assert(nonLitPath.checkInputDataTypes().isFailure) + assert(nonLitPath.checkInputDataTypes() == DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "path", + "inputType" -> "\"STRING\"", + "inputExpr" -> "\"nonfoldableliteral()\"") + )) } testExpr(XPathBoolean) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 80e8f6d73420..71f576884d1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -109,66 +109,21 @@ object StatFunctions extends Logging { /** Calculate the Pearson Correlation Coefficient for the given columns */ def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { - val counts = collectStatisticalData(df, cols, "correlation") - counts.Ck / math.sqrt(counts.MkX * counts.MkY) - } - - /** Helper class to simplify tracking and merging counts. */ - private class CovarianceCounter extends Serializable { - var xAvg = 0.0 // the mean of all examples seen so far in col1 - var yAvg = 0.0 // the mean of all examples seen so far in col2 - var Ck = 0.0 // the co-moment after k examples - var MkX = 0.0 // sum of squares of differences from the (current) mean for col1 - var MkY = 0.0 // sum of squares of differences from the (current) mean for col2 - var count = 0L // count of observed examples - // add an example to the calculation - def add(x: Double, y: Double): this.type = { - val deltaX = x - xAvg - val deltaY = y - yAvg - count += 1 - xAvg += deltaX / count - yAvg += deltaY / count - Ck += deltaX * (y - yAvg) - MkX += deltaX * (x - xAvg) - MkY += deltaY * (y - yAvg) - this - } - // merge counters from other partitions. Formula can be found at: - // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance - def merge(other: CovarianceCounter): this.type = { - if (other.count > 0) { - val totalCount = count + other.count - val deltaX = xAvg - other.xAvg - val deltaY = yAvg - other.yAvg - Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count - xAvg = (xAvg * count + other.xAvg * other.count) / totalCount - yAvg = (yAvg * count + other.yAvg * other.count) / totalCount - MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count - MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count - count = totalCount - } - this + require(cols.length == 2, + "Currently correlation calculation is supported between two columns.") + val Seq(col1, col2) = cols.map { c => + val dataType = df.resolve(c).dataType + require(dataType.isInstanceOf[NumericType], + "Currently correlation calculation for columns with dataType " + + s"${dataType.catalogString} not supported.") + when(isnull(col(c)), lit(0.0)) + .otherwise(col(c).cast(DoubleType)) } - // return the sample covariance for the observed examples - def cov: Double = Ck / (count - 1) - } - - private def collectStatisticalData(df: DataFrame, cols: Seq[String], - functionName: String): CovarianceCounter = { - require(cols.length == 2, s"Currently $functionName calculation is supported " + - "between two columns.") - cols.map(name => (name, df.resolve(name))).foreach { case (name, data) => - require(data.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " + - s"for columns with dataType ${data.dataType.catalogString} not supported.") - } - val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) - df.select(columns: _*).queryExecution.toRdd.treeAggregate(new CovarianceCounter)( - seqOp = (counter, row) => { - counter.add(row.getDouble(0), row.getDouble(1)) - }, - combOp = (baseCounter, other) => { - baseCounter.merge(other) - }) + val correlation = corr(col1, col2) + df.select( + when(isnull(correlation), lit(Double.NaN)) + .otherwise(correlation) + ).head.getDouble(0) } /** @@ -178,8 +133,21 @@ object StatFunctions extends Logging { * @return the covariance of the two columns. */ def calculateCov(df: DataFrame, cols: Seq[String]): Double = { - val counts = collectStatisticalData(df, cols, "covariance") - counts.cov + require(cols.length == 2, + "Currently covariance calculation is supported between two columns.") + val Seq(col1, col2) = cols.map { c => + val dataType = df.resolve(c).dataType + require(dataType.isInstanceOf[NumericType], + "Currently covariance calculation for columns with dataType " + + s"${dataType.catalogString} not supported.") + when(isnull(col(c)), lit(0.0)) + .otherwise(col(c).cast(DoubleType)) + } + val covariance = covar_samp(col1, col2) + df.select( + when(isnull(covariance), lit(0.0)) + .otherwise(covariance) + ).head.getDouble(0) } /** Generate a table of frequencies for the elements of two columns. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3adf17518182..fe843c236c38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -4270,13 +4270,33 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("SPARK-21281 fails if functions have no argument") { val df = Seq(1).toDF("a") - val funcsMustHaveAtLeastOneArg = - ("coalesce", (df: DataFrame) => df.select(coalesce())) :: - ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: Nil - funcsMustHaveAtLeastOneArg.foreach { case (name, func) => - val errMsg = intercept[AnalysisException] { func(df) }.getMessage - assert(errMsg.contains(s"input to function $name requires at least one argument")) - } + checkError( + exception = intercept[AnalysisException] { + df.select(coalesce()) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_ARGS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"coalesce()\"", + "functionName" -> "`coalesce`", + "expectedNum" -> "> 0", + "actualNum" -> "0")) + + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("coalesce()") + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_ARGS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"coalesce()\"", + "functionName" -> "`coalesce`", + "expectedNum" -> "> 0", + "actualNum" -> "0"), + context = ExpectedContext( + fragment = "coalesce()", + start = 0, + stop = 9)) checkError( exception = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 79ab3cda9994..ceb1a75e83d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -23,7 +23,7 @@ import org.scalatest.matchers.must.Matchers._ import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.stat.StatFunctions -import org.apache.spark.sql.functions.{col, lit, struct} +import org.apache.spark.sql.functions.{col, lit, struct, when} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType} @@ -152,6 +152,18 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { } } + test("SPARK-40933 test cov & corr with null values and empty dataset") { + val df1 = spark.range(0, 10) + .withColumn("value", when(col("id") % 3 === 0, col("id"))) + assert(math.abs(df1.stat.cov("id", "value") - 5.0) < 1e-12) + assert(math.abs(df1.stat.corr("id", "value") - 0.5120915564991891) < 1e-12) + + // empty dataframe + val df2 = df1.where(col("id") < 0) + assert(df2.stat.cov("id", "value") === 0) + assert(df2.stat.corr("id", "value").isNaN) + } + test("covariance") { val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")