diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 82692334544e..24cb6b301dfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -82,6 +82,8 @@ object DecimalPrecision extends TypeCoercionRule { PromotePrecision(Cast(e, dataType)) } + private def nullOnOverflow: Boolean = SQLConf.get.decimalOperationsNullOnOverflow + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // fix decimal precision for expressions case q => q.transformExpressionsUp( @@ -105,7 +107,7 @@ object DecimalPrecision extends TypeCoercionRule { DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), - resultType) + resultType, nullOnOverflow) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val resultScale = max(s1, s2) @@ -116,7 +118,7 @@ object DecimalPrecision extends TypeCoercionRule { DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), - resultType) + resultType, nullOnOverflow) case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { @@ -126,7 +128,7 @@ object DecimalPrecision extends TypeCoercionRule { } val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) + resultType, nullOnOverflow) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { @@ -148,7 +150,7 @@ object DecimalPrecision extends TypeCoercionRule { } val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) + resultType, nullOnOverflow) case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { @@ -159,7 +161,7 @@ object DecimalPrecision extends TypeCoercionRule { // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) + resultType, nullOnOverflow) case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { @@ -170,7 +172,7 @@ object DecimalPrecision extends TypeCoercionRule { // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), - resultType) + resultType, nullOnOverflow) case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala index 76733dd6dac3..c1d72f9b58a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -236,7 +236,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging { collect(left, negate) ++ collect(right, !negate) case UnaryMinus(child) => collect(child, !negate) - case CheckOverflow(child, _) => + case CheckOverflow(child, _, _) => collect(child, negate) case PromotePrecision(child) => collect(child, negate) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 3a06f8d0c50f..afe8a23f8f15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -114,7 +114,7 @@ object RowEncoder { d, "fromDecimal", inputObject :: Nil, - returnNullable = false), d) + returnNullable = false), d, SQLConf.get.decimalOperationsNullOnOverflow) case StringType => createSerializerForString(inputObject) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 04de83343be7..ad7f7dd9434a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -81,30 +81,34 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { /** * Rounds the decimal to given scale and check whether the decimal can fit in provided precision - * or not, returns null if not. + * or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an + * `ArithmeticException` is thrown. */ -case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression { +case class CheckOverflow( + child: Expression, + dataType: DecimalType, + nullOnOverflow: Boolean) extends UnaryExpression { override def nullable: Boolean = true override def nullSafeEval(input: Any): Any = - input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale) + input.asInstanceOf[Decimal].toPrecision( + dataType.precision, + dataType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { - val tmp = ctx.freshName("tmp") s""" - | Decimal $tmp = $eval.clone(); - | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) { - | ${ev.value} = $tmp; - | } else { - | ${ev.isNull} = true; - | } + |${ev.value} = $eval.toPrecision( + | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow); + |${ev.isNull} = ${ev.value} == null; """.stripMargin }) } - override def toString: String = s"CheckOverflow($child, $dataType)" + override def toString: String = s"CheckOverflow($child, $dataType, $nullOnOverflow)" override def sql: String = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index c2e1720259b5..bdeb9ed29e0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1138,8 +1138,10 @@ abstract class RoundBase(child: Expression, scale: Expression, val evaluationCode = dataType match { case DecimalType.Fixed(_, s) => s""" - ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr()); - ${ev.isNull} = ${ev.value} == null;""" + |${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, + | Decimal.$modeStr(), true); + |${ev.isNull} = ${ev.value} == null; + """.stripMargin case ByteType => if (_scale < 0) { s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d26cd2ca7343..21ffa07f83ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1441,6 +1441,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val DECIMAL_OPERATIONS_NULL_ON_OVERFLOW = + buildConf("spark.sql.decimalOperations.nullOnOverflow") + .internal() + .doc("When true (default), if an overflow on a decimal occurs, then NULL is returned. " + + "Spark's older versions and Hive behave in this way. If turned to false, SQL ANSI 2011 " + + "specification will be followed instead: an arithmetic exception is thrown, as most " + + "of the SQL databases do.") + .booleanConf + .createWithDefault(true) + val LITERAL_PICK_MINIMUM_PRECISION = buildConf("spark.sql.legacy.literal.pickMinimumPrecision") .internal() @@ -2205,6 +2215,8 @@ class SQLConf extends Serializable with Logging { def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) + def decimalOperationsNullOnOverflow: Boolean = getConf(DECIMAL_OPERATIONS_NULL_ON_OVERFLOW) + def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION) def continuousStreamingEpochBacklogQueueSize: Int = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 0192059a3a39..b7b70974f50e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -249,14 +249,25 @@ final class Decimal extends Ordered[Decimal] with Serializable { /** * Create new `Decimal` with given precision and scale. * - * @return a non-null `Decimal` value if successful or `null` if overflow would occur. + * @return a non-null `Decimal` value if successful. Otherwise, if `nullOnOverflow` is true, null + * is returned; if `nullOnOverflow` is false, an `ArithmeticException` is thrown. */ private[sql] def toPrecision( precision: Int, scale: Int, - roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = { + roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP, + nullOnOverflow: Boolean = true): Decimal = { val copy = clone() - if (copy.changePrecision(precision, scale, roundMode)) copy else null + if (copy.changePrecision(precision, scale, roundMode)) { + copy + } else { + if (nullOnOverflow) { + null + } else { + throw new ArithmeticException( + s"$toDebugString cannot be represented as Decimal($precision, $scale).") + } + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala index a8f758d625a0..d14eceb480f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala @@ -45,18 +45,26 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("CheckOverflow") { val d1 = Decimal("10.1") - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10")) - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1) - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1) - checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0), true), Decimal("10")) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1), true), d1) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2), true), d1) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3), true), null) + intercept[ArithmeticException](CheckOverflow(Literal(d1), DecimalType(4, 3), false).eval()) + intercept[ArithmeticException](checkEvaluationWithMutableProjection( + CheckOverflow(Literal(d1), DecimalType(4, 3), false), null)) val d2 = Decimal(101, 3, 1) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10")) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2) - checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0), true), Decimal("10")) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1), true), d2) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2), true), d2) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3), true), null) + intercept[ArithmeticException](CheckOverflow(Literal(d2), DecimalType(4, 3), false).eval()) + intercept[ArithmeticException](checkEvaluationWithMutableProjection( + CheckOverflow(Literal(d2), DecimalType(4, 3), false), null)) - checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null) + checkEvaluation(CheckOverflow( + Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2), true), null) + checkEvaluation(CheckOverflow( + Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2), false), null) } - } diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql similarity index 79% rename from sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql rename to sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql index 28a0e20c0f49..35f2be46cd13 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/decimalArithmeticOperations.sql @@ -83,4 +83,28 @@ select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.1 select 123456789123456789.1234567890 * 1.123456789123456789; select 12345678912345.123456789123 / 0.000000012345678; +-- throw an exception instead of returning NULL, according to SQL ANSI 2011 +set spark.sql.decimalOperations.nullOnOverflow=false; + +-- test operations between decimals and constants +select id, a*10, b/10 from decimals_test order by id; + +-- test operations on constants +select 10.3 * 3.0; +select 10.3000 * 3.0; +select 10.30000 * 30.0; +select 10.300000000000000000 * 3.000000000000000000; +select 10.300000000000000000 * 3.0000000000000000000; + +-- arithmetic operations causing an overflow throw exception +select (5e36 + 0.1) + 5e36; +select (-4e36 - 0.1) - 7e36; +select 12345678901234567890.0 * 12345678901234567890.0; +select 1e35 / 0.1; + +-- arithmetic operations causing a precision loss throw exception +select 123456789123456789.1234567890 * 1.123456789123456789; +select 123456789123456789.1234567890 * 1.123456789123456789; +select 12345678912345.123456789123 / 0.000000012345678; + drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/decimalArithmeticOperations.sql.out similarity index 73% rename from sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out rename to sql/core/src/test/resources/sql-tests/results/decimalArithmeticOperations.sql.out index cbf44548b3cc..217233bfad37 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 40 +-- Number of queries: 54 -- !query 0 @@ -328,8 +328,131 @@ NULL -- !query 39 -drop table decimals_test +set spark.sql.decimalOperations.nullOnOverflow=false -- !query 39 schema -struct<> +struct -- !query 39 output +spark.sql.decimalOperations.nullOnOverflow false + + +-- !query 40 +select id, a*10, b/10 from decimals_test order by id +-- !query 40 schema +struct +-- !query 40 output +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.1123456789123456789 + + +-- !query 41 +select 10.3 * 3.0 +-- !query 41 schema +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> +-- !query 41 output +30.9 + + +-- !query 42 +select 10.3000 * 3.0 +-- !query 42 schema +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> +-- !query 42 output +30.9 + + +-- !query 43 +select 10.30000 * 30.0 +-- !query 43 schema +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +-- !query 43 output +309 + + +-- !query 44 +select 10.300000000000000000 * 3.000000000000000000 +-- !query 44 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)> +-- !query 44 output +30.9 + + +-- !query 45 +select 10.300000000000000000 * 3.0000000000000000000 +-- !query 45 schema +struct<> +-- !query 45 output +java.lang.ArithmeticException +Decimal(expanded,30.900000000000000000000000000000000000,38,36}) cannot be represented as Decimal(38, 37). + + +-- !query 46 +select (5e36 + 0.1) + 5e36 +-- !query 46 schema +struct<> +-- !query 46 output +java.lang.ArithmeticException +Decimal(expanded,10000000000000000000000000000000000000.1,39,1}) cannot be represented as Decimal(38, 1). + + +-- !query 47 +select (-4e36 - 0.1) - 7e36 +-- !query 47 schema +struct<> +-- !query 47 output +java.lang.ArithmeticException +Decimal(expanded,-11000000000000000000000000000000000000.1,39,1}) cannot be represented as Decimal(38, 1). + + +-- !query 48 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 48 schema +struct<> +-- !query 48 output +java.lang.ArithmeticException +Decimal(expanded,1.5241578753238836750190519987501905210E+38,38,-1}) cannot be represented as Decimal(38, 2). + + +-- !query 49 +select 1e35 / 0.1 +-- !query 49 schema +struct<> +-- !query 49 output +java.lang.ArithmeticException +Decimal(expanded,1000000000000000000000000000000000000,37,0}) cannot be represented as Decimal(38, 3). + + +-- !query 50 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 50 schema +struct<> +-- !query 50 output +java.lang.ArithmeticException +Decimal(expanded,138698367904130467.65432098851562262075,38,20}) cannot be represented as Decimal(38, 28). + + +-- !query 51 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 51 schema +struct<> +-- !query 51 output +java.lang.ArithmeticException +Decimal(expanded,138698367904130467.65432098851562262075,38,20}) cannot be represented as Decimal(38, 28). + + +-- !query 52 +select 12345678912345.123456789123 / 0.000000012345678 +-- !query 52 schema +struct<> +-- !query 52 output +java.lang.ArithmeticException +Decimal(expanded,1000000073899961059796.7258663315210392,38,16}) cannot be represented as Decimal(38, 18). + + +-- !query 53 +drop table decimals_test +-- !query 53 schema +struct<> +-- !query 53 output