diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 8c0f1659ea50..1d03cb114c9b 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -269,6 +269,11 @@ "Input to the function cannot contain elements of the \"MAP\" type. In Spark, same maps may have different hashcode, thus hash expressions are prohibited on \"MAP\" elements. To restore previous behavior set \"spark.sql.legacy.allowHashOnMapType\" to \"true\"." ] }, + "INVALID_ARG_VALUE" : { + "message" : [ + "The value must to be a literal of , but got ." + ] + }, "INVALID_JSON_MAP_KEY_TYPE" : { "message" : [ "Input schema can only contain STRING as a key type for a MAP." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 60b56f4fef79..3a1db2ce1b8b 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -2620,39 +2620,30 @@ case class ToBinary( nullOnInvalidFormat: Boolean = false) extends RuntimeReplaceable with ImplicitCastInputTypes { - override lazy val replacement: Expression = format.map { f => - assert(f.foldable && (f.dataType == StringType || f.dataType == NullType)) + @transient lazy val fmt: String = format.map { f => val value = f.eval() if (value == null) { - Literal(null, BinaryType) + null } else { - value.asInstanceOf[UTF8String].toString.toLowerCase(Locale.ROOT) match { - case "hex" => Unhex(expr, failOnError = true) - case "utf-8" | "utf8" => Encode(expr, Literal("UTF-8")) - case "base64" => UnBase64(expr, failOnError = true) - case _ if nullOnInvalidFormat => Literal(null, BinaryType) - case other => throw QueryCompilationErrors.invalidStringLiteralParameter( - "to_binary", - "format", - other, - Some( - "The value has to be a case-insensitive string literal of " + - "'hex', 'utf-8', 'utf8', or 'base64'.")) - } + value.asInstanceOf[UTF8String].toString.toLowerCase(Locale.ROOT) + } + }.getOrElse("hex") + + override lazy val replacement: Expression = if (fmt == null) { + Literal(null, BinaryType) + } else { + fmt match { + case "hex" => Unhex(expr, failOnError = true) + case "utf-8" | "utf8" => Encode(expr, Literal("UTF-8")) + case "base64" => UnBase64(expr, failOnError = true) + case _ => Literal(null, BinaryType) } - }.getOrElse(Unhex(expr, failOnError = true)) + } def this(expr: Expression) = this(expr, None, false) def this(expr: Expression, format: Expression) = - this(expr, Some({ - // We perform this check in the constructor to make it eager and not go through type coercion. - if (format.foldable && (format.dataType == StringType || format.dataType == NullType)) { - format - } else { - throw QueryCompilationErrors.requireLiteralParameter("to_binary", "format", "string") - } - }), false) + this(expr, Some(format), false) override def prettyName: String = "to_binary" @@ -2660,6 +2651,50 @@ case class ToBinary( override def inputTypes: Seq[AbstractDataType] = children.map(_ => StringType) + override def checkInputDataTypes(): TypeCheckResult = { + def isValidFormat: Boolean = { + fmt == null || Set("hex", "utf-8", "utf8", "base64").contains(fmt) + } + format match { + case Some(f) => + if (f.foldable && (f.dataType == StringType || f.dataType == NullType)) { + if (isValidFormat || nullOnInvalidFormat) { + super.checkInputDataTypes() + } else { + DataTypeMismatch( + errorSubClass = "INVALID_ARG_VALUE", + messageParameters = Map( + "inputName" -> "fmt", + "requireType" -> s"case-insensitive ${toSQLType(StringType)}", + "validValues" -> "'hex', 'utf-8', 'utf8', or 'base64'", + "inputValue" -> toSQLValue(fmt, StringType) + ) + ) + } + } else if (!f.foldable) { + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "fmt", + "inputType" -> toSQLType(StringType), + "inputExpr" -> toSQLExpr(f) + ) + ) + } else { + DataTypeMismatch( + errorSubClass = "INVALID_ARG_VALUE", + messageParameters = Map( + "inputName" -> "fmt", + "requireType" -> s"case-insensitive ${toSQLType(StringType)}", + "validValues" -> "'hex', 'utf-8', 'utf8', or 'base64'", + "inputValue" -> toSQLValue(f.eval(), f.dataType) + ) + ) + } + case _ => super.checkInputDataTypes() + } + } + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): Expression = { if (format.isDefined) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 8bdbcb26e83c..42b1b967fe7d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -1256,6 +1256,21 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } + test("ToBinary: fails analysis if fmt is not foldable") { + val wrongFmt = AttributeReference("invalidFormat", StringType)() + val toBinaryExpr = ToBinary(Literal("abc"), Some(wrongFmt)) + assert(toBinaryExpr.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "fmt", + "inputType" -> toSQLType(wrongFmt.dataType), + "inputExpr" -> toSQLExpr(wrongFmt) + ) + ) + ) + } + test("ToNumber: negative tests (the input string does not match the format string)") { Seq( // The input contained more thousands separators than the format string. diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index cb18c547b612..39c57e6efa28 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -225,3 +225,7 @@ select to_binary(null, cast(null as string)); -- invalid format select to_binary('abc', 1); select to_binary('abc', 'invalidFormat'); +CREATE TEMPORARY VIEW fmtTable(fmtField) AS SELECT * FROM VALUES ('invalidFormat'); +SELECT to_binary('abc', fmtField) FROM fmtTable; +-- Clean up +DROP VIEW IF EXISTS fmtTable; diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index 3ab49c14bef1..5a0479996b95 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -1610,11 +1610,13 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "DATATYPE_MISMATCH.INVALID_ARG_VALUE", "messageParameters" : { - "argName" : "format", - "funcName" : "to_binary", - "requiredType" : "string" + "inputName" : "fmt", + "inputValue" : "'1'", + "requireType" : "case-insensitive \"STRING\"", + "sqlExpr" : "\"to_binary(abc, 1)\"", + "validValues" : "'hex', 'utf-8', 'utf8', or 'base64'" }, "queryContext" : [ { "objectType" : "", @@ -1633,11 +1635,59 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1101", + "errorClass" : "DATATYPE_MISMATCH.INVALID_ARG_VALUE", "messageParameters" : { - "argName" : "format", - "endingMsg" : " The value has to be a case-insensitive string literal of 'hex', 'utf-8', 'utf8', or 'base64'.", - "funcName" : "to_binary", - "invalidValue" : "invalidformat" - } + "inputName" : "fmt", + "inputValue" : "'invalidformat'", + "requireType" : "case-insensitive \"STRING\"", + "sqlExpr" : "\"to_binary(abc, invalidFormat)\"", + "validValues" : "'hex', 'utf-8', 'utf8', or 'base64'" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 40, + "fragment" : "to_binary('abc', 'invalidFormat')" + } ] } + + +-- !query +CREATE TEMPORARY VIEW fmtTable(fmtField) AS SELECT * FROM VALUES ('invalidFormat') +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT to_binary('abc', fmtField) FROM fmtTable +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "messageParameters" : { + "inputExpr" : "\"fmtField\"", + "inputName" : "fmt", + "inputType" : "\"STRING\"", + "sqlExpr" : "\"to_binary(abc, fmtField)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 33, + "fragment" : "to_binary('abc', fmtField)" + } ] +} + + +-- !query +DROP VIEW IF EXISTS fmtTable +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 2ea5cefa38d1..36814275cd7d 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1542,11 +1542,13 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "DATATYPE_MISMATCH.INVALID_ARG_VALUE", "messageParameters" : { - "argName" : "format", - "funcName" : "to_binary", - "requiredType" : "string" + "inputName" : "fmt", + "inputValue" : "'1'", + "requireType" : "case-insensitive \"STRING\"", + "sqlExpr" : "\"to_binary(abc, 1)\"", + "validValues" : "'hex', 'utf-8', 'utf8', or 'base64'" }, "queryContext" : [ { "objectType" : "", @@ -1565,11 +1567,59 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1101", + "errorClass" : "DATATYPE_MISMATCH.INVALID_ARG_VALUE", "messageParameters" : { - "argName" : "format", - "endingMsg" : " The value has to be a case-insensitive string literal of 'hex', 'utf-8', 'utf8', or 'base64'.", - "funcName" : "to_binary", - "invalidValue" : "invalidformat" - } + "inputName" : "fmt", + "inputValue" : "'invalidformat'", + "requireType" : "case-insensitive \"STRING\"", + "sqlExpr" : "\"to_binary(abc, invalidFormat)\"", + "validValues" : "'hex', 'utf-8', 'utf8', or 'base64'" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 40, + "fragment" : "to_binary('abc', 'invalidFormat')" + } ] } + + +-- !query +CREATE TEMPORARY VIEW fmtTable(fmtField) AS SELECT * FROM VALUES ('invalidFormat') +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT to_binary('abc', fmtField) FROM fmtTable +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "messageParameters" : { + "inputExpr" : "\"fmtField\"", + "inputName" : "fmt", + "inputType" : "\"STRING\"", + "sqlExpr" : "\"to_binary(abc, fmtField)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 33, + "fragment" : "to_binary('abc', fmtField)" + } ] +} + + +-- !query +DROP VIEW IF EXISTS fmtTable +-- !query schema +struct<> +-- !query output +