Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry}
import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

@ExpressionDescription(
Expand All @@ -30,8 +31,6 @@ import org.apache.spark.sql.types._
2.0
> SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col);
1.5
> SELECT _FUNC_(cast(v as interval)) FROM VALUES ('-1 weeks'), ('2 seconds'), (null) t(v);
-3 days -11 hours -59 minutes -59 seconds
""",
since = "1.0.0")
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
Expand All @@ -40,7 +39,10 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit

override def children: Seq[Expression] = child :: Nil

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function average")

override def nullable: Boolean = true

Expand All @@ -50,13 +52,11 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
private lazy val resultType = child.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
case interval: CalendarIntervalType => interval
case _ => DoubleType
}

private lazy val sumDataType = child.dataType match {
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
case interval: CalendarIntervalType => interval
case _ => DoubleType
}

Expand All @@ -79,9 +79,6 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
override lazy val evaluateExpression = child.dataType match {
case _: DecimalType =>
DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType)
case CalendarIntervalType =>
val newCount = If(EqualTo(count, Literal(0L)), Literal(null, LongType), count)
DivideInterval(sum.cast(resultType), newCount.cast(DoubleType))
case _ =>
sum.cast(resultType) / count.cast(resultType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.",
examples = """
Expand All @@ -33,11 +34,8 @@ import org.apache.spark.sql.types._
25
> SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col);
NULL
> SELECT _FUNC_(cast(col as interval)) FROM VALUES ('1 seconds'), ('2 seconds'), (null) tab(col);
3 seconds
""",
since = "1.0.0")
// scalastyle:on line.size.limit
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {

override def children: Seq[Expression] = child :: Nil
Expand All @@ -47,12 +45,14 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
// Return data type.
override def dataType: DataType = resultType

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function sum")

private lazy val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
case _: CalendarIntervalType => CalendarIntervalType
case _: IntegralType => LongType
case _ => DoubleType
}
Expand All @@ -61,7 +61,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast

private lazy val sum = AttributeReference("sum", sumDataType)()

private lazy val zero = Literal.default(resultType)
private lazy val zero = Literal.default(sumDataType)

override lazy val aggBufferAttributes = sum :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {

assertError(Min(Symbol("mapField")), "min does not support ordering on type")
assertError(Max(Symbol("mapField")), "max does not support ordering on type")
assertError(Sum(Symbol("booleanField")), "requires (numeric or interval) type")
assertError(Average(Symbol("booleanField")), "requires (numeric or interval) type")
assertError(Sum(Symbol("booleanField")), "function sum requires numeric type")
assertError(Average(Symbol("booleanField")), "function average requires numeric type")
}

test("check types for others") {
Expand Down
64 changes: 0 additions & 64 deletions sql/core/src/test/resources/sql-tests/inputs/interval.sql
Original file line number Diff line number Diff line change
Expand Up @@ -84,70 +84,6 @@ select interval (-30) day;
select interval (a + 1) day;
select interval 30 day day day;

-- sum interval values
-- null
select sum(cast(null as interval));

-- empty set
select sum(cast(v as interval)) from VALUES ('1 seconds') t(v) where 1=0;

-- basic interval sum
select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v);
select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v);
select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v);
select sum(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v);

-- group by
select
i,
sum(cast(v as interval))
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
group by i;

-- having
select
sum(cast(v as interval)) as sv
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
having sv is not null;

-- window
SELECT
i,
sum(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM VALUES(1, '1 seconds'), (1, '2 seconds'), (2, NULL), (2, NULL) t(i,v);

-- average with interval type
-- null
select avg(cast(v as interval)) from VALUES (null) t(v);

-- empty set
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0;

-- basic interval avg
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v);
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v);
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v);
select avg(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v);

-- group by
select
i,
avg(cast(v as interval))
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
group by i;

-- having
select
avg(cast(v as interval)) as sv
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
having sv is not null;

-- window
SELECT
i,
avg(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM VALUES (1,'1 seconds'), (1,'2 seconds'), (2,NULL), (2,NULL) t(i,v);

-- Interval year-month arithmetic

create temporary view interval_arithmetic as
Expand Down
176 changes: 1 addition & 175 deletions sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 99
-- Number of queries: 81


-- !query
Expand Down Expand Up @@ -631,180 +631,6 @@ select interval 30 day day day
-----------------------^^^


-- !query
select sum(cast(null as interval))
-- !query schema
struct<sum(CAST(NULL AS INTERVAL)):interval>
-- !query output
NULL


-- !query
select sum(cast(v as interval)) from VALUES ('1 seconds') t(v) where 1=0
-- !query schema
struct<sum(CAST(v AS INTERVAL)):interval>
-- !query output
NULL


-- !query
select sum(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v)
-- !query schema
struct<sum(CAST(v AS INTERVAL)):interval>
-- !query output
3 seconds


-- !query
select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v)
-- !query schema
struct<sum(CAST(v AS INTERVAL)):interval>
-- !query output
1 seconds


-- !query
select sum(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v)
-- !query schema
struct<sum(CAST(v AS INTERVAL)):interval>
-- !query output
-3 seconds


-- !query
select sum(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v)
-- !query schema
struct<sum(CAST(v AS INTERVAL)):interval>
-- !query output
-7 days 2 seconds


-- !query
select
i,
sum(cast(v as interval))
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
group by i
-- !query schema
struct<i:int,sum(CAST(v AS INTERVAL)):interval>
-- !query output
1 -2 days
2 2 seconds
3 NULL


-- !query
select
sum(cast(v as interval)) as sv
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
having sv is not null
-- !query schema
struct<sv:interval>
-- !query output
-2 days 2 seconds


-- !query
SELECT
i,
sum(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM VALUES(1, '1 seconds'), (1, '2 seconds'), (2, NULL), (2, NULL) t(i,v)
-- !query schema
struct<i:int,sum(CAST(v AS INTERVAL)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):interval>
-- !query output
1 2 seconds
1 3 seconds
2 NULL
2 NULL


-- !query
select avg(cast(v as interval)) from VALUES (null) t(v)
-- !query schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query output
NULL


-- !query
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0
-- !query schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query output
NULL


-- !query
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v)
-- !query schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query output
1.5 seconds


-- !query
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v)
-- !query schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query output
0.5 seconds


-- !query
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v)
-- !query schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query output
-1.5 seconds


-- !query
select avg(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v)
-- !query schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query output
-3 days -11 hours -59 minutes -59 seconds


-- !query
select
i,
avg(cast(v as interval))
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
group by i
-- !query schema
struct<i:int,avg(CAST(v AS INTERVAL)):interval>
-- !query output
1 -1 days
2 2 seconds
3 NULL


-- !query
select
avg(cast(v as interval)) as sv
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
having sv is not null
-- !query schema
struct<sv:interval>
-- !query output
-15 hours -59 minutes -59.333333 seconds


-- !query
SELECT
i,
avg(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM VALUES (1,'1 seconds'), (1,'2 seconds'), (2,NULL), (2,NULL) t(i,v)
-- !query schema
struct<i:int,avg(CAST(v AS INTERVAL)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):interval>
-- !query output
1 1.5 seconds
1 2 seconds
2 NULL
2 NULL


-- !query
create temporary view interval_arithmetic as
select CAST(dateval AS date), CAST(tsval AS timestamp) from values
Expand Down
Loading