Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.spark.sql.types._
2.0
> SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col);
1.5
> SELECT _FUNC_(cast(v as interval)) FROM VALUES ('-1 weeks'), ('2 seconds'), (null) t(v);
-3 days -11 hours -59 minutes -59 seconds
""",
since = "1.0.0")
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
Expand All @@ -39,10 +41,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit

override def children: Seq[Expression] = child :: Nil

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function average")
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

override def nullable: Boolean = true

Expand All @@ -52,11 +51,13 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
private lazy val resultType = child.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
case interval: CalendarIntervalType => interval
case _ => DoubleType
}

private lazy val sumDataType = child.dataType match {
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
case interval: CalendarIntervalType => interval
case _ => DoubleType
}

Expand All @@ -66,7 +67,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
override lazy val aggBufferAttributes = sum :: count :: Nil

override lazy val initialValues = Seq(
/* sum = */ Literal(0).cast(sumDataType),
/* sum = */ Literal.default(sumDataType),
/* count = */ Literal(0L)
)

Expand All @@ -79,6 +80,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
override lazy val evaluateExpression = child.dataType match {
case _: DecimalType =>
DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType)
case CalendarIntervalType =>
DivideInterval(sum.cast(resultType), count.cast(DoubleType))
case _ =>
sum.cast(resultType) / count.cast(resultType)
}
Expand All @@ -87,7 +90,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
/* sum = */
Add(
sum,
coalesce(child.cast(sumDataType), Literal(0).cast(sumDataType))),
coalesce(child.cast(sumDataType), Literal.default(sumDataType))),
/* count = */ If(child.isNull, count, count + 1L)
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(Min('mapField), "min does not support ordering on type")
assertError(Max('mapField), "max does not support ordering on type")
assertError(Sum('booleanField), "requires (numeric or interval) type")
assertError(Average('booleanField), "function average requires numeric type")
assertError(Average('booleanField), "requires (numeric or interval) type")
}

test("check types for others") {
Expand Down
34 changes: 33 additions & 1 deletion sql/core/src/test/resources/sql-tests/inputs/group-by.sql
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,36 @@ having sv is not null;
SELECT
i,
Sum(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM VALUES(1,'1 seconds'),(1,'2 seconds'),(2,NULL),(2,NULL) t(i,v);
FROM VALUES(1,'1 seconds'),(1,'2 seconds'),(2,NULL),(2,NULL) t(i,v);

-- average with interval type
-- null
select avg(cast(v as interval)) from VALUES (null) t(v);

-- empty set
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0;

-- basic interval avg
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v);
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v);
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v);
select avg(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v);

-- group by
select
i,
avg(cast(v as interval))
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
group by i;

-- having
select
avg(cast(v as interval)) as sv
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
having sv is not null;

-- window
SELECT
i,
avg(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM VALUES (1,'1 seconds'), (1,'2 seconds'), (2,NULL), (2,NULL) t(i,v);
89 changes: 88 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/group-by.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 65
-- Number of queries: 74


-- !query 0
Expand Down Expand Up @@ -660,3 +660,90 @@ struct<i:int,sum(CAST(v AS INTERVAL)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETW
1 3 seconds
2 NULL
2 NULL


-- !query 65
select avg(cast(v as interval)) from VALUES (null) t(v)
-- !query 65 schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query 65 output
NULL


-- !query 66
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0
-- !query 66 schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query 66 output
NULL


-- !query 67
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v)
-- !query 67 schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query 67 output
1.5 seconds


-- !query 68
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v)
-- !query 68 schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query 68 output
0.5 seconds


-- !query 69
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v)
-- !query 69 schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query 69 output
-1.5 seconds


-- !query 70
select avg(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v)
-- !query 70 schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query 70 output
-3 days -11 hours -59 minutes -59 seconds


-- !query 71
select
i,
avg(cast(v as interval))
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
group by i
-- !query 71 schema
struct<i:int,avg(CAST(v AS INTERVAL)):interval>
-- !query 71 output
1 -1 days
2 2 seconds
3 NULL


-- !query 72
select
avg(cast(v as interval)) as sv
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
having sv is not null
-- !query 72 schema
struct<sv:interval>
-- !query 72 output
-15 hours -59 minutes -59.333333 seconds


-- !query 73
SELECT
i,
avg(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM VALUES (1,'1 seconds'), (1,'2 seconds'), (2,NULL), (2,NULL) t(i,v)
-- !query 73 schema
struct<i:int,avg(CAST(v AS INTERVAL)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):interval>
-- !query 73 output
1 1.5 seconds
1 2 seconds
2 NULL
2 NULL