diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index f01947d8f5ed..8d4731f34ddd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -31,13 +31,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ -import org.apache.spark.sql.execution.adaptive.rule.ReduceNumShufflePartitions import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index d8dd7224fef3..4ddb2154116e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -24,31 +24,51 @@ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartit import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.internal.SQLConf -case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] { - - def canUseLocalShuffleReaderLeft(join: BroadcastHashJoinExec): Boolean = { - join.buildSide == BuildRight && ShuffleQueryStageExec.isShuffleQueryStageExec(join.left) +object BroadcastJoinWithShuffleLeft { + def unapply(plan: SparkPlan): Option[(QueryStageExec, BuildSide)] = plan match { + case join: BroadcastHashJoinExec if ShuffleQueryStageExec.isShuffleQueryStageExec(join.left) => + Some((join.left.asInstanceOf[QueryStageExec], join.buildSide)) + case _ => None } +} - def canUseLocalShuffleReaderRight(join: BroadcastHashJoinExec): Boolean = { - join.buildSide == BuildLeft && ShuffleQueryStageExec.isShuffleQueryStageExec(join.right) +object BroadcastJoinWithShuffleRight { + def unapply(plan: SparkPlan): Option[(QueryStageExec, BuildSide)] = plan match { + case join: BroadcastHashJoinExec if ShuffleQueryStageExec.isShuffleQueryStageExec(join.right) => + Some((join.right.asInstanceOf[QueryStageExec], join.buildSide)) + case _ => None } +} + +/** + * A rule to optimize the shuffle reader to local reader as far as possible + * when converting the 'SortMergeJoinExec' to 'BroadcastHashJoinExec' in runtime. + * + * This rule can be divided into two steps: + * Step1: Add the local reader in probe side and then check whether additional + * shuffle introduced. If introduced, we will revert all the local + * reader in probe side. + * Step2: Add the local reader in build side and will not check whether + * additional shuffle introduced. Because the build side will not introduce + * additional shuffle. + */ +case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = { if (!conf.getConf(SQLConf.OPTIMIZE_LOCAL_SHUFFLE_READER_ENABLED)) { return plan } - - val optimizedPlan = plan.transformDown { - case join: BroadcastHashJoinExec if canUseLocalShuffleReaderRight(join) => - val localReader = LocalShuffleReaderExec(join.right.asInstanceOf[QueryStageExec]) - join.copy(right = localReader) - case join: BroadcastHashJoinExec if canUseLocalShuffleReaderLeft(join) => - val localReader = LocalShuffleReaderExec(join.left.asInstanceOf[QueryStageExec]) - join.copy(left = localReader) + // Add local reader in probe side. + val withProbeSideLocalReader = plan.transformDown { + case join @ BroadcastJoinWithShuffleLeft(shuffleStage, BuildRight) => + val localReader = LocalShuffleReaderExec(shuffleStage) + join.asInstanceOf[BroadcastHashJoinExec].copy(left = localReader) + case join @ BroadcastJoinWithShuffleRight(shuffleStage, BuildLeft) => + val localReader = LocalShuffleReaderExec(shuffleStage) + join.asInstanceOf[BroadcastHashJoinExec].copy(right = localReader) } def numExchanges(plan: SparkPlan): Int = { @@ -56,16 +76,25 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] { case e: ShuffleExchangeExec => e }.length } - + // Check whether additional shuffle introduced. If introduced, revert the local reader. val numExchangeBefore = numExchanges(EnsureRequirements(conf).apply(plan)) - val numExchangeAfter = numExchanges(EnsureRequirements(conf).apply(optimizedPlan)) - - if (numExchangeAfter > numExchangeBefore) { - logDebug("OptimizeLocalShuffleReader rule is not applied due" + + val numExchangeAfter = numExchanges(EnsureRequirements(conf).apply(withProbeSideLocalReader)) + val optimizedPlan = if (numExchangeAfter > numExchangeBefore) { + logDebug("OptimizeLocalShuffleReader rule is not applied in the probe side due" + " to additional shuffles will be introduced.") plan } else { - optimizedPlan + withProbeSideLocalReader + } + // Add the local reader in build side and and do not need to check whether + // additional shuffle introduced. + optimizedPlan.transformDown { + case join @ BroadcastJoinWithShuffleLeft(shuffleStage, BuildLeft) => + val localReader = LocalShuffleReaderExec(shuffleStage) + join.asInstanceOf[BroadcastHashJoinExec].copy(left = localReader) + case join @ BroadcastJoinWithShuffleRight(shuffleStage, BuildRight) => + val localReader = LocalShuffleReaderExec(shuffleStage) + join.asInstanceOf[BroadcastHashJoinExec].copy(right = localReader) } } } @@ -108,25 +137,4 @@ case class LocalShuffleReaderExec(child: QueryStageExec) extends UnaryExecNode { } cachedShuffleRDD } - - override def generateTreeString( - depth: Int, - lastChildren: Seq[Boolean], - append: String => Unit, - verbose: Boolean, - prefix: String = "", - addSuffix: Boolean = false, - maxFields: Int, - printNodeId: Boolean): Unit = { - super.generateTreeString(depth, - lastChildren, - append, - verbose, - prefix, - addSuffix, - maxFields, - printNodeId) - child.generateTreeString( - depth + 1, lastChildren :+ true, append, verbose, "", false, maxFields, printNodeId) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala index 5a505c213a26..3b02ddadd2da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.adaptive.rule +package org.apache.spark.sql.execution.adaptive import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration.Duration @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.adaptive.{LocalShuffleReaderExec, QueryStageExec, ReusedQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.ThreadUtils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala index b5dbdd0b18b4..4d408cd8ebd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.sql._ import org.apache.spark.sql.execution.adaptive._ -import org.apache.spark.sql.execution.adaptive.rule.{CoalescedShuffleReaderExec, ReduceNumShufflePartitions} +import org.apache.spark.sql.execution.adaptive.{CoalescedShuffleReaderExec, ReduceNumShufflePartitions} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 649467a27d93..b140b08950db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -93,7 +93,7 @@ class AdaptiveQueryExecSuite assert(smj.size == 1) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan, 1) + checkNumLocalShuffleReaders(adaptivePlan, 2) } } @@ -110,7 +110,7 @@ class AdaptiveQueryExecSuite val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan, 1) + checkNumLocalShuffleReaders(adaptivePlan, 2) } } @@ -125,7 +125,7 @@ class AdaptiveQueryExecSuite assert(smj.size == 1) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan, 1) + checkNumLocalShuffleReaders(adaptivePlan, 2) } } @@ -141,7 +141,7 @@ class AdaptiveQueryExecSuite val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan, 1) + checkNumLocalShuffleReaders(adaptivePlan, 2) } } @@ -163,9 +163,28 @@ class AdaptiveQueryExecSuite assert(smj.size == 3) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 3) - // The child of remaining one BroadcastHashJoin is not ShuffleQueryStage. - // So only two LocalShuffleReader. - checkNumLocalShuffleReaders(adaptivePlan, 2) + // BroadcastHashJoin + // +- BroadcastExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastHashJoin + // +- BroadcastExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastHashJoin + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastExchange + // +-LocalShuffleReader* + // +- ShuffleExchange + + // After applied the 'OptimizeLocalShuffleReader' rule, we can convert all the four + // shuffle reader to local shuffle reader in the bottom two 'BroadcastHashJoin'. + // For the top level 'BroadcastHashJoin', the probe side is not shuffle query stage + // and the build side shuffle query stage is also converted to local shuffle reader. + checkNumLocalShuffleReaders(adaptivePlan, 5) } } @@ -189,9 +208,24 @@ class AdaptiveQueryExecSuite assert(smj.size == 3) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 3) - // The child of remaining two BroadcastHashJoin is not ShuffleQueryStage. - // So only two LocalShuffleReader. - checkNumLocalShuffleReaders(adaptivePlan, 1) + // BroadcastHashJoin + // +- BroadcastExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastHashJoin + // +- BroadcastExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastHashJoin + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastExchange + // +-HashAggregate + // +- CoalescedShuffleReader + // +- ShuffleExchange + checkNumLocalShuffleReaders(adaptivePlan, 4) } } @@ -215,9 +249,25 @@ class AdaptiveQueryExecSuite assert(smj.size == 3) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 3) - // The child of remaining two BroadcastHashJoin is not ShuffleQueryStage. - // So only two LocalShuffleReader. - checkNumLocalShuffleReaders(adaptivePlan, 1) + // BroadcastHashJoin + // +- BroadcastExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastHashJoin + // +- BroadcastExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastHashJoin + // +- Filter + // +- HashAggregate + // +- CoalescedShuffleReader + // +- ShuffleExchange + // +- BroadcastExchange + // +-LocalShuffleReader* + // +- ShuffleExchange + checkNumLocalShuffleReaders(adaptivePlan, 4) } } @@ -232,8 +282,8 @@ class AdaptiveQueryExecSuite assert(smj.size == 3) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 2) - checkNumLocalShuffleReaders(adaptivePlan, 2) - // Even with local shuffle reader, the query statge reuse can also work. + checkNumLocalShuffleReaders(adaptivePlan, 4) + // Even with local shuffle reader, the query stage reuse can also work. val ex = findReusedExchange(adaptivePlan) assert(ex.size == 1) } @@ -250,8 +300,8 @@ class AdaptiveQueryExecSuite assert(smj.size == 1) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan, 1) - // Even with local shuffle reader, the query statge reuse can also work. + checkNumLocalShuffleReaders(adaptivePlan, 2) + // Even with local shuffle reader, the query stage reuse can also work. val ex = findReusedExchange(adaptivePlan) assert(ex.size == 1) } @@ -270,8 +320,8 @@ class AdaptiveQueryExecSuite assert(smj.size == 1) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan, 1) - // Even with local shuffle reader, the query statge reuse can also work. + checkNumLocalShuffleReaders(adaptivePlan, 2) + // Even with local shuffle reader, the query stage reuse can also work. val ex = findReusedExchange(adaptivePlan) assert(ex.nonEmpty) val sub = findReusedSubquery(adaptivePlan) @@ -291,8 +341,8 @@ class AdaptiveQueryExecSuite assert(smj.size == 1) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan, 1) - // Even with local shuffle reader, the query statge reuse can also work. + checkNumLocalShuffleReaders(adaptivePlan, 2) + // Even with local shuffle reader, the query stage reuse can also work. val ex = findReusedExchange(adaptivePlan) assert(ex.isEmpty) val sub = findReusedSubquery(adaptivePlan) @@ -315,8 +365,8 @@ class AdaptiveQueryExecSuite assert(smj.size == 1) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan, 1) - // Even with local shuffle reader, the query statge reuse can also work. + checkNumLocalShuffleReaders(adaptivePlan, 2) + // Even with local shuffle reader, the query stage reuse can also work. val ex = findReusedExchange(adaptivePlan) assert(ex.nonEmpty) assert(ex.head.plan.isInstanceOf[BroadcastQueryStageExec]) @@ -393,8 +443,9 @@ class AdaptiveQueryExecSuite assert(smj.size == 2) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - // additional shuffle exchange introduced, so revert OptimizeLocalShuffleReader rule. - checkNumLocalShuffleReaders(adaptivePlan, 0) + // Even additional shuffle exchange introduced, we still + // can convert the shuffle reader to local reader in build side. + checkNumLocalShuffleReaders(adaptivePlan, 1) } }