-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-28560][SQL][followup] support the build side to local shuffle reader as far as possible in BroadcastHashJoin #26289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5b7ff2d
e13f637
782827a
f3bb9ce
152aaa6
e510e96
1e947db
573ffcd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,48 +24,77 @@ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartit | |
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} | ||
| import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} | ||
| import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight} | ||
| import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, BuildSide} | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
||
| case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] { | ||
|
|
||
| def canUseLocalShuffleReaderLeft(join: BroadcastHashJoinExec): Boolean = { | ||
| join.buildSide == BuildRight && ShuffleQueryStageExec.isShuffleQueryStageExec(join.left) | ||
| object BroadcastJoinWithShuffleLeft { | ||
| def unapply(plan: SparkPlan): Option[(QueryStageExec, BuildSide)] = plan match { | ||
| case join: BroadcastHashJoinExec if ShuffleQueryStageExec.isShuffleQueryStageExec(join.left) => | ||
| Some((join.left.asInstanceOf[QueryStageExec], join.buildSide)) | ||
| case _ => None | ||
| } | ||
| } | ||
|
|
||
| def canUseLocalShuffleReaderRight(join: BroadcastHashJoinExec): Boolean = { | ||
| join.buildSide == BuildLeft && ShuffleQueryStageExec.isShuffleQueryStageExec(join.right) | ||
| object BroadcastJoinWithShuffleRight { | ||
| def unapply(plan: SparkPlan): Option[(QueryStageExec, BuildSide)] = plan match { | ||
| case join: BroadcastHashJoinExec if ShuffleQueryStageExec.isShuffleQueryStageExec(join.right) => | ||
| Some((join.right.asInstanceOf[QueryStageExec], join.buildSide)) | ||
| case _ => None | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A rule to optimize the shuffle reader to local reader as far as possible | ||
| * when converting the 'SortMergeJoinExec' to 'BroadcastHashJoinExec' in runtime. | ||
| * | ||
| * This rule can be divided into two steps: | ||
| * Step1: Add the local reader in probe side and then check whether additional | ||
| * shuffle introduced. If introduced, we will revert all the local | ||
| * reader in probe side. | ||
| * Step2: Add the local reader in build side and will not check whether | ||
| * additional shuffle introduced. Because the build side will not introduce | ||
| * additional shuffle. | ||
| */ | ||
| case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] { | ||
|
|
||
| override def apply(plan: SparkPlan): SparkPlan = { | ||
| if (!conf.getConf(SQLConf.OPTIMIZE_LOCAL_SHUFFLE_READER_ENABLED)) { | ||
| return plan | ||
| } | ||
|
|
||
| val optimizedPlan = plan.transformDown { | ||
| case join: BroadcastHashJoinExec if canUseLocalShuffleReaderRight(join) => | ||
| val localReader = LocalShuffleReaderExec(join.right.asInstanceOf[QueryStageExec]) | ||
| join.copy(right = localReader) | ||
| case join: BroadcastHashJoinExec if canUseLocalShuffleReaderLeft(join) => | ||
| val localReader = LocalShuffleReaderExec(join.left.asInstanceOf[QueryStageExec]) | ||
| join.copy(left = localReader) | ||
| // Add local reader in probe side. | ||
| val withProbeSideLocalReader = plan.transformDown { | ||
| case join @ BroadcastJoinWithShuffleLeft(shuffleStage, BuildRight) => | ||
| val localReader = LocalShuffleReaderExec(shuffleStage) | ||
| join.asInstanceOf[BroadcastHashJoinExec].copy(left = localReader) | ||
| case join @ BroadcastJoinWithShuffleRight(shuffleStage, BuildLeft) => | ||
| val localReader = LocalShuffleReaderExec(shuffleStage) | ||
| join.asInstanceOf[BroadcastHashJoinExec].copy(right = localReader) | ||
| } | ||
|
|
||
| def numExchanges(plan: SparkPlan): Int = { | ||
| plan.collect { | ||
| case e: ShuffleExchangeExec => e | ||
| }.length | ||
| } | ||
|
|
||
| // Check whether additional shuffle introduced. If introduced, revert the local reader. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now this rule converts local shuffle reader for all BroadcastHashJoinExec and then reverts all local shuffle readers if any of local shuffle reader causes additional shuffle. Can we just revert the local shuffle readers that cause additional shuffle and keep these not?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the best, but I don't know if there is an easy way to do it.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can implement using revert all the local reader currently and re-optimize later when we find a better way. |
||
| val numExchangeBefore = numExchanges(EnsureRequirements(conf).apply(plan)) | ||
| val numExchangeAfter = numExchanges(EnsureRequirements(conf).apply(optimizedPlan)) | ||
|
|
||
| if (numExchangeAfter > numExchangeBefore) { | ||
| logDebug("OptimizeLocalShuffleReader rule is not applied due" + | ||
| val numExchangeAfter = numExchanges(EnsureRequirements(conf).apply(withProbeSideLocalReader)) | ||
| val optimizedPlan = if (numExchangeAfter > numExchangeBefore) { | ||
| logDebug("OptimizeLocalShuffleReader rule is not applied in the probe side due" + | ||
| " to additional shuffles will be introduced.") | ||
| plan | ||
| } else { | ||
| optimizedPlan | ||
| withProbeSideLocalReader | ||
| } | ||
| // Add the local reader in build side and and do not need to check whether | ||
| // additional shuffle introduced. | ||
| optimizedPlan.transformDown { | ||
| case join @ BroadcastJoinWithShuffleLeft(shuffleStage, BuildLeft) => | ||
| val localReader = LocalShuffleReaderExec(shuffleStage) | ||
| join.asInstanceOf[BroadcastHashJoinExec].copy(left = localReader) | ||
| case join @ BroadcastJoinWithShuffleRight(shuffleStage, BuildRight) => | ||
| val localReader = LocalShuffleReaderExec(shuffleStage) | ||
| join.asInstanceOf[BroadcastHashJoinExec].copy(right = localReader) | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -108,25 +137,4 @@ case class LocalShuffleReaderExec(child: QueryStageExec) extends UnaryExecNode { | |
| } | ||
| cachedShuffleRDD | ||
| } | ||
|
|
||
| override def generateTreeString( | ||
| depth: Int, | ||
| lastChildren: Seq[Boolean], | ||
| append: String => Unit, | ||
| verbose: Boolean, | ||
| prefix: String = "", | ||
| addSuffix: Boolean = false, | ||
| maxFields: Int, | ||
| printNodeId: Boolean): Unit = { | ||
| super.generateTreeString(depth, | ||
| lastChildren, | ||
| append, | ||
| verbose, | ||
| prefix, | ||
| addSuffix, | ||
| maxFields, | ||
| printNodeId) | ||
| child.generateTreeString( | ||
| depth + 1, lastChildren :+ true, append, verbose, "", false, maxFields, printNodeId) | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.