From 2793ac6cf38f32ce8a7b12a6edcc23a65ff3cff3 Mon Sep 17 00:00:00 2001 From: 07ARB Date: Fri, 27 Dec 2019 15:52:47 +0530 Subject: [PATCH] [SPARK-29854]lpad and rpad built in function should show Error or throw Exception for invalid length value --- .../expressions/stringExpressions.scala | 27 +++++++++++++++++++ .../expressions/StringExpressionsSuite.scala | 11 ++++++++ 2 files changed, 38 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 211ae3f02a0d8..091a7c4dd5f16 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.commons.codec.binary.{Base64 => CommonsBase64} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -1227,6 +1228,19 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) + override def checkInputDataTypes(): TypeCheckResult = { + val inputTypeCheck = super.checkInputDataTypes() + if (inputTypeCheck.isFailure) { + try { + if (len != null && len.toString.toInt.isValidInt) inputTypeCheck + } catch { + case _: NumberFormatException => + throw new AnalysisException(s"Invalid argument, $inputTypeCheck") + } + } + inputTypeCheck + } + override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } @@ -1268,6 +1282,19 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) + override def checkInputDataTypes(): TypeCheckResult = { + val inputTypeCheck = super.checkInputDataTypes() + if (inputTypeCheck.isFailure) { + try { + if (len != null && len.toString.toInt.isValidInt) inputTypeCheck + } catch { + case _: NumberFormatException => + throw new AnalysisException(s"Invalid argument, $inputTypeCheck") + } + } + inputTypeCheck + } + override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 4308f98d6969a..b5a74dbad8ab7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolderSparkSubmitSuite.{assert, intercept} import org.apache.spark.sql.types._ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -720,6 +721,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringRPad(s1, s2, s3), null, row5) checkEvaluation(StringRPad(Literal("hi"), Literal(5)), "hi ") checkEvaluation(StringRPad(Literal("hi"), Literal(1)), "h") + + assert(intercept[AnalysisException] { + checkEvaluation(StringRPad(Literal("hi"), Literal("invalidLength")), "Exception") + }.getMessage.contains("Invalid argument, TypeCheckFailure(argument 2 " + + "requires int type, however, ''invalidLength'' is of string type.);")) + + assert(intercept[AnalysisException] { + checkEvaluation(StringLPad(Literal("hi"), Literal("invalidLength")), "Exception") + }.getMessage.contains("Invalid argument, TypeCheckFailure(argument 2 " + + "requires int type, however, ''invalidLength'' is of string type.);")) } test("REPEAT") {