diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 4fc0256bce23..8ae24e51351d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -87,8 +87,12 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case _: DecimalType => DecimalPrecision.decimalAndDecimal()( Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType) - case _: YearMonthIntervalType => DivideYMInterval(sum, count) - case _: DayTimeIntervalType => DivideDTInterval(sum, count) + case _: YearMonthIntervalType => + If(EqualTo(count, Literal(0L)), + Literal(null, YearMonthIntervalType), DivideYMInterval(sum, count)) + case _: DayTimeIntervalType => + If(EqualTo(count, Literal(0L)), + Literal(null, DayTimeIntervalType), DivideDTInterval(sum, count)) case _ => Divide(sum.cast(resultType), count.cast(resultType), failOnError = false) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index c53bcf045d00..c6f6cbdbf02c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1135,7 +1135,7 @@ class DataFrameAggregateSuite extends QueryTest val sumDF2 = df.groupBy($"class").agg(sum($"year-month"), sum($"day-time")) checkAnswer(sumDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) :: Row(2, Period.ofMonths(1), Duration.ofDays(1)) :: - Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) ::Nil) + Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) :: Nil) assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false), StructField("sum(year-month)", YearMonthIntervalType), @@ -1173,7 +1173,7 @@ class DataFrameAggregateSuite extends QueryTest val avgDF2 = df.groupBy($"class").agg(avg($"year-month"), avg($"day-time")) checkAnswer(avgDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) :: Row(2, Period.ofMonths(1), Duration.ofDays(1)) :: - Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) ::Nil) + Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) :: Nil) assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined) assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false), StructField("avg(year-month)", YearMonthIntervalType), @@ -1188,6 +1188,13 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(df2.select(avg($"day-time")), Nil) } assert(error2.toString contains "java.lang.ArithmeticException: long overflow") + + val df3 = df.filter($"class" > 4) + val avgDF3 = df3.select(avg($"year-month"), avg($"day-time")) + checkAnswer(avgDF3, Row(null, null) :: Nil) + + val avgDF4 = df3.groupBy($"class").agg(avg($"year-month"), avg($"day-time")) + checkAnswer(avgDF4, Nil) } }