Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,11 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._
import org.apache.spark.sql.execution.adaptive.rule.ReduceNumShufflePartitions
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
import org.apache.spark.sql.internal.SQLConf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,48 +24,77 @@ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartit
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.internal.SQLConf

case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {

def canUseLocalShuffleReaderLeft(join: BroadcastHashJoinExec): Boolean = {
join.buildSide == BuildRight && ShuffleQueryStageExec.isShuffleQueryStageExec(join.left)
object BroadcastJoinWithShuffleLeft {
def unapply(plan: SparkPlan): Option[(QueryStageExec, BuildSide)] = plan match {
case join: BroadcastHashJoinExec if ShuffleQueryStageExec.isShuffleQueryStageExec(join.left) =>
Some((join.left.asInstanceOf[QueryStageExec], join.buildSide))
case _ => None
}
}

def canUseLocalShuffleReaderRight(join: BroadcastHashJoinExec): Boolean = {
join.buildSide == BuildLeft && ShuffleQueryStageExec.isShuffleQueryStageExec(join.right)
object BroadcastJoinWithShuffleRight {
def unapply(plan: SparkPlan): Option[(QueryStageExec, BuildSide)] = plan match {
case join: BroadcastHashJoinExec if ShuffleQueryStageExec.isShuffleQueryStageExec(join.right) =>
Some((join.right.asInstanceOf[QueryStageExec], join.buildSide))
case _ => None
}
}

/**
* A rule to optimize the shuffle reader to local reader as far as possible
* when converting the 'SortMergeJoinExec' to 'BroadcastHashJoinExec' in runtime.
*
* This rule can be divided into two steps:
* Step1: Add the local reader in probe side and then check whether additional
* shuffle introduced. If introduced, we will revert all the local
* reader in probe side.
* Step2: Add the local reader in build side and will not check whether
* additional shuffle introduced. Because the build side will not introduce
* additional shuffle.
*/
case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {

override def apply(plan: SparkPlan): SparkPlan = {
if (!conf.getConf(SQLConf.OPTIMIZE_LOCAL_SHUFFLE_READER_ENABLED)) {
return plan
}

val optimizedPlan = plan.transformDown {
case join: BroadcastHashJoinExec if canUseLocalShuffleReaderRight(join) =>
val localReader = LocalShuffleReaderExec(join.right.asInstanceOf[QueryStageExec])
join.copy(right = localReader)
case join: BroadcastHashJoinExec if canUseLocalShuffleReaderLeft(join) =>
val localReader = LocalShuffleReaderExec(join.left.asInstanceOf[QueryStageExec])
join.copy(left = localReader)
// Add local reader in probe side.
val withProbeSideLocalReader = plan.transformDown {
case join @ BroadcastJoinWithShuffleLeft(shuffleStage, BuildRight) =>
val localReader = LocalShuffleReaderExec(shuffleStage)
join.asInstanceOf[BroadcastHashJoinExec].copy(left = localReader)
case join @ BroadcastJoinWithShuffleRight(shuffleStage, BuildLeft) =>
val localReader = LocalShuffleReaderExec(shuffleStage)
join.asInstanceOf[BroadcastHashJoinExec].copy(right = localReader)
}

def numExchanges(plan: SparkPlan): Int = {
plan.collect {
case e: ShuffleExchangeExec => e
}.length
}

// Check whether additional shuffle introduced. If introduced, revert the local reader.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this rule converts local shuffle reader for all BroadcastHashJoinExec and then reverts all local shuffle readers if any of local shuffle reader causes additional shuffle.

Can we just revert the local shuffle readers that cause additional shuffle and keep these not?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the best, but I don't know if there is an easy way to do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can implement using revert all the local reader currently and re-optimize later when we find a better way.

val numExchangeBefore = numExchanges(EnsureRequirements(conf).apply(plan))
val numExchangeAfter = numExchanges(EnsureRequirements(conf).apply(optimizedPlan))

if (numExchangeAfter > numExchangeBefore) {
logDebug("OptimizeLocalShuffleReader rule is not applied due" +
val numExchangeAfter = numExchanges(EnsureRequirements(conf).apply(withProbeSideLocalReader))
val optimizedPlan = if (numExchangeAfter > numExchangeBefore) {
logDebug("OptimizeLocalShuffleReader rule is not applied in the probe side due" +
" to additional shuffles will be introduced.")
plan
} else {
optimizedPlan
withProbeSideLocalReader
}
// Add the local reader in build side and and do not need to check whether
// additional shuffle introduced.
optimizedPlan.transformDown {
case join @ BroadcastJoinWithShuffleLeft(shuffleStage, BuildLeft) =>
val localReader = LocalShuffleReaderExec(shuffleStage)
join.asInstanceOf[BroadcastHashJoinExec].copy(left = localReader)
case join @ BroadcastJoinWithShuffleRight(shuffleStage, BuildRight) =>
val localReader = LocalShuffleReaderExec(shuffleStage)
join.asInstanceOf[BroadcastHashJoinExec].copy(right = localReader)
}
}
}
Expand Down Expand Up @@ -108,25 +137,4 @@ case class LocalShuffleReaderExec(child: QueryStageExec) extends UnaryExecNode {
}
cachedShuffleRDD
}

override def generateTreeString(
depth: Int,
lastChildren: Seq[Boolean],
append: String => Unit,
verbose: Boolean,
prefix: String = "",
addSuffix: Boolean = false,
maxFields: Int,
printNodeId: Boolean): Unit = {
super.generateTreeString(depth,
lastChildren,
append,
verbose,
prefix,
addSuffix,
maxFields,
printNodeId)
child.generateTreeString(
depth + 1, lastChildren :+ true, append, verbose, "", false, maxFields, printNodeId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.sql.execution.adaptive.rule
package org.apache.spark.sql.execution.adaptive

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration.Duration
Expand All @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.adaptive.{LocalShuffleReaderExec, QueryStageExec, ReusedQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.ThreadUtils

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite}
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.sql._
import org.apache.spark.sql.execution.adaptive._
import org.apache.spark.sql.execution.adaptive.rule.{CoalescedShuffleReaderExec, ReduceNumShufflePartitions}
import org.apache.spark.sql.execution.adaptive.{CoalescedShuffleReaderExec, ReduceNumShufflePartitions}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -110,7 +110,7 @@ class AdaptiveQueryExecSuite
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)

checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -125,7 +125,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -141,7 +141,7 @@ class AdaptiveQueryExecSuite
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)

checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -163,9 +163,28 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 3)
// The child of remaining one BroadcastHashJoin is not ShuffleQueryStage.
// So only two LocalShuffleReader.
checkNumLocalShuffleReaders(adaptivePlan, 2)
// BroadcastHashJoin
// +- BroadcastExchange
// +- LocalShuffleReader*
// +- ShuffleExchange
// +- BroadcastHashJoin
// +- BroadcastExchange
// +- LocalShuffleReader*
// +- ShuffleExchange
// +- LocalShuffleReader*
// +- ShuffleExchange
// +- BroadcastHashJoin
// +- LocalShuffleReader*
// +- ShuffleExchange
// +- BroadcastExchange
// +-LocalShuffleReader*
// +- ShuffleExchange

// After applied the 'OptimizeLocalShuffleReader' rule, we can convert all the four
// shuffle reader to local shuffle reader in the bottom two 'BroadcastHashJoin'.
// For the top level 'BroadcastHashJoin', the probe side is not shuffle query stage
// and the build side shuffle query stage is also converted to local shuffle reader.
checkNumLocalShuffleReaders(adaptivePlan, 5)
}
}

Expand All @@ -189,9 +208,24 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 3)
// The child of remaining two BroadcastHashJoin is not ShuffleQueryStage.
// So only two LocalShuffleReader.
checkNumLocalShuffleReaders(adaptivePlan, 1)
// BroadcastHashJoin
// +- BroadcastExchange
// +- LocalShuffleReader*
// +- ShuffleExchange
// +- BroadcastHashJoin
// +- BroadcastExchange
// +- LocalShuffleReader*
// +- ShuffleExchange
// +- LocalShuffleReader*
// +- ShuffleExchange
// +- BroadcastHashJoin
// +- LocalShuffleReader*
// +- ShuffleExchange
// +- BroadcastExchange
// +-HashAggregate
// +- CoalescedShuffleReader
// +- ShuffleExchange
checkNumLocalShuffleReaders(adaptivePlan, 4)
}
}

Expand All @@ -215,9 +249,25 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 3)
// The child of remaining two BroadcastHashJoin is not ShuffleQueryStage.
// So only two LocalShuffleReader.
checkNumLocalShuffleReaders(adaptivePlan, 1)
// BroadcastHashJoin
// +- BroadcastExchange
// +- LocalShuffleReader*
// +- ShuffleExchange
// +- BroadcastHashJoin
// +- BroadcastExchange
// +- LocalShuffleReader*
// +- ShuffleExchange
// +- LocalShuffleReader*
// +- ShuffleExchange
// +- BroadcastHashJoin
// +- Filter
// +- HashAggregate
// +- CoalescedShuffleReader
// +- ShuffleExchange
// +- BroadcastExchange
// +-LocalShuffleReader*
// +- ShuffleExchange
checkNumLocalShuffleReaders(adaptivePlan, 4)
}
}

Expand All @@ -232,8 +282,8 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 2)
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query statge reuse can also work.
checkNumLocalShuffleReaders(adaptivePlan, 4)
// Even with local shuffle reader, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.size == 1)
}
Expand All @@ -250,8 +300,8 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
// Even with local shuffle reader, the query statge reuse can also work.
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.size == 1)
}
Expand All @@ -270,8 +320,8 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
// Even with local shuffle reader, the query statge reuse can also work.
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.nonEmpty)
val sub = findReusedSubquery(adaptivePlan)
Expand All @@ -291,8 +341,8 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
// Even with local shuffle reader, the query statge reuse can also work.
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.isEmpty)
val sub = findReusedSubquery(adaptivePlan)
Expand All @@ -315,8 +365,8 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
// Even with local shuffle reader, the query statge reuse can also work.
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.nonEmpty)
assert(ex.head.plan.isInstanceOf[BroadcastQueryStageExec])
Expand Down Expand Up @@ -393,8 +443,9 @@ class AdaptiveQueryExecSuite
assert(smj.size == 2)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
// additional shuffle exchange introduced, so revert OptimizeLocalShuffleReader rule.
checkNumLocalShuffleReaders(adaptivePlan, 0)
// Even additional shuffle exchange introduced, we still
// can convert the shuffle reader to local reader in build side.
checkNumLocalShuffleReaders(adaptivePlan, 1)
}
}

Expand Down