diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 50965c1abc68..90ab0e0df011 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1092,6 +1092,15 @@ object SQLConf { .intConf .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + val SHUFFLED_JOIN_CHILDREN_PARTITIONING_DETECTION = + buildConf("spark.sql.shuffledJoin.childrenPartitioningDetection") + .internal() + .doc("When true, sort merge join and shuffled hash join will detect children data " + + "partitioning to avoid shuffle, it is helpful when join keys are a super-set of " + + "bucket keys") + .booleanConf + .createWithDefault(true) + val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD = buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold") .internal() @@ -1669,6 +1678,9 @@ class SQLConf extends Serializable with Logging { def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) + def shuffledJoinChildrenPartitioningDetection: Boolean = + getConf(SHUFFLED_JOIN_CHILDREN_PARTITIONING_DETECTION) + def sortMergeJoinExecBufferInMemoryThreshold: Int = getConf(SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinUtils.scala new file mode 100644 index 000000000000..d5a9f0235787 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinUtils.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, HashClusteredDistribution, HashPartitioning} +import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec + +object JoinUtils { + private def avoidShuffleIfPossible( + joinKeys: Seq[Expression], + expressions: Seq[Expression], + leftKeys: Seq[Expression], + rightKeys: Seq[Expression]): Seq[Distribution] = { + val indices = expressions.map(x => joinKeys.indexWhere(_.semanticEquals(x))) + HashClusteredDistribution(indices.map(leftKeys(_))) :: + HashClusteredDistribution(indices.map(rightKeys(_))) :: Nil + } + + def requiredChildDistributionForShuffledJoin( + partitioningDetection: Boolean, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan): Seq[Distribution] = { + if (!partitioningDetection) { + return HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil + } + + val leftPartitioning = left.outputPartitioning + val rightPartitioning = right.outputPartitioning + leftPartitioning match { + case HashPartitioning(leftExpressions, _) + if leftPartitioning.satisfies(ClusteredDistribution(leftKeys)) => + avoidShuffleIfPossible(leftKeys, leftExpressions, leftKeys, rightKeys) + + case _ => rightPartitioning match { + case HashPartitioning(rightExpressions, _) + if rightPartitioning.satisfies(ClusteredDistribution(rightKeys)) => + avoidShuffleIfPossible(rightKeys, rightExpressions, leftKeys, rightKeys) + + case _ => + HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 897a4dae39f3..43ac22a61f8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -45,8 +45,10 @@ case class ShuffledHashJoinExec( "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"), "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) - override def requiredChildDistribution: Seq[Distribution] = - HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil + override def requiredChildDistribution: Seq[Distribution] = { + JoinUtils.requiredChildDistributionForShuffledJoin( + conf.shuffledJoinChildrenPartitioningDetection, leftKeys, rightKeys, left, right) + } private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val buildDataSize = longMetric("buildDataSize") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index f4b9d132122e..c08337ca4a26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -77,8 +77,10 @@ case class SortMergeJoinExec( s"${getClass.getSimpleName} should not take $x as the JoinType") } - override def requiredChildDistribution: Seq[Distribution] = - HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil + override def requiredChildDistribution: Seq[Distribution] = { + JoinUtils.requiredChildDistributionForShuffledJoin( + conf.shuffledJoinChildrenPartitioningDetection, leftKeys, rightKeys, left, right) + } override def outputOrdering: Seq[SortOrder] = joinType match { // For inner join, orders of both sides keys should be kept. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index a9414200e70f..aae20c2928b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -408,16 +408,26 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { ) } - // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704 - ignore("avoid shuffle when join keys are a super-set of bucket keys") { - val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) - val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) - val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) - testBucketing( - bucketedTableTestSpecLeft = bucketedTableTestSpecLeft, - bucketedTableTestSpecRight = bucketedTableTestSpecRight, - joinCondition = joinCondition(Seq("i", "j")) - ) + test("avoid shuffle when join keys are a super-set of bucket keys") { + Seq("i", "j").foreach { bucketColumn => + val bucketSpec = Some(BucketSpec(8, Seq(s"$bucketColumn"), Nil)) + + val bucketedTableTestSpecLeft1 = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) + val bucketedTableTestSpecRight1 = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft1, + bucketedTableTestSpecRight = bucketedTableTestSpecRight1, + joinCondition = joinCondition(Seq("i", "j")) + ) + + val bucketedTableTestSpecLeft2 = BucketedTableTestSpec(bucketSpec, expectedShuffle = false) + val bucketedTableTestSpecRight2 = BucketedTableTestSpec(None, expectedShuffle = true) + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpecLeft2, + bucketedTableTestSpecRight = bucketedTableTestSpecRight2, + joinCondition = joinCondition(Seq("i", "j")) + ) + } } test("only shuffle one side when join bucketed table and non-bucketed table") { @@ -647,8 +657,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { "partitioning columns") { // join predicates is a super set of child's partitioning columns - val bucketedTableTestSpec1 = - BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1) + val bucketedTableTestSpec1 = BucketedTableTestSpec( + Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), + numPartitions = 1, + expectedShuffle = false) testBucketing( bucketedTableTestSpecLeft = bucketedTableTestSpec1, bucketedTableTestSpecRight = bucketedTableTestSpec1,