Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.",
examples = """
Expand All @@ -34,8 +33,11 @@ import org.apache.spark.sql.types._
25
> SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col);
NULL
> SELECT _FUNC_(cast(col as interval)) FROM VALUES ('1 seconds'), ('2 seconds'), (null) tab(col);
interval 3 seconds
""",
since = "1.0.0")
// scalastyle:on line.size.limit
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {

override def children: Seq[Expression] = child :: Nil
Expand All @@ -45,14 +47,12 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
// Return data type.
override def dataType: DataType = resultType

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function sum")
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

private lazy val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
case _: CalendarIntervalType => CalendarIntervalType
case _: IntegralType => LongType
case _ => DoubleType
}
Expand All @@ -61,7 +61,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast

private lazy val sum = AttributeReference("sum", sumDataType)()

private lazy val zero = Cast(Literal(0), sumDataType)
private lazy val zero = Literal.default(resultType)

override lazy val aggBufferAttributes = sum :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {

assertError(Min('mapField), "min does not support ordering on type")
assertError(Max('mapField), "max does not support ordering on type")
assertError(Sum('booleanField), "function sum requires numeric type")
assertError(Sum('booleanField), "requires (numeric or interval) type")
assertError(Average('booleanField), "function average requires numeric type")
}

Expand Down
31 changes: 31 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/group-by.sql
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,34 @@ SELECT count(*) FROM test_agg WHERE count(*) > 1L;
SELECT count(*) FROM test_agg WHERE count(*) + 1L > 1L;
SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(k) > 1;

-- sum interval values
-- null
select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where v is null;

-- empty set
select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0;

-- basic interval sum
select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v);
select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v);
select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v);
select sum(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v);

-- group by
select
i,
sum(cast(v as interval))
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
group by i;

-- having
select
sum(cast(v as interval)) as sv
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
having sv is not null;

-- window
SELECT
i,
Sum(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM VALUES(1,'1 seconds'),(1,'2 seconds'),(2,NULL),(2,NULL) t(i,v);
89 changes: 88 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/group-by.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 56
-- Number of queries: 65


-- !query 0
Expand Down Expand Up @@ -573,3 +573,90 @@ org.apache.spark.sql.AnalysisException
Aggregate/Window/Generate expressions are not valid in where clause of the query.
Expression in where clause: [(((test_agg.`k` = 1) OR (test_agg.`k` = 2)) OR (((count(1) + 1L) > 1L) OR (max(test_agg.`k`) > 1)))]
Invalid expressions: [count(1), max(test_agg.`k`)];


-- !query 56
select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where v is null
-- !query 56 schema
struct<sum(CAST(v AS INTERVAL)):interval>
-- !query 56 output
NULL


-- !query 57
select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0
-- !query 57 schema
struct<sum(CAST(v AS INTERVAL)):interval>
-- !query 57 output
NULL


-- !query 58
select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v)
-- !query 58 schema
struct<sum(CAST(v AS INTERVAL)):interval>
-- !query 58 output
interval 3 seconds


-- !query 59
select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v)
-- !query 59 schema
struct<sum(CAST(v AS INTERVAL)):interval>
-- !query 59 output
interval 1 seconds


-- !query 60
select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v)
-- !query 60 schema
struct<sum(CAST(v AS INTERVAL)):interval>
-- !query 60 output
interval -3 seconds


-- !query 61
select sum(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v)
-- !query 61 schema
struct<sum(CAST(v AS INTERVAL)):interval>
-- !query 61 output
interval -7 days 2 seconds


-- !query 62
select
i,
sum(cast(v as interval))
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
group by i
-- !query 62 schema
struct<i:int,sum(CAST(v AS INTERVAL)):interval>
-- !query 62 output
1 interval -2 days
2 interval 2 seconds
3 NULL


-- !query 63
select
sum(cast(v as interval)) as sv
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
having sv is not null
-- !query 63 schema
struct<sv:interval>
-- !query 63 output
interval -2 days 2 seconds


-- !query 64
SELECT
i,
Sum(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM VALUES(1,'1 seconds'),(1,'2 seconds'),(2,NULL),(2,NULL) t(i,v)
-- !query 64 schema
struct<i:int,sum(CAST(v AS INTERVAL)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):interval>
-- !query 64 output
1 interval 2 seconds
1 interval 3 seconds
2 NULL
2 NULL