Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,15 @@ object SQLConf {
.intConf
.createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get)

val SHUFFLED_JOIN_CHILDREN_PARTITIONING_DETECTION =
buildConf("spark.sql.shuffledJoin.childrenPartitioningDetection")
.internal()
.doc("When true, sort merge join and shuffled hash join will detect children data " +
"partitioning to avoid shuffle, it is helpful when join keys are a super-set of " +
"bucket keys")
.booleanConf
.createWithDefault(true)

val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD =
buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold")
.internal()
Expand Down Expand Up @@ -1669,6 +1678,9 @@ class SQLConf extends Serializable with Logging {

def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD)

def shuffledJoinChildrenPartitioningDetection: Boolean =
getConf(SHUFFLED_JOIN_CHILDREN_PARTITIONING_DETECTION)

def sortMergeJoinExecBufferInMemoryThreshold: Int =
getConf(SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, HashClusteredDistribution, HashPartitioning}
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec

object JoinUtils {
private def avoidShuffleIfPossible(
joinKeys: Seq[Expression],
expressions: Seq[Expression],
leftKeys: Seq[Expression],
rightKeys: Seq[Expression]): Seq[Distribution] = {
val indices = expressions.map(x => joinKeys.indexWhere(_.semanticEquals(x)))
HashClusteredDistribution(indices.map(leftKeys(_))) ::
HashClusteredDistribution(indices.map(rightKeys(_))) :: Nil
}

def requiredChildDistributionForShuffledJoin(
partitioningDetection: Boolean,
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan): Seq[Distribution] = {
if (!partitioningDetection) {
return HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
}

val leftPartitioning = left.outputPartitioning
val rightPartitioning = right.outputPartitioning
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my biggest concern. Currently Spark adds shuffle with a rule, so we can't always get the children partitioning precisely. We implemented a similar feature in EnsureRequirements.reorderJoinPredicates, which is hacky and we should improve the framework before adding more features like this.

Copy link
Contributor Author

@yucai yucai Oct 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan in this PR, requiredChildDistribution is always re-calculated each time it is invoked, could it be more precise than EnsureRequirements.reorderJoinPredicates

This kind of bucketjoin is common, do we have a plan to improve the framework in 3.0?

leftPartitioning match {
case HashPartitioning(leftExpressions, _)
if leftPartitioning.satisfies(ClusteredDistribution(leftKeys)) =>
avoidShuffleIfPossible(leftKeys, leftExpressions, leftKeys, rightKeys)

case _ => rightPartitioning match {
case HashPartitioning(rightExpressions, _)
if rightPartitioning.satisfies(ClusteredDistribution(rightKeys)) =>
avoidShuffleIfPossible(rightKeys, rightExpressions, leftKeys, rightKeys)

case _ =>
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ case class ShuffledHashJoinExec(
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"),
"avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))

override def requiredChildDistribution: Seq[Distribution] =
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
override def requiredChildDistribution: Seq[Distribution] = {
JoinUtils.requiredChildDistributionForShuffledJoin(
conf.shuffledJoinChildrenPartitioningDetection, leftKeys, rightKeys, left, right)
}

private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
val buildDataSize = longMetric("buildDataSize")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ case class SortMergeJoinExec(
s"${getClass.getSimpleName} should not take $x as the JoinType")
}

override def requiredChildDistribution: Seq[Distribution] =
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
override def requiredChildDistribution: Seq[Distribution] = {
JoinUtils.requiredChildDistributionForShuffledJoin(
conf.shuffledJoinChildrenPartitioningDetection, leftKeys, rightKeys, left, right)
}

override def outputOrdering: Seq[SortOrder] = joinType match {
// For inner join, orders of both sides keys should be kept.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,16 +408,26 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
)
}

// Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704
ignore("avoid shuffle when join keys are a super-set of bucket keys") {
val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
val bucketedTableTestSpecLeft = BucketedTableTestSpec(bucketSpec, expectedShuffle = false)
val bucketedTableTestSpecRight = BucketedTableTestSpec(bucketSpec, expectedShuffle = false)
testBucketing(
bucketedTableTestSpecLeft = bucketedTableTestSpecLeft,
bucketedTableTestSpecRight = bucketedTableTestSpecRight,
joinCondition = joinCondition(Seq("i", "j"))
)
test("avoid shuffle when join keys are a super-set of bucket keys") {
Seq("i", "j").foreach { bucketColumn =>
val bucketSpec = Some(BucketSpec(8, Seq(s"$bucketColumn"), Nil))

val bucketedTableTestSpecLeft1 = BucketedTableTestSpec(bucketSpec, expectedShuffle = false)
val bucketedTableTestSpecRight1 = BucketedTableTestSpec(bucketSpec, expectedShuffle = false)
testBucketing(
bucketedTableTestSpecLeft = bucketedTableTestSpecLeft1,
bucketedTableTestSpecRight = bucketedTableTestSpecRight1,
joinCondition = joinCondition(Seq("i", "j"))
)

val bucketedTableTestSpecLeft2 = BucketedTableTestSpec(bucketSpec, expectedShuffle = false)
val bucketedTableTestSpecRight2 = BucketedTableTestSpec(None, expectedShuffle = true)
testBucketing(
bucketedTableTestSpecLeft = bucketedTableTestSpecLeft2,
bucketedTableTestSpecRight = bucketedTableTestSpecRight2,
joinCondition = joinCondition(Seq("i", "j"))
)
}
}

test("only shuffle one side when join bucketed table and non-bucketed table") {
Expand Down Expand Up @@ -647,8 +657,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
"partitioning columns") {

// join predicates is a super set of child's partitioning columns
val bucketedTableTestSpec1 =
BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1)
val bucketedTableTestSpec1 = BucketedTableTestSpec(
Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))),
numPartitions = 1,
expectedShuffle = false)
testBucketing(
bucketedTableTestSpecLeft = bucketedTableTestSpec1,
bucketedTableTestSpecRight = bucketedTableTestSpec1,
Expand Down