From 4b353df372939713bb9988e78a1d26c4c45fbcef Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 26 Nov 2019 20:15:28 +0800 Subject: [PATCH 1/4] [SPARK-30048][SQL] Enable aggregates with interval type values for RelationalGroupedDataset --- .../spark/sql/RelationalGroupedDataset.scala | 33 ++++++++++--------- .../spark/sql/DataFrameAggregateSuite.scala | 23 ++++++++++++- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index b1ba7d4538732..407799c418ca7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.sql.types.{StructType, TypeCollection} /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -88,8 +88,8 @@ class RelationalGroupedDataset protected[sql]( case expr: Expression => Alias(expr, toPrettySQL(expr))() } - private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) - : DataFrame = { + private[this] def aggregateNumericOrIntervalColumns( + colNames: String*)(f: Expression => AggregateFunction): DataFrame = { val columnExprs = if (colNames.isEmpty) { // No columns specified. Use all numeric columns. @@ -98,10 +98,10 @@ class RelationalGroupedDataset protected[sql]( // Make sure all specified columns are numeric. colNames.map { colName => val namedExpr = df.resolve(colName) - if (!namedExpr.dataType.isInstanceOf[NumericType]) { + if (!TypeCollection.NumericAndInterval.acceptsType(namedExpr.dataType)) { throw new AnalysisException( - s""""$colName" is not a numeric column. """ + - "Aggregation function can only be applied on a numeric column.") + s""""$colName" is not a numeric or calendar interval column. """ + + "Aggregation function can only be applied on a numeric or calendar interval column.") } namedExpr } @@ -269,7 +269,8 @@ class RelationalGroupedDataset protected[sql]( def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")())) /** - * Compute the average value for each numeric columns for each group. This is an alias for `avg`. + * Compute the average value for each numeric or calender interval columns for each group. This + * is an alias for `avg`. * The resulting `DataFrame` will also contain the grouping columns. * When specified columns are given, only compute the average values for them. * @@ -277,11 +278,11 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def mean(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Average) + aggregateNumericOrIntervalColumns(colNames : _*)(Average) } /** - * Compute the max value for each numeric columns for each group. + * Compute the max value for each numeric calender interval columns for each group. * The resulting `DataFrame` will also contain the grouping columns. * When specified columns are given, only compute the max values for them. * @@ -289,11 +290,11 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def max(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Max) + aggregateNumericOrIntervalColumns(colNames : _*)(Max) } /** - * Compute the mean value for each numeric columns for each group. + * Compute the mean value for each numeric calender interval columns for each group. * The resulting `DataFrame` will also contain the grouping columns. * When specified columns are given, only compute the mean values for them. * @@ -301,11 +302,11 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def avg(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Average) + aggregateNumericOrIntervalColumns(colNames : _*)(Average) } /** - * Compute the min value for each numeric column for each group. + * Compute the min value for each numeric calender interval column for each group. * The resulting `DataFrame` will also contain the grouping columns. * When specified columns are given, only compute the min values for them. * @@ -313,11 +314,11 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def min(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Min) + aggregateNumericOrIntervalColumns(colNames : _*)(Min) } /** - * Compute the sum for each numeric columns for each group. + * Compute the sum for each numeric calender interval columns for each group. * The resulting `DataFrame` will also contain the grouping columns. * When specified columns are given, only compute the sum for them. * @@ -325,7 +326,7 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def sum(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Sum) + aggregateNumericOrIntervalColumns(colNames : _*)(Sum) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index ec7b636c8f695..4635d3de878a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -29,7 +29,8 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.test.SQLTestData.DecimalData -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.{CalendarIntervalType, DecimalType} +import org.apache.spark.unsafe.types.CalendarInterval case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) @@ -942,4 +943,24 @@ class DataFrameAggregateSuite extends QueryTest with SharedSparkSession { assert(error.message.contains("function count_if requires boolean type")) } } + + test("calendar interval agg support hash aggregate") { + val df1 = Seq((1, "1 day"), (2, "2 day"), (3, "3 day"), (3, null)).toDF("a", "b") + val df2 = df1.select('a, 'b cast CalendarIntervalType).groupBy('a % 2) + checkAnswer(df2.sum("b"), + Row(0, new CalendarInterval(0, 2, 0)) :: + Row(1, new CalendarInterval(0, 4, 0)) :: Nil) + checkAnswer(df2.avg("b"), + Row(0, new CalendarInterval(0, 2, 0)) :: + Row(1, new CalendarInterval(0, 2, 0)) :: Nil) + checkAnswer(df2.mean("b"), + Row(0, new CalendarInterval(0, 2, 0)) :: + Row(1, new CalendarInterval(0, 2, 0)) :: Nil) + checkAnswer(df2.max("b"), + Row(0, new CalendarInterval(0, 2, 0)) :: + Row(1, new CalendarInterval(0, 3, 0)) :: Nil) + checkAnswer(df2.min("b"), + Row(0, new CalendarInterval(0, 2, 0)) :: + Row(1, new CalendarInterval(0, 1, 0)) :: Nil) + } } From f93cd9ff3b777fa56f26ae20e3a5bcaac14741f2 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 26 Nov 2019 22:00:04 +0800 Subject: [PATCH 2/4] fix --- .../scala/org/apache/spark/sql/types/AbstractDataType.scala | 4 ++-- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 6 ++++-- .../org/apache/spark/sql/RelationalGroupedDataset.scala | 6 +++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 21ac32adca6e9..25303475a73ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -79,8 +79,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) private[sql] object TypeCollection { /** - * Types that include numeric types and interval type. They are only used in unary_minus, - * unary_positive, add and subtract operations. + * Types that include numeric types and interval type, which support numeric type calculations, + * i.e. unary_minus, unary_positive, sum, avg, min, max, add and subtract operations. */ val NumericAndInterval = TypeCollection(NumericType, CalendarIntervalType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e1bca44dfccf5..b29598e3a547f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -268,8 +268,10 @@ class Dataset[T] private[sql]( } } - private[sql] def numericColumns: Seq[Expression] = { - schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => + private[sql] def numericCalculationSupportedColumns: Seq[Expression] = { + schema.fields.filter{ f => + TypeCollection.NumericAndInterval.acceptsType(f.dataType) + }.map { n => queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 407799c418ca7..52bd0ecb1fffd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -92,10 +92,10 @@ class RelationalGroupedDataset protected[sql]( colNames: String*)(f: Expression => AggregateFunction): DataFrame = { val columnExprs = if (colNames.isEmpty) { - // No columns specified. Use all numeric columns. - df.numericColumns + // No columns specified. Use all numeric calculation supported columns. + df.numericCalculationSupportedColumns } else { - // Make sure all specified columns are numeric. + // Make sure all specified columns are numeric calculation supported columns. colNames.map { colName => val namedExpr = df.resolve(colName) if (!TypeCollection.NumericAndInterval.acceptsType(namedExpr.dataType)) { From 73b217a6fd7d26a033619e81743227ee7f0101fc Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 26 Nov 2019 23:23:29 +0800 Subject: [PATCH 3/4] refine --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b29598e3a547f..77a779a2f3105 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -269,10 +269,8 @@ class Dataset[T] private[sql]( } private[sql] def numericCalculationSupportedColumns: Seq[Expression] = { - schema.fields.filter{ f => - TypeCollection.NumericAndInterval.acceptsType(f.dataType) - }.map { n => - queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get + queryExecution.analyzed.output.filter { attr => + TypeCollection.NumericAndInterval.acceptsType(attr.dataType) } } From fe84ae8bf071887865fbf1583d6144db95002a7d Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Wed, 27 Nov 2019 10:09:58 +0800 Subject: [PATCH 4/4] fix test name --- .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 4635d3de878a8..889981eb76620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -944,7 +944,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSparkSession { } } - test("calendar interval agg support hash aggregate") { + test("Dataset agg functions support calendar intervals") { val df1 = Seq((1, "1 day"), (2, "2 day"), (3, "3 day"), (3, null)).toDF("a", "b") val df2 = df1.select('a, 'b cast CalendarIntervalType).groupBy('a % 2) checkAnswer(df2.sum("b"),