diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 2364130f79e4c..43aba478c37be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -385,8 +385,9 @@ case class KeyGroupedPartitioning( val attributes = expressions.flatMap(_.collectLeaves()) if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - // check that all join keys (required clustering keys) contained in partitioning - requiredClustering.forall(x => attributes.exists(_.semanticEquals(x))) && + // check that join keys (required clustering keys) + // overlap with partition keys (KeyGroupedPartitioning attributes) + requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && expressions.forall(_.collectLeaves().size == 1) } else { attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index ec275fe101fd6..10a32441b6cd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1227,6 +1227,66 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-48065: SPJ: allowJoinKeysSubsetOfPartitionKeys is too strict") { + val table1 = "tab1e1" + val table2 = "table2" + val partition = Array(identity("id")) + createTable(table1, columns, partition) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(1, 'aa', cast('2020-01-01' as timestamp)), " + + "(2, 'bb', cast('2020-01-01' as timestamp)), " + + "(2, 'cc', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp))") + + createTable(table2, columns, partition) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(2, 'ww', cast('2020-01-01' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.id AS id, t1.data AS t1data, t2.data AS t2data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.id = t2.id AND t1.data = t2.data ORDER BY t1.id, t1data, t2data + |""".stripMargin) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + if (partiallyClustered) { + assert(scans == Seq(8, 8)) + } else { + assert(scans == Seq(4, 4)) + } + checkAnswer(df, Seq( + Row(3, "dd", "dd"), + Row(3, "dd", "dd"), + Row(3, "dd", "dd"), + Row(3, "dd", "dd") + )) + } + } + } + } + test("SPARK-44647: test join key is subset of cluster key " + "with push values and partially-clustered") { val table1 = "tab1e1"