diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 500c040dfe4e..2d0792208129 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -27,7 +27,8 @@ import org.apache.commons.text.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.BinaryLike @@ -37,7 +38,6 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String - abstract class StringRegexExpression extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Predicate { @@ -594,14 +594,28 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio return defaultCheck } if (!pos.foldable) { - return TypeCheckFailure(s"Position expression must be foldable, but got $pos") + return DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "position", + "inputType" -> toSQLType(pos.dataType), + "inputExpr" -> toSQLExpr(pos) + ) + ) } val posEval = pos.eval() if (posEval == null || posEval.asInstanceOf[Int] > 0) { TypeCheckSuccess } else { - TypeCheckFailure(s"Position expression must be positive, but got: $posEval") + DataTypeMismatch( + errorSubClass = "VALUE_OUT_OF_RANGE", + messageParameters = Map( + "exprName" -> "position", + "valueRange" -> s"(0, ${Int.MaxValue}]", + "currentValue" -> toSQLValue(posEval, pos.dataType) + ) + ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 1bc79f238464..6927c4cfa3c9 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -25,6 +25,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke @@ -273,18 +275,35 @@ case class Elt( override def checkInputDataTypes(): TypeCheckResult = { if (children.size < 2) { - TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments") + DataTypeMismatch( + errorSubClass = "WRONG_NUM_PARAMS", + messageParameters = Map( + "functionName" -> "elt", + "expectedNum" -> "> 1", + "actualNum" -> children.length.toString + ) + ) } else { val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) if (indexType != IntegerType) { - return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + - s"have ${IntegerType.catalogString}, but it's ${indexType.catalogString}") + return DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(IntegerType), + "inputSql" -> toSQLExpr(indexExpr), + "inputType" -> toSQLType(indexType))) } if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { - return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have ${StringType.catalogString} or " + - s"${BinaryType.catalogString}, but it's " + - inputTypes.map(_.catalogString).mkString("[", ", ", "]")) + return DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "2...", + "requiredType" -> (toSQLType(StringType) + " or " + toSQLType(BinaryType)), + "inputSql" -> inputExprs.map(toSQLExpr(_)).mkString(","), + "inputType" -> inputTypes.map(toSQLType(_)).mkString(",") + ) + ) } TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 9089963ee852..98a6a9bc19c4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -19,12 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.optimizer.ConstantFolding import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{IntegerType, StringType} /** * Unit tests for regular expression (regexp) related SQL expressions. @@ -531,4 +533,23 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { create_row("abc", ", (", 0), s"$prefix `regexp_instr` is invalid: , (") } + + test("RegExpReplace: fails analysis if pos is not a constant") { + val s = $"s".string.at(0) + val p = $"p".string.at(1) + val r = $"r".string.at(2) + val posExpr = AttributeReference("b", IntegerType)() + val expr = RegExpReplace(s, p, r, posExpr) + + assert(expr.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "position", + "inputType" -> toSQLType(posExpr.dataType), + "inputExpr" -> toSQLExpr(posExpr) + ) + ) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 0d42155a5111..fce94bf02a0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -21,7 +21,9 @@ import java.math.{BigDecimal => JavaBigDecimal} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -1583,4 +1585,51 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Contains(Literal("Spark SQL"), Literal("SQL")), true) checkEvaluation(Contains(Literal("Spark SQL"), Literal("k S")), true) } + + test("Elt: checkInputDataTypes") { + // requires at least two arguments + val indexExpr1 = Literal(8) + val expr1 = Elt(Seq(indexExpr1)) + assert(expr1.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "WRONG_NUM_PARAMS", + messageParameters = Map( + "functionName" -> "elt", + "expectedNum" -> "> 1", + "actualNum" -> "1" + ) + ) + ) + + // first input to function etl should have IntegerType + val indexExpr2 = Literal('a') + val expr2 = Elt(Seq(indexExpr2, Literal('b'))) + assert(expr2.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(IntegerType), + "inputSql" -> toSQLExpr(indexExpr2), + "inputType" -> toSQLType(indexExpr2.dataType) + ) + ) + ) + + // input to function etl should have StringType or BinaryType + val indexExpr3 = Literal(1) + val inputExpr3 = Seq(Literal('a'), Literal('b'), Literal(12345)) + val expr3 = Elt(indexExpr3 +: inputExpr3) + assert(expr3.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "2...", + "requiredType" -> (toSQLType(StringType) + " or " + toSQLType(BinaryType)), + "inputSql" -> inputExpr3.map(toSQLExpr(_)).mkString(","), + "inputType" -> inputExpr3.map(expr => toSQLType(expr.dataType)).mkString(",") + ) + ) + ) + } } diff --git a/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out index adeea49a3e37..60094af7a991 100644 --- a/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out @@ -355,7 +355,22 @@ SELECT regexp_replace('healthy, wealthy, and wise', '\\w+thy', 'something', -2) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'regexp_replace('healthy, wealthy, and wise', '\\w+thy', 'something', -2)' due to data type mismatch: Position expression must be positive, but got: -2; line 1 pos 7 +{ + "errorClass" : "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE", + "messageParameters" : { + "currentValue" : "-2", + "exprName" : "position", + "sqlExpr" : "\"regexp_replace(healthy, wealthy, and wise, \\w+thy, something, -2)\"", + "valueRange" : "(0, 2147483647]" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 79, + "fragment" : "regexp_replace('healthy, wealthy, and wise', '\\\\w+thy', 'something', -2)" + } ] +} -- !query @@ -364,7 +379,22 @@ SELECT regexp_replace('healthy, wealthy, and wise', '\\w+thy', 'something', 0) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'regexp_replace('healthy, wealthy, and wise', '\\w+thy', 'something', 0)' due to data type mismatch: Position expression must be positive, but got: 0; line 1 pos 7 +{ + "errorClass" : "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE", + "messageParameters" : { + "currentValue" : "0", + "exprName" : "position", + "sqlExpr" : "\"regexp_replace(healthy, wealthy, and wise, \\w+thy, something, 0)\"", + "valueRange" : "(0, 2147483647]" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "regexp_replace('healthy, wealthy, and wise', '\\\\w+thy', 'something', 0)" + } ] +} -- !query