diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index bd449930e3b33..de50e78074ed9 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -138,6 +138,11 @@ "Unable to convert column of type to JSON." ] }, + "CANNOT_DROP_ALL_FIELDS" : { + "message" : [ + "Cannot drop all fields in struct." + ] + }, "CAST_WITHOUT_SUGGESTION" : { "message" : [ "cannot cast to ." @@ -155,6 +160,21 @@ "To convert values from to , you can use the functions instead." ] }, + "CREATE_MAP_KEY_DIFF_TYPES" : { + "message" : [ + "The given keys of function should all be the same type, but they are ." + ] + }, + "CREATE_MAP_VALUE_DIFF_TYPES" : { + "message" : [ + "The given values of function should all be the same type, but they are ." + ] + }, + "CREATE_NAMED_STRUCT_WITHOUT_FOLDABLE_STRING" : { + "message" : [ + "Only foldable `STRING` expressions are allowed to appear at odd position, but they are ." + ] + }, "DATA_DIFF_TYPES" : { "message" : [ "Input to should all be the same type, but it's ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 27d4f506ac864..97c882fd176be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -202,16 +204,30 @@ case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean) override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure( - s"$prettyName expects a positive even number of arguments.") + DataTypeMismatch( + errorSubClass = "WRONG_NUM_ARGS", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "expectedNum" -> "2n (n > 0)", + "actualNum" -> children.length.toString + ) + ) } else if (!TypeCoercion.haveSameType(keys.map(_.dataType))) { - TypeCheckResult.TypeCheckFailure( - "The given keys of function map should all be the same type, but they are " + - keys.map(_.dataType.catalogString).mkString("[", ", ", "]")) + DataTypeMismatch( + errorSubClass = "CREATE_MAP_KEY_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> keys.map(key => toSQLType(key.dataType)).mkString("[", ", ", "]") + ) + ) } else if (!TypeCoercion.haveSameType(values.map(_.dataType))) { - TypeCheckResult.TypeCheckFailure( - "The given values of function map should all be the same type, but they are " + - values.map(_.dataType.catalogString).mkString("[", ", ", "]")) + DataTypeMismatch( + errorSubClass = "CREATE_MAP_VALUE_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> values.map(value => toSQLType(value.dataType)).mkString("[", ", ", "]") + ) + ) } else { TypeUtils.checkForMapKeyType(dataType.keyType) } @@ -444,17 +460,32 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression with override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") + DataTypeMismatch( + errorSubClass = "WRONG_NUM_ARGS", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "expectedNum" -> "2n (n > 0)", + "actualNum" -> children.length.toString + ) + ) } else { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { - TypeCheckResult.TypeCheckFailure( - s"Only foldable ${StringType.catalogString} expressions are allowed to appear at odd" + - s" position, got: ${invalidNames.mkString(",")}") + DataTypeMismatch( + errorSubClass = "CREATE_NAMED_STRUCT_WITHOUT_FOLDABLE_STRING", + messageParameters = Map( + "inputExprs" -> invalidNames.map(toSQLExpr(_)).mkString("[", ", ", "]") + ) + ) } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure("Field name should not be null") + DataTypeMismatch( + errorSubClass = "UNEXPECTED_NULL", + messageParameters = Map( + "exprName" -> nameExprs.map(toSQLExpr).mkString("[", ", ", "]") + ) + ) } } } @@ -668,10 +699,19 @@ case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperat override def checkInputDataTypes(): TypeCheckResult = { val dataType = structExpr.dataType if (!dataType.isInstanceOf[StructType]) { - TypeCheckResult.TypeCheckFailure("struct argument should be struct type, got: " + - dataType.catalogString) + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(StructType), + "inputSql" -> toSQLExpr(structExpr), + "inputType" -> toSQLType(structExpr.dataType)) + ) } else if (newExprs.isEmpty) { - TypeCheckResult.TypeCheckFailure("cannot drop all fields in struct") + DataTypeMismatch( + errorSubClass = "CANNOT_DROP_ALL_FIELDS", + messageParameters = Map.empty + ) } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 9bc765df75e31..dc9a0ad30df6a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -40,6 +40,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer $"arrayField".array(StringType), Symbol("mapField").map(StringType, LongType)) + private def analysisException(expr: Expression): AnalysisException = { + intercept[AnalysisException](assertSuccess(expr)) + } + def assertError(expr: Expression, errorMessage: String): Unit = { val e = intercept[AnalysisException] { assertSuccess(expr) @@ -483,29 +487,68 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer } test("check types for CreateNamedStruct") { - assertError( - CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") - assertError( - CreateNamedStruct(Seq(1, "a", "b", 2.0)), - "Only foldable string expressions are allowed to appear at odd position") - assertError( - CreateNamedStruct(Seq($"a".string.at(0), "a", "b", 2.0)), - "Only foldable string expressions are allowed to appear at odd position") - assertError( - CreateNamedStruct(Seq(Literal.create(null, StringType), "a")), - "Field name should not be null") + checkError( + exception = analysisException(CreateNamedStruct(Seq("a", "b", 2.0))), + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_ARGS", + parameters = Map( + "sqlExpr" -> "\"named_struct(a, b, 2.0)\"", + "functionName" -> "`named_struct`", + "expectedNum" -> "2n (n > 0)", + "actualNum" -> "3") + ) + checkError( + exception = analysisException(CreateNamedStruct(Seq(1, "a", "b", 2.0))), + errorClass = "DATATYPE_MISMATCH.CREATE_NAMED_STRUCT_WITHOUT_FOLDABLE_STRING", + parameters = Map( + "sqlExpr" -> "\"named_struct(1, a, b, 2.0)\"", + "inputExprs" -> "[\"1\"]") + ) + checkError( + exception = analysisException(CreateNamedStruct(Seq($"a".string.at(0), "a", "b", 2.0))), + errorClass = "DATATYPE_MISMATCH.CREATE_NAMED_STRUCT_WITHOUT_FOLDABLE_STRING", + parameters = Map( + "sqlExpr" -> "\"named_struct(boundreference(), a, b, 2.0)\"", + "inputExprs" -> "[\"boundreference()\"]") + ) + checkError( + exception = analysisException(CreateNamedStruct(Seq(Literal.create(null, StringType), "a"))), + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL", + parameters = Map( + "sqlExpr" -> "\"named_struct(NULL, a)\"", + "exprName" -> "[\"NULL\"]") + ) } test("check types for CreateMap") { - assertError(CreateMap(Seq("a", "b", 2.0)), "even number of arguments") - assertError( - CreateMap(Seq($"intField", $"stringField", - $"booleanField", $"stringField")), - "keys of function map should all be the same type") - assertError( - CreateMap(Seq($"stringField", $"intField", - $"stringField", $"booleanField")), - "values of function map should all be the same type") + checkError( + exception = analysisException(CreateMap(Seq("a", "b", 2.0))), + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_ARGS", + parameters = Map( + "sqlExpr" -> "\"map(a, b, 2.0)\"", + "functionName" -> "`map`", + "expectedNum" -> "2n (n > 0)", + "actualNum" -> "3") + ) + checkError( + exception = analysisException(CreateMap(Seq(Literal(1), + Literal("a"), Literal(true), Literal("b")))), + errorClass = "DATATYPE_MISMATCH.CREATE_MAP_KEY_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"map(1, a, true, b)\"", + "functionName" -> "`map`", + "dataType" -> "[\"INT\", \"BOOLEAN\"]" + ) + ) + checkError( + exception = analysisException(CreateMap(Seq(Literal("a"), + Literal(1), Literal("b"), Literal(true)))), + errorClass = "DATATYPE_MISMATCH.CREATE_MAP_VALUE_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"map(a, 1, b, true)\"", + "functionName" -> "`map`", + "dataType" -> "[\"INT\", \"BOOLEAN\"]" + ) + ) } test("check types for ROUND/BROUND") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index fb6a23e3d776c..f1f781b7137b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util._ @@ -314,6 +315,40 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { assert(errorSubClass == "INVALID_MAP_KEY_TYPE") assert(messageParameters === Map("keyType" -> "\"MAP\"")) } + + // expects a positive even number of arguments + val map3 = CreateMap(Seq(Literal(1), Literal(2), Literal(3))) + assert(map3.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "WRONG_NUM_ARGS", + messageParameters = Map( + "functionName" -> "`map`", + "expectedNum" -> "2n (n > 0)", + "actualNum" -> "3") + ) + ) + + // The given keys of function map should all be the same type + val map4 = CreateMap(Seq(Literal(1), Literal(2), Literal('a'), Literal(3))) + assert(map4.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "CREATE_MAP_KEY_DIFF_TYPES", + messageParameters = Map( + "functionName" -> "`map`", + "dataType" -> "[\"INT\", \"STRING\"]") + ) + ) + + // The given values of function map should all be the same type + val map5 = CreateMap(Seq(Literal(1), Literal(2), Literal(3), Literal('a'))) + assert(map5.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "CREATE_MAP_VALUE_DIFF_TYPES", + messageParameters = Map( + "functionName" -> "`map`", + "dataType" -> "[\"INT\", \"STRING\"]") + ) + ) } test("MapFromArrays") { @@ -397,6 +432,18 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { create_row(UTF8String.fromString("x"), 2.0)) checkEvaluation(CreateNamedStruct(Seq("a", Literal.create(null, IntegerType))), create_row(null)) + + // expects a positive even number of arguments + val namedStruct1 = CreateNamedStruct(Seq(Literal(1), Literal(2), Literal(3))) + assert(namedStruct1.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "WRONG_NUM_ARGS", + messageParameters = Map( + "functionName" -> "`named_struct`", + "expectedNum" -> "2n (n > 0)", + "actualNum" -> "3") + ) + ) } test("test dsl for complex type") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 554f6a34b17e8..3c9f3e58cec63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -940,7 +940,7 @@ class Column(val expr: Expression) extends Logging { * * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") * df.select($"struct_col".dropFields("a", "b")) - * // result: org.apache.spark.sql.AnalysisException: cannot resolve 'update_fields(update_fields(`struct_col`))' due to data type mismatch: cannot drop all fields in struct + * // result: org.apache.spark.sql.AnalysisException: [DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS] Cannot resolve "update_fields(struct_col, dropfield(), dropfield())" due to data type mismatch: Cannot drop all fields in struct.; * * val df = sql("SELECT CAST(NULL AS struct) struct_col") * df.select($"struct_col".dropFields("b")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index f109b7ff90481..32ea6f74a0757 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -1023,9 +1023,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) test("withField should throw an exception if called on a non-StructType column") { - intercept[AnalysisException] { - testData.withColumn("key", $"key".withField("a", lit(2))) - }.getMessage should include("struct argument should be struct type, got: int") + checkError( + exception = intercept[AnalysisException] { + testData.withColumn("key", $"key".withField("a", lit(2))) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"update_fields(key, WithField(2))\"", + "paramIndex" -> "1", + "inputSql" -> "\"key\"", + "inputType" -> "\"INT\"", + "requiredType" -> "\"STRUCT\"") + ) } test("withField should throw an exception if either fieldName or col argument are null") { @@ -1059,9 +1068,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should throw an exception if intermediate field is not a struct") { - intercept[AnalysisException] { - structLevel1.withColumn("a", $"a".withField("b.a", lit(2))) - }.getMessage should include("struct argument should be struct type, got: int") + checkError( + exception = intercept[AnalysisException] { + structLevel1.withColumn("a", $"a".withField("b.a", lit(2))) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"update_fields(a.b, WithField(2))\"", + "paramIndex" -> "1", + "inputSql" -> "\"a.b\"", + "inputType" -> "\"INT\"", + "requiredType" -> "\"STRUCT\"") + ) } test("withField should throw an exception if intermediate field reference is ambiguous") { @@ -1788,9 +1806,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("dropFields should throw an exception if called on a non-StructType column") { - intercept[AnalysisException] { - testData.withColumn("key", $"key".dropFields("a")) - }.getMessage should include("struct argument should be struct type, got: int") + checkError( + exception = intercept[AnalysisException] { + testData.withColumn("key", $"key".dropFields("a")) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"update_fields(key, dropfield())\"", + "paramIndex" -> "1", + "inputSql" -> "\"key\"", + "inputType" -> "\"INT\"", + "requiredType" -> "\"STRUCT\"") + ) } test("dropFields should throw an exception if fieldName argument is null") { @@ -1816,9 +1843,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("dropFields should throw an exception if intermediate field is not a struct") { - intercept[AnalysisException] { - structLevel1.withColumn("a", $"a".dropFields("b.a")) - }.getMessage should include("struct argument should be struct type, got: int") + checkError( + exception = intercept[AnalysisException] { + structLevel1.withColumn("a", $"a".dropFields("b.a")) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"update_fields(a.b, dropfield())\"", + "paramIndex" -> "1", + "inputSql" -> "\"a.b\"", + "inputType" -> "\"INT\"", + "requiredType" -> "\"STRUCT\"") + ) } test("dropFields should throw an exception if intermediate field reference is ambiguous") { @@ -1873,9 +1909,13 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("dropFields should throw an exception if no fields will be left in struct") { - intercept[AnalysisException] { - structLevel1.withColumn("a", $"a".dropFields("a", "b", "c")) - }.getMessage should include("cannot drop all fields in struct") + checkError( + exception = intercept[AnalysisException] { + structLevel1.withColumn("a", $"a".dropFields("a", "b", "c")) + }, + errorClass = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", + parameters = Map("sqlExpr" -> "\"update_fields(a, dropfield(), dropfield(), dropfield())\"") + ) } test("dropFields should drop field with no name in struct") { @@ -2140,10 +2180,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { .select($"struct_col".dropFields("b", "c")), Row(Row(1))) - intercept[AnalysisException] { - sql("SELECT named_struct('a', 1, 'b', 2) struct_col") - .select($"struct_col".dropFields("a", "b")) - }.getMessage should include("cannot drop all fields in struct") + checkError( + exception = intercept[AnalysisException] { + sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + .select($"struct_col".dropFields("a", "b")) + }, + errorClass = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", + parameters = Map("sqlExpr" -> "\"update_fields(struct_col, dropfield(), dropfield())\"") + ) checkAnswer( sql("SELECT CAST(NULL AS struct) struct_col")