|
18 | 18 | package org.apache.spark.sql.execution |
19 | 19 |
|
20 | 20 | import org.apache.spark.rdd.RDD |
21 | | -import org.apache.spark.sql.{execution, DataFrame, Row} |
| 21 | +import org.apache.spark.sql.{execution, Row} |
22 | 22 | import org.apache.spark.sql.catalyst.InternalRow |
23 | 23 | import org.apache.spark.sql.catalyst.expressions._ |
24 | 24 | import org.apache.spark.sql.catalyst.plans.Inner |
25 | 25 | import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} |
26 | 26 | import org.apache.spark.sql.catalyst.plans.physical._ |
| 27 | +import org.apache.spark.sql.execution.aggregate.SortAggregateExec |
27 | 28 | import org.apache.spark.sql.execution.columnar.InMemoryRelation |
28 | 29 | import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange} |
29 | 30 | import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} |
@@ -70,6 +71,25 @@ class PlannerSuite extends SharedSQLContext { |
70 | 71 | s"The plan of query $query does not have partial aggregations.") |
71 | 72 | } |
72 | 73 |
|
| 74 | + test("SPARK-17289 sort-based partial aggregation needs a sort operator as a child") { |
| 75 | + withTempView("testSortBasedPartialAggregation") { |
| 76 | + val schema = StructType( |
| 77 | + StructField(s"key", IntegerType, true) :: StructField(s"value", StringType, true) :: Nil) |
| 78 | + val rowRDD = sparkContext.parallelize((0 until 1000).map(d => Row(d % 2, d.toString))) |
| 79 | + spark.createDataFrame(rowRDD, schema) |
| 80 | + .createOrReplaceTempView("testSortBasedPartialAggregation") |
| 81 | + |
| 82 | + // This test assumes a query below uses sort-based aggregations |
| 83 | + val planned = sql("SELECT MAX(value) FROM testSortBasedPartialAggregation GROUP BY key") |
| 84 | + .queryExecution.executedPlan |
| 85 | + // This line extracts both SortAggregate and Sort operators |
| 86 | + val extractedOps = planned.collect { case n if n.nodeName contains "Sort" => n } |
| 87 | + val aggOps = extractedOps.collect { case n if n.nodeName contains "SortAggregate" => n } |
| 88 | + assert(extractedOps.size == 4 && aggOps.size == 2, |
| 89 | + s"The plan $planned does not have correct sort-based partial aggregate pairs.") |
| 90 | + } |
| 91 | + } |
| 92 | + |
73 | 93 | test("non-partial aggregation for aggregates") { |
74 | 94 | withTempView("testNonPartialAggregation") { |
75 | 95 | val schema = StructType(StructField(s"value", IntegerType, true) :: Nil) |
|
0 commit comments