diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 01947273b6cc..47b4e643395e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -214,10 +214,10 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN override def toString = s"AVG($child)" override def asPartial: SplitEvaluation = { - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - val castedSum = Cast(Sum(partialSum.toAttribute), dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), dataType) + val partialSum = Alias(Sum(Cast(child, dataType)), "PartialSum")() + val partialCount = Alias(Cast(Count(child), dataType), "PartialCount")() + val castedSum = Sum(partialSum.toAttribute) + val castedCount = Sum(partialCount.toAttribute) SplitEvaluation( Divide(castedSum, castedCount), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 95860e6683f6..c47c243bc0f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -123,6 +123,12 @@ class SQLQuerySuite extends QueryTest { 2.0) } + test("average overflow test") { + checkAnswer( + sql("SELECT AVG(a),b FROM testData1 group by b"), + Seq((2147483645.0,1),(2.0,2))) + } + test("count") { checkAnswer( sql("SELECT COUNT(*) FROM testData2"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 944f520e4351..b8fe3c385632 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -30,6 +30,17 @@ object TestData { (1 to 100).map(i => TestData(i, i.toString))) testData.registerAsTable("testData") + case class TestData1(a: Int, b: Int) + val testData1: SchemaRDD = + TestSQLContext.sparkContext.parallelize( + TestData1(2147483644, 1) :: + TestData1(1, 2) :: + TestData1(2147483645, 1) :: + TestData1(2, 2) :: + TestData1(2147483646, 1) :: + TestData1(3, 2) :: Nil) + testData1.registerAsTable("testData1") + case class TestData2(a: Int, b: Int) val testData2: SchemaRDD = TestSQLContext.sparkContext.parallelize(