Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
private[sql] object TypeCollection {

/**
* Types that include numeric types and interval type. They are only used in unary_minus,
* unary_positive, add and subtract operations.
* Types that include numeric types and interval type, which support numeric type calculations,
* i.e. unary_minus, unary_positive, sum, avg, min, max, add and subtract operations.
*/
val NumericAndInterval = TypeCollection(NumericType, CalendarIntervalType)

Expand Down
6 changes: 3 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,9 @@ class Dataset[T] private[sql](
}
}

private[sql] def numericColumns: Seq[Expression] = {
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get
private[sql] def numericCalculationSupportedColumns: Seq[Expression] = {
queryExecution.analyzed.output.filter { attr =>
TypeCollection.NumericAndInterval.acceptsType(attr.dataType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{NumericType, StructType}
import org.apache.spark.sql.types.{StructType, TypeCollection}

/**
* A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]],
Expand Down Expand Up @@ -88,20 +88,20 @@ class RelationalGroupedDataset protected[sql](
case expr: Expression => Alias(expr, toPrettySQL(expr))()
}

private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction)
: DataFrame = {
private[this] def aggregateNumericOrIntervalColumns(
colNames: String*)(f: Expression => AggregateFunction): DataFrame = {

val columnExprs = if (colNames.isEmpty) {
// No columns specified. Use all numeric columns.
df.numericColumns
// No columns specified. Use all numeric calculation supported columns.
df.numericCalculationSupportedColumns
} else {
// Make sure all specified columns are numeric.
// Make sure all specified columns are numeric calculation supported columns.
colNames.map { colName =>
val namedExpr = df.resolve(colName)
if (!namedExpr.dataType.isInstanceOf[NumericType]) {
if (!TypeCollection.NumericAndInterval.acceptsType(namedExpr.dataType)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the comment to make it more general one?

* Types that include numeric types and interval type. They are only used in unary_minus,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for your suggestion. Please check 。

throw new AnalysisException(
s""""$colName" is not a numeric column. """ +
"Aggregation function can only be applied on a numeric column.")
s""""$colName" is not a numeric or calendar interval column. """ +
"Aggregation function can only be applied on a numeric or calendar interval column.")
}
namedExpr
}
Expand Down Expand Up @@ -269,63 +269,64 @@ class RelationalGroupedDataset protected[sql](
def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")()))

/**
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
* Compute the average value for each numeric or calender interval columns for each group. This
* is an alias for `avg`.
* The resulting `DataFrame` will also contain the grouping columns.
* When specified columns are given, only compute the average values for them.
*
* @since 1.3.0
*/
@scala.annotation.varargs
def mean(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Average)
aggregateNumericOrIntervalColumns(colNames : _*)(Average)
}

/**
* Compute the max value for each numeric columns for each group.
* Compute the max value for each numeric calender interval columns for each group.
* The resulting `DataFrame` will also contain the grouping columns.
* When specified columns are given, only compute the max values for them.
*
* @since 1.3.0
*/
@scala.annotation.varargs
def max(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Max)
aggregateNumericOrIntervalColumns(colNames : _*)(Max)
}

/**
* Compute the mean value for each numeric columns for each group.
* Compute the mean value for each numeric calender interval columns for each group.
* The resulting `DataFrame` will also contain the grouping columns.
* When specified columns are given, only compute the mean values for them.
*
* @since 1.3.0
*/
@scala.annotation.varargs
def avg(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Average)
aggregateNumericOrIntervalColumns(colNames : _*)(Average)
}

/**
* Compute the min value for each numeric column for each group.
* Compute the min value for each numeric calender interval column for each group.
* The resulting `DataFrame` will also contain the grouping columns.
* When specified columns are given, only compute the min values for them.
*
* @since 1.3.0
*/
@scala.annotation.varargs
def min(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Min)
aggregateNumericOrIntervalColumns(colNames : _*)(Min)
}

/**
* Compute the sum for each numeric columns for each group.
* Compute the sum for each numeric calender interval columns for each group.
* The resulting `DataFrame` will also contain the grouping columns.
* When specified columns are given, only compute the sum for them.
*
* @since 1.3.0
*/
@scala.annotation.varargs
def sum(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Sum)
aggregateNumericOrIntervalColumns(colNames : _*)(Sum)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -964,4 +964,24 @@ class DataFrameAggregateSuite extends QueryTest with SharedSparkSession {
Row(3, new CalendarInterval(0, 3, 0)) :: Nil)
assert(df3.queryExecution.executedPlan.find(_.isInstanceOf[HashAggregateExec]).isDefined)
}

test("Dataset agg functions support calendar intervals") {
val df1 = Seq((1, "1 day"), (2, "2 day"), (3, "3 day"), (3, null)).toDF("a", "b")
val df2 = df1.select('a, 'b cast CalendarIntervalType).groupBy('a % 2)
checkAnswer(df2.sum("b"),
Row(0, new CalendarInterval(0, 2, 0)) ::
Row(1, new CalendarInterval(0, 4, 0)) :: Nil)
checkAnswer(df2.avg("b"),
Row(0, new CalendarInterval(0, 2, 0)) ::
Row(1, new CalendarInterval(0, 2, 0)) :: Nil)
checkAnswer(df2.mean("b"),
Row(0, new CalendarInterval(0, 2, 0)) ::
Row(1, new CalendarInterval(0, 2, 0)) :: Nil)
checkAnswer(df2.max("b"),
Row(0, new CalendarInterval(0, 2, 0)) ::
Row(1, new CalendarInterval(0, 3, 0)) :: Nil)
checkAnswer(df2.min("b"),
Row(0, new CalendarInterval(0, 2, 0)) ::
Row(1, new CalendarInterval(0, 1, 0)) :: Nil)
}
}