Skip to content

Commit 3f3d024

Browse files
ulysses-youyaooqinn
authored andcommitted
[SPARK-49205][SQL] KeyGroupedPartitioning should inherit HashPartitioningLike
### What changes were proposed in this pull request? This pr makes `KeyGroupedPartitioning` inherit `HashPartitioningLike`, so that the `BroadcastHashJoin#expandOutputPartitioning` and `PartitioningPreservingUnaryExecNode` can work with it. ### Why are the changes needed? To make `KeyGroupedPartitioning` support alias aware framework. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? add test ### Was this patch authored or co-authored using generative AI tooling? no Closes #47734 from ulysses-you/SPARK-49205-partitioning. Authored-by: ulysses-you <[email protected]> Signed-off-by: Kent Yao <[email protected]>
1 parent d82c695 commit 3f3d024

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ case class KeyGroupedPartitioning(
370370
expressions: Seq[Expression],
371371
numPartitions: Int,
372372
partitionValues: Seq[InternalRow] = Seq.empty,
373-
originalPartitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning {
373+
originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike {
374374

375375
override def satisfies0(required: Distribution): Boolean = {
376376
super.satisfies0(required) || {
@@ -421,6 +421,9 @@ case class KeyGroupedPartitioning(
421421
.distinct
422422
.map(_.row)
423423
}
424+
425+
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
426+
copy(expressions = newChildren)
424427
}
425428

426429
object KeyGroupedPartitioning {
@@ -766,8 +769,8 @@ case class CoalescedHashShuffleSpec(
766769
*
767770
* @param partitioning key grouped partitioning
768771
* @param distribution distribution
769-
* @param joinKeyPosition position of join keys among cluster keys.
770-
* This is set if joining on a subset of cluster keys is allowed.
772+
* @param joinKeyPositions position of join keys among cluster keys.
773+
* This is set if joining on a subset of cluster keys is allowed.
771774
*/
772775
case class KeyGroupedShuffleSpec(
773776
partitioning: KeyGroupedPartitioning,

sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717
package org.apache.spark.sql.connector
1818

19+
import java.sql.Timestamp
1920
import java.util.Collections
2021

2122
import org.apache.spark.SparkConf
@@ -583,6 +584,55 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
583584
}
584585
}
585586

587+
test("SPARK-49205: KeyGroupedPartitioning should inherit HashPartitioningLike") {
588+
val items_partitions = Array(days("arrive_time"))
589+
createTable(items, itemsColumns, items_partitions)
590+
sql(s"INSERT INTO testcat.ns.$items VALUES " +
591+
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
592+
"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
593+
"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
594+
"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " +
595+
"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")
596+
597+
val purchases_partitions = Array(days("time"))
598+
createTable(purchases, purchasesColumns, purchases_partitions)
599+
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
600+
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
601+
"(1, 44.0, cast('2020-01-15' as timestamp)), " +
602+
"(1, 45.0, cast('2020-01-15' as timestamp)), " +
603+
"(2, 11.0, cast('2020-01-01' as timestamp)), " +
604+
"(3, 19.5, cast('2020-02-01' as timestamp))")
605+
606+
val df = sql(
607+
s"""
608+
|SELECT x, count(*) FROM (
609+
| SELECT /*+ broadcast(t2) */ arrive_time as x, * FROM testcat.ns.$items t1
610+
| JOIN testcat.ns.$purchases t2 ON t1.arrive_time = t2.time
611+
|)
612+
|GROUP BY x
613+
|""".stripMargin)
614+
checkAnswer(df,
615+
Seq(Row(Timestamp.valueOf("2020-01-01 00:00:00"), 6),
616+
Row(Timestamp.valueOf("2020-01-15 00:00:00"), 2),
617+
Row(Timestamp.valueOf("2020-02-01 00:00:00"), 1)))
618+
assert(collectAllShuffles(df.queryExecution.executedPlan).isEmpty)
619+
620+
val df2 = sql(
621+
s"""
622+
|WITH t1 (SELECT * FROM testcat.ns.$items)
623+
|SELECT x, count(*) FROM (
624+
| SELECT /*+ broadcast(t2) */ t2.time as x FROM t1
625+
| JOIN testcat.ns.$purchases t2 ON t1.arrive_time = t2.time
626+
| JOIN t1 t3 ON t1.arrive_time = t3.arrive_time
627+
|) GROUP BY x
628+
|""".stripMargin)
629+
checkAnswer(df2,
630+
Seq(Row(Timestamp.valueOf("2020-01-01 00:00:00"), 18),
631+
Row(Timestamp.valueOf("2020-01-15 00:00:00"), 2),
632+
Row(Timestamp.valueOf("2020-02-01 00:00:00"), 1)))
633+
assert(collectAllShuffles(df2.queryExecution.executedPlan).isEmpty)
634+
}
635+
586636
test("SPARK-42038: partially clustered: with same partition keys and one side fully clustered") {
587637
val items_partitions = Array(identity("id"))
588638
createTable(items, itemsColumns, items_partitions)

0 commit comments

Comments
 (0)