diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 66ac73087b4d..aaad3c7bcefa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.types._ 2.0 > SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col); 1.5 + > SELECT _FUNC_(cast(v as interval)) FROM VALUES ('-1 weeks'), ('2 seconds'), (null) t(v); + -3 days -11 hours -59 minutes -59 seconds """, since = "1.0.0") case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { @@ -39,10 +41,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit override def children: Seq[Expression] = child :: Nil - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) override def nullable: Boolean = true @@ -52,11 +51,13 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) + case interval: CalendarIntervalType => interval case _ => DoubleType } private lazy val sumDataType = child.dataType match { case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) + case interval: CalendarIntervalType => interval case _ => DoubleType } @@ -66,7 +67,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit override lazy val aggBufferAttributes = sum :: count :: Nil override lazy val initialValues = Seq( - /* sum = */ Literal(0).cast(sumDataType), + /* sum = */ Literal.default(sumDataType), /* count = */ Literal(0L) ) @@ -79,6 +80,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit override lazy val evaluateExpression = child.dataType match { case _: DecimalType => DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType) + case CalendarIntervalType => + DivideInterval(sum.cast(resultType), count.cast(DoubleType)) case _ => sum.cast(resultType) / count.cast(resultType) } @@ -87,7 +90,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit /* sum = */ Add( sum, - coalesce(child.cast(sumDataType), Literal(0).cast(sumDataType))), + coalesce(child.cast(sumDataType), Literal.default(sumDataType))), /* count = */ If(child.isNull, count, count + 1L) ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 9beb07a2f030..93ea3221e747 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -150,7 +150,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Min('mapField), "min does not support ordering on type") assertError(Max('mapField), "max does not support ordering on type") assertError(Sum('booleanField), "requires (numeric or interval) type") - assertError(Average('booleanField), "function average requires numeric type") + assertError(Average('booleanField), "requires (numeric or interval) type") } test("check types for others") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 72531718c8b0..c405fb0aa9e8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -192,4 +192,36 @@ having sv is not null; SELECT i, Sum(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) -FROM VALUES(1,'1 seconds'),(1,'2 seconds'),(2,NULL),(2,NULL) t(i,v); \ No newline at end of file +FROM VALUES(1,'1 seconds'),(1,'2 seconds'),(2,NULL),(2,NULL) t(i,v); + +-- average with interval type +-- null +select avg(cast(v as interval)) from VALUES (null) t(v); + +-- empty set +select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0; + +-- basic interval avg +select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v); +select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v); +select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v); +select avg(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v); + +-- group by +select + i, + avg(cast(v as interval)) +from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) +group by i; + +-- having +select + avg(cast(v as interval)) as sv +from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) +having sv is not null; + +-- window +SELECT + i, + avg(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) +FROM VALUES (1,'1 seconds'), (1,'2 seconds'), (2,NULL), (2,NULL) t(i,v); diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 0417bfb070d4..150ee8aab01e 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 65 +-- Number of queries: 74 -- !query 0 @@ -660,3 +660,90 @@ struct +-- !query 65 output +NULL + + +-- !query 66 +select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0 +-- !query 66 schema +struct +-- !query 66 output +NULL + + +-- !query 67 +select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) +-- !query 67 schema +struct +-- !query 67 output +1.5 seconds + + +-- !query 68 +select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v) +-- !query 68 schema +struct +-- !query 68 output +0.5 seconds + + +-- !query 69 +select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v) +-- !query 69 schema +struct +-- !query 69 output +-1.5 seconds + + +-- !query 70 +select avg(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v) +-- !query 70 schema +struct +-- !query 70 output +-3 days -11 hours -59 minutes -59 seconds + + +-- !query 71 +select + i, + avg(cast(v as interval)) +from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) +group by i +-- !query 71 schema +struct +-- !query 71 output +1 -1 days +2 2 seconds +3 NULL + + +-- !query 72 +select + avg(cast(v as interval)) as sv +from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) +having sv is not null +-- !query 72 schema +struct +-- !query 72 output +-15 hours -59 minutes -59.333333 seconds + + +-- !query 73 +SELECT + i, + avg(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) +FROM VALUES (1,'1 seconds'), (1,'2 seconds'), (2,NULL), (2,NULL) t(i,v) +-- !query 73 schema +struct +-- !query 73 output +1 1.5 seconds +1 2 seconds +2 NULL +2 NULL