diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index ef204ec82c52..d04fe9249d06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @ExpressionDescription( @@ -89,5 +90,9 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast ) } - override lazy val evaluateExpression: Expression = sum + override lazy val evaluateExpression: Expression = resultType match { + case d: DecimalType => CheckOverflow(sum, d, SQLConf.get.decimalOperationsNullOnOverflow) + case _ => sum + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ba8fced983c6..c6daff1479fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} -import org.apache.spark.sql.test.SQLTestData.{NullStrings, TestData2} +import org.apache.spark.sql.test.SQLTestData.{DecimalData, NullStrings, TestData2} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -156,6 +156,27 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { structDf.select(xxhash64($"a", $"record.*"))) } + test("SPARK-28224: Aggregate sum big decimal overflow") { + val largeDecimals = spark.sparkContext.parallelize( + DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) :: + DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 + ".123")) :: Nil).toDF() + + Seq(true, false).foreach { nullOnOverflow => + withSQLConf((SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key, nullOnOverflow.toString)) { + val structDf = largeDecimals.select("a").agg(sum("a")) + if (nullOnOverflow) { + checkAnswer(structDf, Row(null)) + } else { + val e = intercept[SparkException] { + structDf.collect + } + assert(e.getCause.getClass.equals(classOf[ArithmeticException])) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) + } + } + } + } + test("Star Expansion - explode should fail with a meaningful message if it takes a star") { val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv") val e = intercept[AnalysisException] {