diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index c2ab8adfaef67..843c361233956 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +// scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.", examples = """ @@ -34,8 +33,11 @@ import org.apache.spark.sql.types._ 25 > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); NULL + > SELECT _FUNC_(cast(col as interval)) FROM VALUES ('1 seconds'), ('2 seconds'), (null) tab(col); + interval 3 seconds """, since = "1.0.0") +// scalastyle:on line.size.limit case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = child :: Nil @@ -45,14 +47,12 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast // Return data type. override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sum") + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) + case _: CalendarIntervalType => CalendarIntervalType case _: IntegralType => LongType case _ => DoubleType } @@ -61,7 +61,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val sum = AttributeReference("sum", sumDataType)() - private lazy val zero = Cast(Literal(0), sumDataType) + private lazy val zero = Literal.default(resultType) override lazy val aggBufferAttributes = sum :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index ed11bce5d12b4..9beb07a2f0303 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -149,7 +149,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Min('mapField), "min does not support ordering on type") assertError(Max('mapField), "max does not support ordering on type") - assertError(Sum('booleanField), "function sum requires numeric type") + assertError(Sum('booleanField), "requires (numeric or interval) type") assertError(Average('booleanField), "function average requires numeric type") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index fcde225676cb9..72531718c8b05 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -162,3 +162,34 @@ SELECT count(*) FROM test_agg WHERE count(*) > 1L; SELECT count(*) FROM test_agg WHERE count(*) + 1L > 1L; SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(k) > 1; +-- sum interval values +-- null +select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where v is null; + +-- empty set +select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0; + +-- basic interval sum +select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v); +select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v); +select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v); +select sum(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v); + +-- group by +select + i, + sum(cast(v as interval)) +from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) +group by i; + +-- having +select + sum(cast(v as interval)) as sv +from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) +having sv is not null; + +-- window +SELECT + i, + Sum(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) +FROM VALUES(1,'1 seconds'),(1,'2 seconds'),(2,NULL),(2,NULL) t(i,v); \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 545aa238dd756..eed6e02798895 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 56 +-- Number of queries: 65 -- !query 0 @@ -573,3 +573,90 @@ org.apache.spark.sql.AnalysisException Aggregate/Window/Generate expressions are not valid in where clause of the query. Expression in where clause: [(((test_agg.`k` = 1) OR (test_agg.`k` = 2)) OR (((count(1) + 1L) > 1L) OR (max(test_agg.`k`) > 1)))] Invalid expressions: [count(1), max(test_agg.`k`)]; + + +-- !query 56 +select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where v is null +-- !query 56 schema +struct +-- !query 56 output +NULL + + +-- !query 57 +select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0 +-- !query 57 schema +struct +-- !query 57 output +NULL + + +-- !query 58 +select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) +-- !query 58 schema +struct +-- !query 58 output +interval 3 seconds + + +-- !query 59 +select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v) +-- !query 59 schema +struct +-- !query 59 output +interval 1 seconds + + +-- !query 60 +select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v) +-- !query 60 schema +struct +-- !query 60 output +interval -3 seconds + + +-- !query 61 +select sum(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v) +-- !query 61 schema +struct +-- !query 61 output +interval -7 days 2 seconds + + +-- !query 62 +select + i, + sum(cast(v as interval)) +from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) +group by i +-- !query 62 schema +struct +-- !query 62 output +1 interval -2 days +2 interval 2 seconds +3 NULL + + +-- !query 63 +select + sum(cast(v as interval)) as sv +from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v) +having sv is not null +-- !query 63 schema +struct +-- !query 63 output +interval -2 days 2 seconds + + +-- !query 64 +SELECT + i, + Sum(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) +FROM VALUES(1,'1 seconds'),(1,'2 seconds'),(2,NULL),(2,NULL) t(i,v) +-- !query 64 schema +struct +-- !query 64 output +1 interval 2 seconds +1 interval 3 seconds +2 NULL +2 NULL