diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 969128838eba4..c95570eac09ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -498,22 +499,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } + private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow + /** * Change the precision / scale in a given decimal to those set in `decimalType` (if any), - * returning null if it overflows or modifying `value` in-place and returning it if successful. + * modifying `value` in-place and returning it if successful. If an overflow occurs, it + * either returns null or throws an exception according to the value set for + * `spark.sql.decimalOperations.nullOnOverflow`. * * NOTE: this modifies `value` in-place, so don't call it on external data. */ private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { - if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null + if (value.changePrecision(decimalType.precision, decimalType.scale)) { + value + } else { + if (nullOnOverflow) { + null + } else { + throw new ArithmeticException(s"${value.toDebugString} cannot be represented as " + + s"Decimal(${decimalType.precision}, ${decimalType.scale}).") + } + } } /** - * Create new `Decimal` with precision and scale given in `decimalType` (if any), - * returning null if it overflows or creating a new `value` and returning it if successful. + * Create new `Decimal` with precision and scale given in `decimalType` (if any). + * If overflow occurs, if `spark.sql.decimalOperations.nullOnOverflow` is true, null is returned; + * otherwise, an `ArithmeticException` is thrown. */ private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal = - value.toPrecision(decimalType.precision, decimalType.scale) + value.toPrecision( + decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow) private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { @@ -963,11 +979,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $d; """.stripMargin } else { + val overflowCode = if (nullOnOverflow) { + s"$evNull = true;" + } else { + s""" + |throw new ArithmeticException($d.toDebugString() + " cannot be represented as " + + | "Decimal(${decimalType.precision}, ${decimalType.scale})."); + """.stripMargin + } code""" |if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { | $evPrim = $d; |} else { - | $evNull = true; + | $overflowCode |} """.stripMargin } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 4d667fd61ae01..d090657b50bb8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1018,4 +1019,25 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ret, InternalRow(null)) } } + + test("SPARK-28470: Cast should honor nullOnOverflow property") { + withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") { + checkEvaluation(Cast(Literal("134.12"), DecimalType(3, 2)), null) + checkEvaluation( + Cast(Literal(Timestamp.valueOf("2019-07-25 22:04:36")), DecimalType(3, 2)), null) + checkEvaluation(Cast(Literal(BigDecimal(134.12)), DecimalType(3, 2)), null) + checkEvaluation(Cast(Literal(134.12), DecimalType(3, 2)), null) + } + withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") { + checkExceptionInExpression[ArithmeticException]( + Cast(Literal("134.12"), DecimalType(3, 2)), "cannot be represented") + checkExceptionInExpression[ArithmeticException]( + Cast(Literal(Timestamp.valueOf("2019-07-25 22:04:36")), DecimalType(3, 2)), + "cannot be represented") + checkExceptionInExpression[ArithmeticException]( + Cast(Literal(BigDecimal(134.12)), DecimalType(3, 2)), "cannot be represented") + checkExceptionInExpression[ArithmeticException]( + Cast(Literal(134.12), DecimalType(3, 2)), "cannot be represented") + } + } }