diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index b315de3091dc9..d70fe92e7fe54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -62,18 +62,18 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val sum = AttributeReference("sum", sumDataType)() - private lazy val isEmptyOrNulls = AttributeReference("isEmptyOrNulls", BooleanType, false)() + private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)() private lazy val zero = Literal.default(sumDataType) override lazy val aggBufferAttributes = resultType match { - case _: DecimalType => sum :: isEmptyOrNulls :: Nil + case _: DecimalType => sum :: isEmpty :: Nil case _ => sum :: Nil } override lazy val initialValues: Seq[Expression] = resultType match { - case _: DecimalType => Seq(zero, Literal.create(true, BooleanType)) - case other => Seq(Literal.create(null, other)) + case _: DecimalType => Seq(Literal(null, resultType), Literal(true, BooleanType)) + case _ => Seq(Literal(null, resultType)) } /** @@ -97,29 +97,18 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast */ override lazy val updateExpressions: Seq[Expression] = { if (child.nullable) { + val updateSumExpr = coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) resultType match { - case d: DecimalType => - Seq( - /* sum */ - If(IsNull(sum), sum, - If(IsNotNull(child.cast(sumDataType)), - CheckOverflow(sum + child.cast(sumDataType), d, true), sum)), - /* isEmptyOrNulls */ - If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls) - ) - case _ => - Seq(coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)) + case _: DecimalType => + Seq(updateSumExpr, isEmpty && child.isNull) + case _ => Seq(updateSumExpr) } } else { + val updateSumExpr = coalesce(sum, zero) + child.cast(sumDataType) resultType match { - case d: DecimalType => - Seq( - /* sum */ - If(IsNull(sum), sum, CheckOverflow(sum + child.cast(sumDataType), d, true)), - /* isEmptyOrNulls */ - false - ) - case _ => Seq(coalesce(sum, zero) + child.cast(sumDataType)) + case _: DecimalType => + Seq(updateSumExpr, Literal(false, BooleanType)) + case _ => Seq(updateSumExpr) } } } @@ -138,19 +127,15 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast * If the value from bufferLeft and bufferRight are both true, then this will be true. */ override lazy val mergeExpressions: Seq[Expression] = { + val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left) resultType match { - case d: DecimalType => + case _: DecimalType => + val inputOverflow = !isEmpty.right && sum.right.isNull + val bufferOverflow = !isEmpty.left && sum.left.isNull Seq( - /* sum = */ - If(And(IsNull(sum.left), EqualTo(isEmptyOrNulls.left, false)) || - And(IsNull(sum.right), EqualTo(isEmptyOrNulls.right, false)), - Literal.create(null, resultType), - CheckOverflow(sum.left + sum.right, d, true)), - /* isEmptyOrNulls = */ - And(isEmptyOrNulls.left, isEmptyOrNulls.right) - ) - case _ => - Seq(coalesce(coalesce(sum.left, zero) + sum.right, sum.left)) + If(inputOverflow || bufferOverflow, Literal.create(null, sumDataType), mergeSumExpr), + isEmpty.left && isEmpty.right) + case _ => Seq(mergeSumExpr) } } @@ -163,11 +148,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast */ override lazy val evaluateExpression: Expression = resultType match { case d: DecimalType => - If(EqualTo(isEmptyOrNulls, true), - Literal.create(null, sumDataType), - If(And(SQLConf.get.ansiEnabled, IsNull(sum)), - OverflowException(resultType, "Arithmetic Operation overflow"), sum)) + If(isEmpty, Literal.create(null, sumDataType), + CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled)) case _ => sum } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index ddd7940fc0e19..9edd5cac75c5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -146,22 +146,53 @@ case class CheckOverflow( override def sql: String = child.sql } -case class OverflowException(dtype: DataType, msg: String) extends LeafExpression { - - override def dataType: DataType = dtype +// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`. +case class CheckOverflowInSum( + child: Expression, + dataType: DecimalType, + nullOnOverflow: Boolean) extends UnaryExpression { - override def nullable: Boolean = false + override def nullable: Boolean = true - def eval(input: InternalRow): Any = { - Decimal.throwArithmeticException(msg) + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + if (nullOnOverflow) null else throw new ArithmeticException("Overflow in sum of decimals.") + } else { + input.asInstanceOf[Decimal].toPrecision( + dataType.precision, + dataType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow) + } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.copy(code = code""" - |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - |${ev.value} = Decimal.throwArithmeticException("${msg}"); - |""", isNull = FalseLiteral) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val nullHandling = if (nullOnOverflow) { + "" + } else { + s""" + |throw new ArithmeticException("Overflow in sum of decimals."); + |""".stripMargin + } + val code = code""" + |${childGen.code} + |boolean ${ev.isNull} = ${childGen.isNull}; + |Decimal ${ev.value} = null; + |if (${childGen.isNull}) { + | $nullHandling + |} else { + | ${ev.value} = ${childGen.value}.toPrecision( + | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow); + | ${ev.isNull} = ${ev.value} == null; + |} + |""".stripMargin + + ev.copy(code = code) } - override def toString: String = "OverflowException" + override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)" + + override def sql: String = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 75e3cf4ad7a67..f32e48e1cc128 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -651,9 +651,4 @@ object Decimal { override def quot(x: Decimal, y: Decimal): Decimal = x quot y override def rem(x: Decimal, y: Decimal): Decimal = x % y } - - - def throwArithmeticException(msg: String): Decimal = { - throw new ArithmeticException(msg) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f4119b412e7e3..bbcb9df455501 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -192,6 +192,28 @@ class DataFrameSuite extends QueryTest structDf.select(xxhash64($"a", $"record.*"))) } + private def assertDecimalSumOverflow( + df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = { + if (!ansiEnabled) { + try { + checkAnswer(df, expectedAnswer) + } catch { + case e: SparkException if e.getCause.isInstanceOf[ArithmeticException] => + // This is an existing bug that we can write overflowed decimal to UnsafeRow but fail + // to read it. + assert(e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) + } + } else { + val e = intercept[SparkException] { + df.collect + } + assert(e.getCause.isInstanceOf[ArithmeticException]) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal") || + e.getCause.getMessage.contains("Overflow in sum of decimals") || + e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) + } + } + test("SPARK-28224: Aggregate sum big decimal overflow") { val largeDecimals = spark.sparkContext.parallelize( DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) :: @@ -200,24 +222,12 @@ class DataFrameSuite extends QueryTest Seq(true, false).foreach { ansiEnabled => withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { val structDf = largeDecimals.select("a").agg(sum("a")) - checkAnsi(structDf, ansiEnabled, Row(null)) - } - } - } - - private def checkAnsi(df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row ): Unit = { - if (!ansiEnabled) { - checkAnswer(df, expectedAnswer) - } else { - val e = intercept[SparkException] { - df.collect() + assertDecimalSumOverflow(structDf, ansiEnabled, Row(null)) } - assert(e.getCause.getClass.equals(classOf[ArithmeticException])) - assert(e.getCause.getMessage.contains("Arithmetic Operation overflow")) } } - test("test sum on null decimal values") { + test("SPARK-28067: sum of null decimal values") { Seq("true", "false").foreach { wholeStageEnabled => withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { Seq("true", "false").foreach { ansiEnabled => @@ -254,26 +264,27 @@ class DataFrameSuite extends QueryTest join(df, "intNum").agg(sum("decNum")) val expectedAnswer = Row(null) - checkAnsi(df2, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer) val decStr = "1" + "0" * 19 val d1 = spark.range(0, 12, 1, 1) val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) - checkAnsi(d2, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer) val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) - checkAnsi(d4, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer) val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"), lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd") - checkAnsi(d5, ansiEnabled, expectedAnswer) + assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer) val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")). toDF("d") - checkAnsi(nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer) + assertDecimalSumOverflow( + nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer) val df3 = Seq( (BigDecimal("10000000000000000000"), 1), @@ -293,7 +304,7 @@ class DataFrameSuite extends QueryTest val df6 = df3.union(df4).union(df5) val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")). filter("intNum == 1") - checkAnsi(df7, ansiEnabled, Row(1, null, 2)) + assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2)) } } }