From 14262941d80f5f5b0075ee294769af3e81c2d7b6 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Wed, 20 Mar 2019 20:03:11 -0500 Subject: [PATCH 01/12] implement join strategy hints --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/analysis/ResolveHints.scala | 63 ++-- .../sql/catalyst/catalog/interface.scala | 2 +- .../optimizer/EliminateResolvedHint.scala | 5 +- .../sql/catalyst/plans/logical/hints.scala | 69 ++++- .../catalyst/analysis/ResolveHintsSuite.scala | 21 +- .../spark/sql/execution/SparkStrategies.scala | 92 ++++-- .../org/apache/spark/sql/functions.scala | 4 +- .../apache/spark/sql/CachedTableSuite.scala | 4 +- .../org/apache/spark/sql/JoinHintSuite.scala | 269 ++++++++++++++++-- .../execution/joins/BroadcastJoinSuite.scala | 5 +- 11 files changed, 443 insertions(+), 93 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e4cf43d43586..b7fc97063cfd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -146,7 +146,7 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Hints", fixedPoint, - new ResolveHints.ResolveBroadcastHints(conf), + new ResolveHints.ResolveJoinStrategyHints(conf), ResolveHints.ResolveCoalesceHints, ResolveHints.RemoveAllHints), Batch("Simple Sanity Check", Once, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index dbd4ed845e32..4231775682ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -28,45 +28,56 @@ import org.apache.spark.sql.internal.SQLConf /** - * Collection of rules related to hints. The only hint currently available is broadcast join hint. + * Collection of rules related to hints. The only hint currently available is join strategy hint. * * Note that this is separately into two rules because in the future we might introduce new hint - * rules that have different ordering requirements from broadcast. + * rules that have different ordering requirements from join strategies. */ object ResolveHints { /** - * For broadcast hint, we accept "BROADCAST", "BROADCASTJOIN", and "MAPJOIN", and a sequence of - * relation aliases can be specified in the hint. A broadcast hint plan node will be inserted - * on top of any relation (that is not aliased differently), subquery, or common table expression - * that match the specified name. + * The list of allowed join strategy hints is defined in [[JoinStrategyHint.strategies]], and a + * sequence of relation aliases can be specified with a join strategy hint, e.g., "MERGE(a, c)", + * "BROADCAST(a)". A join strategy hint plan node will be inserted on top of any relation (that + * is not aliased differently), subquery, or common table expression that match the specified + * name. * * The hint resolution works by recursively traversing down the query plan to find a relation or - * subquery that matches one of the specified broadcast aliases. The traversal does not go past - * beyond any existing broadcast hints, subquery aliases. + * subquery that matches one of the specified relation aliases. The traversal does not go past + * beyond any view reference, with clause or subquery alias. * * This rule must happen before common table expressions. */ - class ResolveBroadcastHints(conf: SQLConf) extends Rule[LogicalPlan] { - private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") + class ResolveJoinStrategyHints(conf: SQLConf) extends Rule[LogicalPlan] { + private val STRATEGY_HINT_NAMES = JoinStrategyHint.strategies.flatMap(_.hintAliases) def resolver: Resolver = conf.resolver - private def applyBroadcastHint(plan: LogicalPlan, toBroadcast: Set[String]): LogicalPlan = { + private def createHintInfo(hintName: String): HintInfo = { + HintInfo(strategy = + JoinStrategyHint.strategies.find( + _.hintAliases.map( + _.toUpperCase(Locale.ROOT)).contains(hintName.toUpperCase(Locale.ROOT)))) + } + + private def applyJoinStrategyHint( + plan: LogicalPlan, + relations: Set[String], + hintName: String): LogicalPlan = { // Whether to continue recursing down the tree var recurse = true val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { - case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) => - ResolvedHint(plan, HintInfo(broadcast = true)) - case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) => - ResolvedHint(plan, HintInfo(broadcast = true)) + case u: UnresolvedRelation if relations.exists(resolver(_, u.tableIdentifier.table)) => + ResolvedHint(plan, createHintInfo(hintName)) + case r: SubqueryAlias if relations.exists(resolver(_, r.alias)) => + ResolvedHint(plan, createHintInfo(hintName)) case _: ResolvedHint | _: View | _: With | _: SubqueryAlias => // Don't traverse down these nodes. - // For an existing broadcast hint, there is no point going down (if we do, we either - // won't change the structure, or will introduce another broadcast hint that is useless. + // For an existing strategy hint, there is no point going down (if we do, we either + // won't change the structure, or will introduce another strategy hint that is useless. // The rest (view, with, subquery) indicates different scopes that we shouldn't traverse // down. Note that technically when this rule is executed, we haven't completed view // resolution yet and as a result the view part should be deadcode. I'm leaving it here @@ -80,25 +91,25 @@ object ResolveHints { } if ((plan fastEquals newNode) && recurse) { - newNode.mapChildren(child => applyBroadcastHint(child, toBroadcast)) + newNode.mapChildren(child => applyJoinStrategyHint(child, relations, hintName)) } else { newNode } } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { - case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => + case h: UnresolvedHint if STRATEGY_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => if (h.parameters.isEmpty) { - // If there is no table alias specified, turn the entire subtree into a BroadcastHint. - ResolvedHint(h.child, HintInfo(broadcast = true)) + // If there is no table alias specified, apply the hint on the entire subtree. + ResolvedHint(h.child, createHintInfo(h.name)) } else { - // Otherwise, find within the subtree query plans that should be broadcasted. - applyBroadcastHint(h.child, h.parameters.map { + // Otherwise, find within the subtree query plans to apply the hint. + applyJoinStrategyHint(h.child, h.parameters.map { case tableName: String => tableName case tableId: UnresolvedAttribute => tableId.name - case unsupported => throw new AnalysisException("Broadcast hint parameter should be " + - s"an identifier or string but was $unsupported (${unsupported.getClass}") - }.toSet) + case unsupported => throw new AnalysisException("Join strategy hint parameter " + + s"should be an identifier or string but was $unsupported (${unsupported.getClass}") + }.toSet, h.name) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 60066371e3dc..2d646721f87a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -374,7 +374,7 @@ object CatalogTable { /** * This class of statistics is used in [[CatalogTable]] to interact with metastore. * We define this new class instead of directly using [[Statistics]] here because there are no - * concepts of attributes or broadcast hint in catalog. + * concepts of attributes in catalog. */ case class CatalogStatistics( sizeInBytes: BigInt, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala index a136f0493699..df50bd7840c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala @@ -40,8 +40,7 @@ object EliminateResolvedHint extends Rule[LogicalPlan] { } private def mergeHints(hints: Seq[HintInfo]): Option[HintInfo] = { - hints.reduceOption((h1, h2) => HintInfo( - broadcast = h1.broadcast || h2.broadcast)) + hints.reduceOption((h1, h2) => h1.merge(h2)) } private def collectHints(plan: LogicalPlan): Seq[HintInfo] = { @@ -50,7 +49,7 @@ object EliminateResolvedHint extends Rule[LogicalPlan] { case u: UnaryNode => collectHints(u.child) // TODO revisit this logic: // except and intersect are semi/anti-joins which won't return more data then - // their left argument, so the broadcast hint should be propagated here + // their left argument, so the join strategy hint should be propagated here case i: Intersect => collectHints(i.left) case e: Except => collectHints(e.left) case _ => Seq.empty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index b2ba725e9d44..ddd577aa137c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -66,17 +66,76 @@ object JoinHint { /** * The hint attributes to be applied on a specific node. * - * @param broadcast If set to true, it indicates that the broadcast hash join is the preferred join - * strategy and the node with this hint is preferred to be the build side. + * @param strategy The preferred join strategy. */ -case class HintInfo(broadcast: Boolean = false) { +case class HintInfo(strategy: Option[JoinStrategyHint] = None) { + + /** + * Combine two [[HintInfo]]s into one [[HintInfo]], in which the new strategy will the strategy + * in this [[HintInfo]] if defined, otherwise the strategy in the other [[HintInfo]]. + */ + def merge(other: HintInfo): HintInfo = { + HintInfo(strategy = this.strategy.orElse(other.strategy)) + } override def toString: String = { val hints = scala.collection.mutable.ArrayBuffer.empty[String] - if (broadcast) { - hints += "broadcast" + if (strategy.isDefined) { + hints += s"strategy=${strategy.get}" } if (hints.isEmpty) "none" else hints.mkString("(", ", ", ")") } } + +sealed abstract class JoinStrategyHint { + + def displayName: String + def hintAliases: Set[String] + + override def toString: String = displayName +} + +/** + * The enumeration of join strategy hints. + * + * The hinted strategy will be used for the join with which it is associated if doable. In case + * of contradicting strategy hints specified for each side of the join, hints are prioritized as + * BROADCAST over SHUFFLE_MERGE over SHUFFLE_HASH over SHUFFLE_REPLICATE_NL. + */ +object JoinStrategyHint { + + val strategies: Set[JoinStrategyHint] = Set( + BROADCAST, + SHUFFLE_MERGE, + SHUFFLE_HASH, + SHUFFLE_REPLICATE_NL) +} + +case object BROADCAST extends JoinStrategyHint { + override def displayName: String = "broadcast-hash" + override def hintAliases: Set[String] = Set( + "BROADCAST", + "BROADCASTJOIN", + "MAPJOIN") +} + +case object SHUFFLE_MERGE extends JoinStrategyHint { + override def displayName: String = "shuffle-merge" + override def hintAliases: Set[String] = Set( + "SHUFFLE_MERGE", + "MERGE", + "MERGEJOIN") +} + +case object SHUFFLE_HASH extends JoinStrategyHint { + override def displayName: String = "shuffle-hash" + override def hintAliases: Set[String] = Set( + "SHUFFLE_HASH") +} + +case object SHUFFLE_REPLICATE_NL extends JoinStrategyHint { + override def displayName: String = "shuffle-replicate-nested-loop" + override def hintAliases: Set[String] = Set( + "SHUFFLE_REPLICATE_NL") +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 563e8adf87ed..0429e7c0681b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -37,17 +37,17 @@ class ResolveHintsSuite extends AnalysisTest { test("case-sensitive or insensitive parameters") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), - ResolvedHint(testRelation, HintInfo(broadcast = true)), + ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")), - ResolvedHint(testRelation, HintInfo(broadcast = true)), + ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), - ResolvedHint(testRelation, HintInfo(broadcast = true)), + ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = true) checkAnalysis( @@ -59,28 +59,29 @@ class ResolveHintsSuite extends AnalysisTest { test("multiple broadcast hint aliases") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), - Join(ResolvedHint(testRelation, HintInfo(broadcast = true)), - ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None, JoinHint.NONE), + Join(ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), + ResolvedHint(testRelation2, HintInfo(strategy = Some(BROADCAST))), + Inner, None, JoinHint.NONE), caseSensitive = false) } test("do not traverse past existing broadcast hints") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table"), - ResolvedHint(table("table").where('a > 1), HintInfo(broadcast = true))), - ResolvedHint(testRelation.where('a > 1), HintInfo(broadcast = true)).analyze, + ResolvedHint(table("table").where('a > 1), HintInfo(strategy = Some(BROADCAST)))), + ResolvedHint(testRelation.where('a > 1), HintInfo(strategy = Some(BROADCAST))).analyze, caseSensitive = false) } test("should work for subqueries") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")), - ResolvedHint(testRelation, HintInfo(broadcast = true)), + ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)), - ResolvedHint(testRelation, HintInfo(broadcast = true)), + ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = false) // Negative case: if the alias doesn't match, don't match the original table name. @@ -105,7 +106,7 @@ class ResolveHintsSuite extends AnalysisTest { |SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable """.stripMargin ), - ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(broadcast = true)) + ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(strategy = Some(BROADCAST))) .select('a).analyze, caseSensitive = false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index bf38189c2fb5..d3eb67b008bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -211,18 +211,54 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { private def canBroadcastByHints( joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): Boolean = { - val buildLeft = canBuildLeft(joinType) && hint.leftHint.exists(_.broadcast) - val buildRight = canBuildRight(joinType) && hint.rightHint.exists(_.broadcast) + val buildLeft = + canBuildLeft(joinType) && hint.leftHint.exists(_.strategy.contains(BROADCAST)) + val buildRight = + canBuildRight(joinType) && hint.rightHint.exists(_.strategy.contains(BROADCAST)) buildLeft || buildRight } private def broadcastSideByHints( joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): BuildSide = { - val buildLeft = canBuildLeft(joinType) && hint.leftHint.exists(_.broadcast) - val buildRight = canBuildRight(joinType) && hint.rightHint.exists(_.broadcast) + val buildLeft = + canBuildLeft(joinType) && hint.leftHint.exists(_.strategy.contains(BROADCAST)) + val buildRight = + canBuildRight(joinType) && hint.rightHint.exists(_.strategy.contains(BROADCAST)) broadcastSide(buildLeft, buildRight, left, right) } + private def canShuffleHashByHints( + joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): Boolean = { + val buildLeft = + canBuildLeft(joinType) && hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH)) + val buildRight = + canBuildRight(joinType) && hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH)) + buildLeft || buildRight + } + + private def shuffleHashSideByHints( + joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): BuildSide = { + val buildLeft = + canBuildLeft(joinType) && hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH)) + val buildRight = + canBuildRight(joinType) && hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH)) + broadcastSide(buildLeft, buildRight, left, right) + } + + private def canShuffleMergeByHints( + leftKeys: Seq[Expression], hint: JoinHint): Boolean = { + val isOrderable = RowOrdering.isOrderable(leftKeys) + val hasMergeHint = + (hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) + || hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE))) + isOrderable && hasMergeHint + } + + private def shuffleReplicateNLByHints(hint: JoinHint): Boolean = { + (hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) + || hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL))) + } + private def canBroadcastBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) : Boolean = { val buildLeft = canBuildLeft(joinType) && canBroadcast(left) @@ -239,18 +275,47 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // --- BroadcastHashJoin -------------------------------------------------------------------- + // --- Hints specified, choose join strategy based on hints. -------------------------------- - // broadcast hints were specified + // broadcast hints specified with equi-join keys, use broadcast-hash case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) - if canBroadcastByHints(joinType, left, right, hint) => + if canBroadcastByHints(joinType, left, right, hint) => val buildSide = broadcastSideByHints(joinType, left, right, hint) Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) - // broadcast hints were not specified, so need to infer it from size and configuration. + // broadcast hints specified with no equi-join keys, use broadcast-nested-loop + case j @ logical.Join(left, right, joinType, condition, hint) + if canBroadcastByHints(joinType, left, right, hint) => + val buildSide = broadcastSideByHints(joinType, left, right, hint) + Seq(joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition)) + + // shuffle-merge hints specified + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) + if canShuffleMergeByHints(leftKeys, hint) => + Seq(joins.SortMergeJoinExec( + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right))) + + // shuffle-hash hints specified + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) + if canShuffleHashByHints(joinType, left, right, hint) => + val buildSide = shuffleHashSideByHints(joinType, left, right, hint) + Seq(joins.ShuffledHashJoinExec( + leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + + // shuffle-replicate-nl hints specified + case logical.Join(left, right, _: InnerLike, condition, hint) + if shuffleReplicateNLByHints(hint) => + Seq(joins.CartesianProductExec(planLater(left), planLater(right), condition)) + + + // --- No hints specified, choose join strategy based on size and configuration. ------------ + + // --- BroadcastHashJoin -------------------------------------------------------------------- + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) - if canBroadcastBySizes(joinType, left, right) => + if canBroadcastBySizes(joinType, left, right) => val buildSide = broadcastSideBySizes(joinType, left, right) Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) @@ -281,12 +346,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Without joining keys ------------------------------------------------------------ // Pick BroadcastNestedLoopJoin if one side could be broadcast - case j @ logical.Join(left, right, joinType, condition, hint) - if canBroadcastByHints(joinType, left, right, hint) => - val buildSide = broadcastSideByHints(joinType, left, right, hint) - joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), buildSide, joinType, condition) :: Nil - case j @ logical.Join(left, right, joinType, condition, _) if canBroadcastBySizes(joinType, left, right) => val buildSide = broadcastSideBySizes(joinType, left, right) @@ -299,7 +358,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Join(left, right, joinType, condition, hint) => val buildSide = broadcastSide( - hint.leftHint.exists(_.broadcast), hint.rightHint.exists(_.broadcast), left, right) + hint.leftHint.exists(_.strategy.contains(BROADCAST)), + hint.rightHint.exists(_.strategy.contains(BROADCAST)), left, right) // This join could be very slow or OOM joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f99186cabc26..2d7a1a214490 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedFunction} import org.apache.spark.sql.internal.SQLConf @@ -1045,7 +1045,7 @@ object functions { */ def broadcast[T](df: Dataset[T]): Dataset[T] = { Dataset[T](df.sparkSession, - ResolvedHint(df.logicalPlan, HintInfo(broadcast = true)))(df.exprEnc) + ResolvedHint(df.logicalPlan, HintInfo(strategy = Some(BROADCAST))))(df.exprEnc) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 2141be4d680f..0f7a25908aa1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Join} import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -940,7 +940,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext case Join(_, _, _, _, hint) => hint } assert(hint.size == 1) - assert(hint(0).leftHint.get.broadcast) + assert(hint(0).leftHint.get.strategy.contains(BROADCAST)) assert(hint(0).rightHint.isEmpty) // Clean-up diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala index 67f0f1a6fd23..76b503ca85fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -43,14 +44,14 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { verifyJoinHint( df.hint("broadcast").join(df, "id"), JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: Nil ) verifyJoinHint( df.join(df.hint("broadcast"), "id"), JoinHint( None, - Some(HintInfo(broadcast = true))) :: Nil + Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil ) } @@ -59,18 +60,18 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { df1.join(df2.hint("broadcast").join(df3, 'b1 === 'c1).hint("broadcast"), 'a1 === 'c1), JoinHint( None, - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: Nil ) verifyJoinHint( df1.hint("broadcast").join(df2, 'a1 === 'b1).hint("broadcast").join(df3, 'a1 === 'c1), JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: Nil ) } @@ -89,13 +90,13 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { |) b on a.a1 = b.b1 """.stripMargin), JoinHint( - Some(HintInfo(broadcast = true)), - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST))), + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint( None, - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: Nil ) } @@ -112,9 +113,9 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { "where a.a1 = b.b1 and b.b1 = c.c1"), JoinHint( None, - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: Nil ) verifyJoinHint( @@ -122,25 +123,25 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { "where a.a1 = b.b1 and b.b1 = c.c1"), JoinHint.NONE :: JoinHint( - Some(HintInfo(broadcast = true)), - Some(HintInfo(broadcast = true))) :: Nil + Some(HintInfo(strategy = Some(BROADCAST))), + Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil ) verifyJoinHint( sql("select /*+ broadcast(b, c)*/ * from a, c, b " + "where a.a1 = b.b1 and b.b1 = c.c1"), JoinHint( None, - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint( None, - Some(HintInfo(broadcast = true))) :: Nil + Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil ) verifyJoinHint( df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast") .join(df3, 'b1 === 'c1 && 'a1 < 10), JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: JoinHint.NONE :: Nil ) @@ -151,7 +152,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { .join(df, 'b1 === 'id), JoinHint.NONE :: JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: JoinHint.NONE :: Nil ) @@ -164,7 +165,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { verifyJoinHint( df.hint("broadcast").except(dfSub).join(df, "id"), JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: JoinHint.NONE :: Nil ) @@ -172,7 +173,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { df.join(df.hint("broadcast").intersect(dfSub), "id"), JoinHint( None, - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint.NONE :: Nil ) } @@ -181,14 +182,14 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { verifyJoinHint( df.hint("broadcast").filter('id > 2).hint("broadcast").join(df, "id"), JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: Nil ) verifyJoinHint( df.join(df.hint("broadcast").limit(2).hint("broadcast"), "id"), JoinHint( None, - Some(HintInfo(broadcast = true))) :: Nil + Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil ) } @@ -196,7 +197,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { verifyJoinHint( df.hint("broadcast").hint("broadcast").filter('id > 2).join(df, "id"), JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: Nil ) } @@ -209,12 +210,230 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { join.join(broadcasted, "id").join(broadcasted, "id"), JoinHint( None, - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint( None, - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint.NONE :: JoinHint.NONE :: JoinHint.NONE :: Nil ) } } + + def equiJoinQueryWithHint(hints: Seq[String], joinType: String = "INNER"): String = + hints.map("/*+ " + _ + " */").mkString( + "SELECT ", " ", s" * FROM t1 $joinType JOIN t2 ON t1.key = t2.key") + + def nonEquiJoinQueryWithHint(hints: Seq[String], joinType: String = "INNER"): String = + hints.map("/*+ " + _ + " */").mkString( + "SELECT ", " ", s" * FROM t1 $joinType JOIN t2 ON t1.key > t2.key") + + private def assertBroadcastHashJoin(df: DataFrame, buildSide: BuildSide): Unit = { + val executedPlan = df.queryExecution.executedPlan + val broadcastHashJoins = executedPlan.collect { + case b: BroadcastHashJoinExec => b + } + assert(broadcastHashJoins.size == 1) + assert(broadcastHashJoins.head.buildSide == buildSide) + } + + private def assertBroadcastNLJoin(df: DataFrame, buildSide: BuildSide): Unit = { + val executedPlan = df.queryExecution.executedPlan + val broadcastNLJoins = executedPlan.collect { + case b: BroadcastNestedLoopJoinExec => b + } + assert(broadcastNLJoins.size == 1) + assert(broadcastNLJoins.head.buildSide == buildSide) + } + + private def assertShuffleHashJoin(df: DataFrame, buildSide: BuildSide): Unit = { + val executedPlan = df.queryExecution.executedPlan + val shuffleHashJoins = executedPlan.collect { + case s: ShuffledHashJoinExec => s + } + assert(shuffleHashJoins.size == 1) + assert(shuffleHashJoins.head.buildSide == buildSide) + } + + private def assertShuffleMergeJoin(df: DataFrame): Unit = { + val executedPlan = df.queryExecution.executedPlan + val shuffleMergeJoins = executedPlan.collect { + case s: SortMergeJoinExec => s + } + assert(shuffleMergeJoins.size == 1) + } + + private def assertShuffleReplicateNLJoin(df: DataFrame): Unit = { + val executedPlan = df.queryExecution.executedPlan + val shuffleReplicateNLJoins = executedPlan.collect { + case c: CartesianProductExec => c + } + assert(shuffleReplicateNLJoins.size == 1) + } + + test("join strategy hint - broadcast") { + withTempView("t1", "t2") { + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") + + val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes + val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes + assert(t1Size < t2Size) + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + // Broadcast hint specified on one side + assertBroadcastHashJoin( + sql(equiJoinQueryWithHint("BROADCAST(t1)" :: Nil)), BuildLeft) + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("BROADCAST(t2)" :: Nil)), BuildRight) + + // Determine build side based on the join type and child relation sizes + assertBroadcastHashJoin( + sql(equiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil)), BuildLeft) + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil, "left")), BuildRight) + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil, "right")), BuildLeft) + + // Use broadcast-hash join if hinted "broadcast" and equi-join + assertBroadcastHashJoin( + sql(equiJoinQueryWithHint("BROADCAST(t2)" :: "SHUFFLE_HASH(t1)" :: Nil)), BuildRight) + assertBroadcastHashJoin( + sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "MERGE(t2)" :: Nil)), BuildLeft) + assertBroadcastHashJoin( + sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "SHUFFLE_REPLICATE_NL(t2)" :: Nil)), + BuildLeft) + + // Use broadcast-nl join if hinted "broadcast" and non-equi-join + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("SHUFFLE_HASH(t2)" :: "BROADCAST(t1)" :: Nil)), BuildLeft) + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("MERGE(t1)" :: "BROADCAST(t2)" :: Nil)), BuildRight) + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1)" :: "BROADCAST(t2)" :: Nil)), + BuildRight) + + // Broadcast hint specified but not doable + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("BROADCAST(t1)" :: Nil, "left"))) + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("BROADCAST(t2)" :: Nil, "right"))) + } + } + } + + test("join strategy hint - shuffle-merge") { + withTempView("t1", "t2") { + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Int.MaxValue.toString) { + // Shuffle-merge hint specified on one side + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("SHUFFLE_MERGE(t1)" :: Nil))) + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("MERGEJOIN(t2)" :: Nil))) + + // Shuffle-merge hint specified on both sides + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("MERGE(t1, t2)" :: Nil))) + + // Shuffle-hash hint prioritized over shuffle-hash hint and shuffle-replicate-nl hint + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t2)" :: "MERGE(t1)" :: Nil, "left"))) + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("MERGE(t2)" :: "SHUFFLE_HASH(t1)" :: Nil, "right"))) + + // Broadcast hint prioritized over shuffle-merge hint, but broadcast hint is not applicable + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "MERGE(t2)" :: Nil, "left"))) + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("BROADCAST(t2)" :: "MERGE(t1)" :: Nil, "right"))) + + // Shuffle-merge hint specified but not doable + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("MERGE(t1, t2)" :: Nil, "left")), BuildRight) + } + } + } + + test("join strategy hint - shuffle-hash") { + withTempView("t1", "t2") { + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") + + val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes + val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes + assert(t1Size < t2Size) + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Int.MaxValue.toString) { + // Shuffle-hash hint specified on one side + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1)" :: Nil)), BuildLeft) + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_HASH(t2)" :: Nil)), BuildRight) + + // Determine build side based on the join type and child relation sizes + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1, t2)" :: Nil)), BuildLeft) + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1, t2)" :: Nil, "left")), BuildRight) + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1, t2)" :: Nil, "right")), BuildLeft) + + // Shuffle-hash hint prioritized over shuffle-replicate-nl hint + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t2)" :: "SHUFFLE_HASH(t1)" :: Nil)), + BuildLeft) + + // Broadcast hint prioritized over shuffle-hash hint, but broadcast hint is not applicable + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "SHUFFLE_HASH(t2)" :: Nil, "left")), + BuildRight) + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("BROADCAST(t2)" :: "SHUFFLE_HASH(t1)" :: Nil, "right")), + BuildLeft) + + // Shuffle-hash hint specified but not doable + assertBroadcastHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1)" :: Nil, "left")), BuildRight) + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("SHUFFLE_HASH(t1)" :: Nil)), BuildLeft) + } + } + } + + test("join strategy hint - shuffle-replicate-nl") { + withTempView("t1", "t2") { + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Int.MaxValue.toString) { + // Shuffle-replicate-nl hint specified on one side + assertShuffleReplicateNLJoin( + sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1)" :: Nil))) + assertShuffleReplicateNLJoin( + sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t2)" :: Nil))) + + // Shuffle-replicate-nl hint specified on both sides + assertShuffleReplicateNLJoin( + sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1, t2)" :: Nil))) + + // Shuffle-merge hint prioritized over shuffle-replicate-nl hint, but shuffle-merge hint + // is not applicable + assertShuffleReplicateNLJoin( + sql(nonEquiJoinQueryWithHint("MERGE(t1)" :: "SHUFFLE_REPLICATE_NL(t2)" :: Nil))) + + // Shuffle-hash hint prioritized over shuffle-replicate-nl hint, but shuffle-hash hint is + // not applicable + assertShuffleReplicateNLJoin( + sql(nonEquiJoinQueryWithHint("SHUFFLE_HASH(t2)" :: "SHUFFLE_REPLICATE_NL(t1)" :: Nil))) + + // Shuffle-replicate-nl hint specified but not doable + assertBroadcastHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1, t2)" :: Nil, "left")), BuildRight) + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1, t2)" :: Nil, "right")), BuildLeft) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index f238148e61c3..05c583c80e50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,6 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} +import org.apache.spark.sql.catalyst.plans.logical.BROADCAST import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.EnsureRequirements @@ -216,10 +217,10 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { val plan3 = sql(s"SELECT /*+ $name(v) */ * FROM t JOIN u ON t.id = u.id").queryExecution .optimizedPlan - assert(plan1.asInstanceOf[Join].hint.leftHint.get.broadcast) + assert(plan1.asInstanceOf[Join].hint.leftHint.get.strategy.contains(BROADCAST)) assert(plan1.asInstanceOf[Join].hint.rightHint.isEmpty) assert(plan2.asInstanceOf[Join].hint.leftHint.isEmpty) - assert(plan2.asInstanceOf[Join].hint.rightHint.get.broadcast) + assert(plan2.asInstanceOf[Join].hint.rightHint.get.strategy.contains(BROADCAST)) assert(plan3.asInstanceOf[Join].hint.leftHint.isEmpty) assert(plan3.asInstanceOf[Join].hint.rightHint.isEmpty) } From e77c9f34a857bcc516fc3b27b75f9304564156a0 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Mon, 25 Mar 2019 21:27:31 -0500 Subject: [PATCH 02/12] address review comments --- docs/sql-performance-tuning.md | 24 ++-- .../sql/catalyst/plans/logical/hints.scala | 13 ++ .../spark/sql/execution/SparkStrategies.scala | 125 ++++++++++++------ 3 files changed, 110 insertions(+), 52 deletions(-) diff --git a/docs/sql-performance-tuning.md b/docs/sql-performance-tuning.md index 7c7c4a815545..51b5d0194ae6 100644 --- a/docs/sql-performance-tuning.md +++ b/docs/sql-performance-tuning.md @@ -92,14 +92,22 @@ that these options will be deprecated in future release as more optimizations ar -## Broadcast Hint for SQL Queries - -The `BROADCAST` hint guides Spark to broadcast each specified table when joining them with another table or view. -When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, -even if the statistics is above the configuration `spark.sql.autoBroadcastJoinThreshold`. -When both sides of a join are specified, Spark broadcasts the one having the lower statistics. -Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) -support BHJ. When the broadcast nested loop join is selected, we still respect the hint. +## Join Strategy Hints for SQL Queries + +The join strategy hints, namely `BROADCAST`, `MERGE`, `SHUFFLE_HASH` and `SHUFFLE_REPLICATE_NL`, +instruct Spark to use the hinted strategy on each specified relation when joining them with another +relation. For example, when the `BROADCAST` hint is used on table 't1', broadcast join (either +broadcast hash join or broadcast nested loop join depending on whether there is any equi-join key) +with 't1' as the build side will be prioritized by Spark even if the size of table 't1' suggested +by the statistics is above the configuration `spark.sql.autoBroadcastJoinThreshold`. + +When different join strategy hints are specified on both sides of a join, Spark prioritizes the +`BROADCAST` hint over the `MERGE` hint over the `SHUFFLE_HASH` hint over the `SHUFFLE_REPLICATE_NL` +hint. When both sides are specified with the `BROADCAST` hint or the `SHUFFLE_HASH` hint, Spark will +pick the build side based on the join type and the sizes of the relations. + +Note that there is no guarantee that Spark will choose the join strategy specified in the hint since +a specific strategy may not support all join types.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index ddd577aa137c..007ba483d411 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -112,6 +112,10 @@ object JoinStrategyHint { SHUFFLE_REPLICATE_NL) } +/** + * The hint for broadcast hash join or broadcast nested loop join, depending on the availability of + * equi-join keys. + */ case object BROADCAST extends JoinStrategyHint { override def displayName: String = "broadcast-hash" override def hintAliases: Set[String] = Set( @@ -120,6 +124,9 @@ case object BROADCAST extends JoinStrategyHint { "MAPJOIN") } +/** + * The hint for shuffle sort merge join. + */ case object SHUFFLE_MERGE extends JoinStrategyHint { override def displayName: String = "shuffle-merge" override def hintAliases: Set[String] = Set( @@ -128,12 +135,18 @@ case object SHUFFLE_MERGE extends JoinStrategyHint { "MERGEJOIN") } +/** + * The hint for shuffle hash join. + */ case object SHUFFLE_HASH extends JoinStrategyHint { override def displayName: String = "shuffle-hash" override def hintAliases: Set[String] = Set( "SHUFFLE_HASH") } +/** + * The hint for shuffle-and-replicate nested loop join, a.k.a. cartesian product join. + */ case object SHUFFLE_REPLICATE_NL extends JoinStrategyHint { override def displayName: String = "shuffle-replicate-nested-loop" override def hintAliases: Set[String] = Set( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d3eb67b008bb..b8bc0fd9a1e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -90,61 +90,98 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** - * Select the proper physical plan for join based on joining keys and size of logical plan. - * - * At first, uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the - * predicates can be evaluated by matching join keys. If found, join implementations are chosen - * with the following precedence: + * Select the proper physical plan for join based on join strategy hints, the availability of + * equi-join keys and the sizes of joining relations. Below are the existing join strategies, + * their characteristics and their limitations. * * - Broadcast hash join (BHJ): - * BHJ is not supported for full outer join. For right outer join, we only can broadcast the - * left side. For left outer, left semi, left anti and the internal join type ExistenceJoin, - * we only can broadcast the right side. For inner like join, we can broadcast both sides. - * Normally, BHJ can perform faster than the other join algorithms when the broadcast side is - * small. However, broadcasting tables is a network-intensive operation. It could cause OOM - * or perform worse than the other join algorithms, especially when the build/broadcast side - * is big. + * Only supported for equi-joins, while the join keys do not need to be sortable. + * Supported for all join types except full outer joins. + * BHJ usually performs faster than the other join algorithms when the broadcast side is + * small. However, broadcasting tables is a network-intensive operation and it could cause + * OOM or perform badly in some cases, especially when the build/broadcast side is big. + * + * - Shuffle hash join: + * Only supported for equi-joins, while the join keys do not need to be sortable. + * Supported for all join types except full outer joins. + * + * - Shuffle sort merge join (SMJ): + * Only supported for equi-joins and the join keys have to be sortable. + * Supported for all join types. + * + * - Broadcast nested loop join (BNLJ): + * Supports both equi-joins and non-equi-joins. + * Supports all the join types, but the implementation is optimized for: + * 1) broadcasting the left side in a right outer join; + * 2) broadcasting the right side in a left outer, left semi, left anti or existence join; + * 3) broadcasting either side in an inner-like join. + * + * - Shuffle-and-replicate nested loop join (a.k.a. cartesian product join): + * Supports both equi-joins and non-equi-joins. + * Supports only inner like joins. + * + * First, look at applicable join strategies hints: + * + * 1. Use broadcast hash join if: + * a) it is an equi-join; and + * b) either side has `BROADCAST` hint and is buildable, i.e., not being the null-generating + * side of an outer join (e.g., either side of an inner-like join, left side of a right + * outer join, or right side of a left outer, left semi, left anti or existence join). + * If both sides satisfy b), choose the smaller side (based on stats) to broadcast. + * + * 2. Use broadcast nested loop join if: + * either side has `BROADCAST` hint and is buildable. + * If both sides satisfy, choose the smaller side to broadcast. + * ** Note that hitting this branch implies this is a non-equi-join. + * + * 3. Use shuffle sort merge join if: + * a) it is an equi-join; and + * b) the equi-join keys are sortable; and + * c) either side has `MERGE` hint. + * + * 4. Use shuffle hash join if: + * a) it is an equi-join; and + * b) either side has `SHUFFLE_HASH` hint and is buildable. + * If both sides satisfy b), choose the smaller side to build. * - * For the supported cases, users can specify the broadcast hint (e.g. the user applied the - * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame) and session-based - * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to adjust whether BHJ is used and - * which join side is broadcast. + * 5. Use shuffle-and-replicate nested loop join if: + * a) it is an inner-like join; and + * b) either side has `SHUFFLE_REPLICATE_NL` hint. * - * 1) Broadcast the join side with the broadcast hint, even if the size is larger than - * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint (only when the type - * is inner like join), the side with a smaller estimated physical size will be broadcast. - * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and broadcast the side - * whose estimated physical size is smaller than the threshold. If both sides are below the - * threshold, broadcast the smaller side. If neither is smaller, BHJ is not used. + * Second, use the [[ExtractEquiJoinKeys]] pattern to find equi-join keys from the join predicates + * and choose an equi-join algorithm based on the following precedence: * - * - Shuffle hash join: if the average size of a single partition is small enough to build a hash - * table. + * 1. Use broadcast hash join if: + * a) it is an equi-join; and + * b) either side is buildable and is of a size equal to or smaller than + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. + * If both sides satisfy b), choose the smaller side to broadcast. * - * - Sort merge: if the matching join keys are sortable. + * 2. Use shuffle hash join if: + * a) it is an equi-join; and + * b) either of the following holds: + * - the equi-join keys are not sortable; + * - [[SQLConf.PREFER_SORTMERGEJOIN]] is `false` and either side is buildable and its + * average partition size is equal to or smaller than + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] and that side is much smaller than the other + * side. * - * If there is no joining keys, Join implementations are chosen with the following precedence: - * - BroadcastNestedLoopJoin (BNLJ): - * BNLJ supports all the join types but the impl is OPTIMIZED for the following scenarios: - * For right outer join, the left side is broadcast. For left outer, left semi, left anti - * and the internal join type ExistenceJoin, the right side is broadcast. For inner like - * joins, either side is broadcast. + * 3. Use shuffle sort merge join if: + * a) it is an equi-join; and + * b) the equi-join keys are sortable. * - * Like BHJ, users still can specify the broadcast hint and session-based - * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to impact which side is broadcast. + * Last, given there is no equi-join keys, choose the join algorithm for non-equi-joins based on + * the following precedence: * - * 1) Broadcast the join side with the broadcast hint, even if the size is larger than - * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint (i.e., just for - * inner-like join), the side with a smaller estimated physical size will be broadcast. - * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and broadcast the side - * whose estimated physical size is smaller than the threshold. If both sides are below the - * threshold, broadcast the smaller side. If neither is smaller, BNLJ is not used. + * 1. Use broadcast nested loop join if: + * either side is buildable and is of a size equal to or smaller than + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. + * If both sides satisfy, choose the smaller side to broadcast. * - * - CartesianProduct: for inner like join, CartesianProduct is the fallback option. + * 2. Use shuffle-and-replicate nested loop join if: + * it is an inner-like join. * - * - BroadcastNestedLoopJoin (BNLJ): - * For the other join types, BNLJ is the fallback option. Here, we just pick the broadcast - * side with the broadcast hint. If neither side has a hint, we broadcast the side with - * the smaller estimated physical size. + * 3. Use broadcast nested loop join as the fallback option. Choose the smaller side to broadcast. */ object JoinSelection extends Strategy with PredicateHelper { From f198dfb9838dbc50eef0f72fa8bc681194be3f5b Mon Sep 17 00:00:00 2001 From: maryannxue Date: Wed, 3 Apr 2019 15:15:54 -0500 Subject: [PATCH 03/12] add more tests --- .../sql/catalyst/analysis/ResolveHints.scala | 10 ++- .../optimizer/EliminateResolvedHint.scala | 2 +- .../org/apache/spark/sql/JoinHintSuite.scala | 71 ++++++++++++++++++- 3 files changed, 78 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 4231775682ed..75dbdbc70323 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -69,6 +69,13 @@ object ResolveHints { val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { + case ResolvedHint(u: UnresolvedRelation, _) + if relations.exists(resolver(_, u.tableIdentifier.table)) => + ResolvedHint(u, createHintInfo(hintName)) + case ResolvedHint(r: SubqueryAlias, _) + if relations.exists(resolver(_, r.alias)) => + ResolvedHint(r, createHintInfo(hintName)) + case u: UnresolvedRelation if relations.exists(resolver(_, u.tableIdentifier.table)) => ResolvedHint(plan, createHintInfo(hintName)) case r: SubqueryAlias if relations.exists(resolver(_, r.alias)) => @@ -76,8 +83,7 @@ object ResolveHints { case _: ResolvedHint | _: View | _: With | _: SubqueryAlias => // Don't traverse down these nodes. - // For an existing strategy hint, there is no point going down (if we do, we either - // won't change the structure, or will introduce another strategy hint that is useless. + // For an existing strategy hint, there is no chance for a match from this point down. // The rest (view, with, subquery) indicates different scopes that we shouldn't traverse // down. Note that technically when this rule is executed, we haven't completed view // resolution yet and as a result the view part should be deadcode. I'm leaving it here diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala index df50bd7840c6..78383c25e0a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala @@ -45,7 +45,7 @@ object EliminateResolvedHint extends Rule[LogicalPlan] { private def collectHints(plan: LogicalPlan): Seq[HintInfo] = { plan match { - case h: ResolvedHint => collectHints(h.child) :+ h.hints + case h: ResolvedHint => h.hints +: collectHints(h.child) case u: UnaryNode => collectHints(u.child) // TODO revisit this logic: // except and intersect are semi/anti-joins which won't return more data then diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala index 76b503ca85fb..af6f91bbbdae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -191,6 +191,67 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { None, Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil ) + verifyJoinHint( + df.hint("merge").filter('id > 2).hint("shuffle_hash").join(df, "id"), + JoinHint( + Some(HintInfo(strategy = Some(SHUFFLE_HASH))), + None) :: Nil + ) + verifyJoinHint( + df.join(df.hint("broadcast").limit(2).hint("merge"), "id") + .hint("shuffle_hash") + .hint("shuffle_replicate_nl") + .join(df, "id"), + JoinHint( + Some(HintInfo(strategy = Some(SHUFFLE_REPLICATE_NL))), + None) :: + JoinHint( + None, + Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) :: Nil + ) + } + + test("hint merge - SQL") { + withTempView("a", "b", "c") { + df1.createOrReplaceTempView("a") + df2.createOrReplaceTempView("b") + df3.createOrReplaceTempView("c") + verifyJoinHint( + sql("select /*+ merge(a, c) broadcast(a, b)*/ * from a, b, c " + + "where a.a1 = b.b1 and b.b1 = c.c1"), + JoinHint( + None, + Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) :: + JoinHint( + Some(HintInfo(strategy = Some(SHUFFLE_MERGE))), + Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil + ) + verifyJoinHint( + sql("select /*+ shuffle_hash(a, b) merge(b, d) broadcast(b)*/ * from a, b, c " + + "where a.a1 = b.b1 and b.b1 = c.c1"), + JoinHint.NONE :: + JoinHint( + Some(HintInfo(strategy = Some(SHUFFLE_HASH))), + Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil + ) + verifyJoinHint( + sql( + """ + |select /*+ broadcast(a, c) merge(a, d)*/ * from a + |join ( + | select /*+ shuffle_hash(c) shuffle_replicate_nl(b, c)*/ * from b + | join c on b.b1 = c.c1 + |) as d + |on a.a2 = d.b2 + """.stripMargin), + JoinHint( + Some(HintInfo(strategy = Some(BROADCAST))), + Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) :: + JoinHint( + Some(HintInfo(strategy = Some(SHUFFLE_REPLICATE_NL))), + Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil + ) + } } test("nested hint") { @@ -200,6 +261,12 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { Some(HintInfo(strategy = Some(BROADCAST))), None) :: Nil ) + verifyJoinHint( + df.hint("shuffle_hash").hint("broadcast").hint("merge").filter('id > 2).join(df, "id"), + JoinHint( + Some(HintInfo(strategy = Some(SHUFFLE_MERGE))), + None) :: Nil + ) } test("hints prevent cost-based join reorder") { @@ -298,7 +365,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { assertBroadcastHashJoin( sql(equiJoinQueryWithHint("BROADCAST(t2)" :: "SHUFFLE_HASH(t1)" :: Nil)), BuildRight) assertBroadcastHashJoin( - sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "MERGE(t2)" :: Nil)), BuildLeft) + sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "MERGE(t1, t2)" :: Nil)), BuildLeft) assertBroadcastHashJoin( sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "SHUFFLE_REPLICATE_NL(t2)" :: Nil)), BuildLeft) @@ -337,7 +404,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { assertShuffleMergeJoin( sql(equiJoinQueryWithHint("MERGE(t1, t2)" :: Nil))) - // Shuffle-hash hint prioritized over shuffle-hash hint and shuffle-replicate-nl hint + // Shuffle-merge hint prioritized over shuffle-hash hint and shuffle-replicate-nl hint assertShuffleMergeJoin( sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t2)" :: "MERGE(t1)" :: Nil, "left"))) assertShuffleMergeJoin( From 4a13ffe664df7b85385da6dedd4660f2a5eb2f32 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Sun, 7 Apr 2019 12:10:29 -0500 Subject: [PATCH 04/12] address review comments --- .../org/apache/spark/sql/execution/SparkStrategies.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b8bc0fd9a1e3..fe94d03bc025 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -120,7 +120,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Supports both equi-joins and non-equi-joins. * Supports only inner like joins. * - * First, look at applicable join strategies hints: + * First, look at applicable join strategies hints based on the following precedence: * * 1. Use broadcast hash join if: * a) it is an equi-join; and @@ -322,7 +322,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) // broadcast hints specified with no equi-join keys, use broadcast-nested-loop - case j @ logical.Join(left, right, joinType, condition, hint) + case logical.Join(left, right, joinType, condition, hint) if canBroadcastByHints(joinType, left, right, hint) => val buildSide = broadcastSideByHints(joinType, left, right, hint) Seq(joins.BroadcastNestedLoopJoinExec( @@ -383,7 +383,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Without joining keys ------------------------------------------------------------ // Pick BroadcastNestedLoopJoin if one side could be broadcast - case j @ logical.Join(left, right, joinType, condition, _) + case logical.Join(left, right, joinType, condition, _) if canBroadcastBySizes(joinType, left, right) => val buildSide = broadcastSideBySizes(joinType, left, right) joins.BroadcastNestedLoopJoinExec( From 407c63fa86b0bf0d2f0872aa1357d247e55cde42 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 9 Apr 2019 01:31:30 +0800 Subject: [PATCH 05/12] refactor --- .../spark/sql/execution/SparkStrategies.scala | 379 ++++++++---------- 1 file changed, 166 insertions(+), 213 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index fe94d03bc025..3313ec289163 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -119,69 +119,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * - Shuffle-and-replicate nested loop join (a.k.a. cartesian product join): * Supports both equi-joins and non-equi-joins. * Supports only inner like joins. - * - * First, look at applicable join strategies hints based on the following precedence: - * - * 1. Use broadcast hash join if: - * a) it is an equi-join; and - * b) either side has `BROADCAST` hint and is buildable, i.e., not being the null-generating - * side of an outer join (e.g., either side of an inner-like join, left side of a right - * outer join, or right side of a left outer, left semi, left anti or existence join). - * If both sides satisfy b), choose the smaller side (based on stats) to broadcast. - * - * 2. Use broadcast nested loop join if: - * either side has `BROADCAST` hint and is buildable. - * If both sides satisfy, choose the smaller side to broadcast. - * ** Note that hitting this branch implies this is a non-equi-join. - * - * 3. Use shuffle sort merge join if: - * a) it is an equi-join; and - * b) the equi-join keys are sortable; and - * c) either side has `MERGE` hint. - * - * 4. Use shuffle hash join if: - * a) it is an equi-join; and - * b) either side has `SHUFFLE_HASH` hint and is buildable. - * If both sides satisfy b), choose the smaller side to build. - * - * 5. Use shuffle-and-replicate nested loop join if: - * a) it is an inner-like join; and - * b) either side has `SHUFFLE_REPLICATE_NL` hint. - * - * Second, use the [[ExtractEquiJoinKeys]] pattern to find equi-join keys from the join predicates - * and choose an equi-join algorithm based on the following precedence: - * - * 1. Use broadcast hash join if: - * a) it is an equi-join; and - * b) either side is buildable and is of a size equal to or smaller than - * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. - * If both sides satisfy b), choose the smaller side to broadcast. - * - * 2. Use shuffle hash join if: - * a) it is an equi-join; and - * b) either of the following holds: - * - the equi-join keys are not sortable; - * - [[SQLConf.PREFER_SORTMERGEJOIN]] is `false` and either side is buildable and its - * average partition size is equal to or smaller than - * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] and that side is much smaller than the other - * side. - * - * 3. Use shuffle sort merge join if: - * a) it is an equi-join; and - * b) the equi-join keys are sortable. - * - * Last, given there is no equi-join keys, choose the join algorithm for non-equi-joins based on - * the following precedence: - * - * 1. Use broadcast nested loop join if: - * either side is buildable and is of a size equal to or smaller than - * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. - * If both sides satisfy, choose the smaller side to broadcast. - * - * 2. Use shuffle-and-replicate nested loop join if: - * it is an inner-like join. - * - * 3. Use broadcast nested loop join as the fallback option. Choose the smaller side to broadcast. */ object JoinSelection extends Strategy with PredicateHelper { @@ -223,187 +160,203 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => false } - private def broadcastSide( - canBuildLeft: Boolean, - canBuildRight: Boolean, + private def getBuildSide( + wantToBuildLeft: Boolean, + wantToBuildRight: Boolean, left: LogicalPlan, - right: LogicalPlan): BuildSide = { - - def smallerSide = - if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft - - if (canBuildRight && canBuildLeft) { - // Broadcast smaller side base on its estimated physical size - // if both sides have broadcast hint - smallerSide - } else if (canBuildRight) { - BuildRight - } else if (canBuildLeft) { - BuildLeft + right: LogicalPlan): Option[BuildSide] = { + if (wantToBuildLeft && wantToBuildRight) { + // returns the smaller side base on its estimated physical size, if we want to build the + // both sides. + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) { + Some(BuildRight) + } else { + Some(BuildLeft) + } + } else if (wantToBuildLeft) { + Some(BuildLeft) + } else if (wantToBuildRight) { + Some(BuildRight) } else { - // for the last default broadcast nested loop join - smallerSide + None } } - private def canBroadcastByHints( - joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): Boolean = { - val buildLeft = - canBuildLeft(joinType) && hint.leftHint.exists(_.strategy.contains(BROADCAST)) - val buildRight = - canBuildRight(joinType) && hint.rightHint.exists(_.strategy.contains(BROADCAST)) - buildLeft || buildRight + private def hintToSortMergeJoin(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) || + hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE)) + } + + private def hintToShuffleReplicateNL(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) || + hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) } private def broadcastSideByHints( - joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): BuildSide = { - val buildLeft = + joinType: JoinType, + left: LogicalPlan, + right: LogicalPlan, + hint: JoinHint): Option[BuildSide] = { + val wantToBuildLeft = canBuildLeft(joinType) && hint.leftHint.exists(_.strategy.contains(BROADCAST)) - val buildRight = + val wantToBuildRight = canBuildRight(joinType) && hint.rightHint.exists(_.strategy.contains(BROADCAST)) - broadcastSide(buildLeft, buildRight, left, right) - } - - private def canShuffleHashByHints( - joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): Boolean = { - val buildLeft = - canBuildLeft(joinType) && hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH)) - val buildRight = - canBuildRight(joinType) && hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH)) - buildLeft || buildRight + getBuildSide(wantToBuildLeft, wantToBuildRight, left, right) } private def shuffleHashSideByHints( - joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): BuildSide = { - val buildLeft = + joinType: JoinType, + left: LogicalPlan, + right: LogicalPlan, + hint: JoinHint): Option[BuildSide] = { + val wantToBuildLeft = canBuildLeft(joinType) && hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH)) - val buildRight = + val wantToBuildRight = canBuildRight(joinType) && hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH)) - broadcastSide(buildLeft, buildRight, left, right) + getBuildSide(wantToBuildLeft, wantToBuildRight, left, right) } - private def canShuffleMergeByHints( - leftKeys: Seq[Expression], hint: JoinHint): Boolean = { - val isOrderable = RowOrdering.isOrderable(leftKeys) - val hasMergeHint = - (hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) - || hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE))) - isOrderable && hasMergeHint + private def broadcastSideBySizes( + joinType: JoinType, + left: LogicalPlan, + right: LogicalPlan): Option[BuildSide] = { + val wantToBuildLeft = canBuildLeft(joinType) && canBroadcast(left) + val wantToBuildRight = canBuildRight(joinType) && canBroadcast(right) + getBuildSide(wantToBuildLeft, wantToBuildRight, left, right) } - private def shuffleReplicateNLByHints(hint: JoinHint): Boolean = { - (hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) - || hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL))) + private def shuffleHashSideBySizes( + joinType: JoinType, + left: LogicalPlan, + right: LogicalPlan): Option[BuildSide] = { + val wantToBuildLeft = + canBuildLeft(joinType) && canBuildLocalHashMap(left) && muchSmaller(left, right) + val wantToBuildRight = + canBuildRight(joinType) && canBuildLocalHashMap(right) && muchSmaller(right, left) + getBuildSide(wantToBuildLeft, wantToBuildRight, left, right) } - private def canBroadcastBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) - : Boolean = { - val buildLeft = canBuildLeft(joinType) && canBroadcast(left) - val buildRight = canBuildRight(joinType) && canBroadcast(right) - buildLeft || buildRight + private def createCartesianProduct( + left: LogicalPlan, + right: LogicalPlan, + condition: Option[Expression]) = { + Seq(joins.CartesianProductExec(planLater(left), planLater(right), condition)) } - private def broadcastSideBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) - : BuildSide = { - val buildLeft = canBuildLeft(joinType) && canBroadcast(left) - val buildRight = canBuildRight(joinType) && canBroadcast(right) - broadcastSide(buildLeft, buildRight, left, right) + private def createFinalBroadcastNLJoin( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]) = { + val smallerSide = + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft + // This join could be very slow or OOM + Seq(joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), smallerSide, joinType, condition)) } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // --- Hints specified, choose join strategy based on hints. -------------------------------- - - // broadcast hints specified with equi-join keys, use broadcast-hash - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) - if canBroadcastByHints(joinType, left, right, hint) => - val buildSide = broadcastSideByHints(joinType, left, right, hint) - Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) - - // broadcast hints specified with no equi-join keys, use broadcast-nested-loop - case logical.Join(left, right, joinType, condition, hint) - if canBroadcastByHints(joinType, left, right, hint) => - val buildSide = broadcastSideByHints(joinType, left, right, hint) - Seq(joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), buildSide, joinType, condition)) - - // shuffle-merge hints specified - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) - if canShuffleMergeByHints(leftKeys, hint) => - Seq(joins.SortMergeJoinExec( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right))) - - // shuffle-hash hints specified - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) - if canShuffleHashByHints(joinType, left, right, hint) => - val buildSide = shuffleHashSideByHints(joinType, left, right, hint) - Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) - - // shuffle-replicate-nl hints specified - case logical.Join(left, right, _: InnerLike, condition, hint) - if shuffleReplicateNLByHints(hint) => - Seq(joins.CartesianProductExec(planLater(left), planLater(right), condition)) - - - // --- No hints specified, choose join strategy based on size and configuration. ------------ - - // --- BroadcastHashJoin -------------------------------------------------------------------- - - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) - if canBroadcastBySizes(joinType, left, right) => - val buildSide = broadcastSideBySizes(joinType, left, right) - Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) - - // --- ShuffledHashJoin --------------------------------------------------------------------- - - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) - if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right) - && muchSmaller(right, left) || - !RowOrdering.isOrderable(leftKeys) => - Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) - - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) - if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left) - && muchSmaller(left, right) || - !RowOrdering.isOrderable(leftKeys) => - Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) - - // --- SortMergeJoin ------------------------------------------------------------ - - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) - if RowOrdering.isOrderable(leftKeys) => - joins.SortMergeJoinExec( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil - - // --- Without joining keys ------------------------------------------------------------ - - // Pick BroadcastNestedLoopJoin if one side could be broadcast - case logical.Join(left, right, joinType, condition, _) - if canBroadcastBySizes(joinType, left, right) => - val buildSide = broadcastSideBySizes(joinType, left, right) - joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), buildSide, joinType, condition) :: Nil - - // Pick CartesianProduct for InnerJoin - case logical.Join(left, right, _: InnerLike, condition, _) => - joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil + // If it is an equi-join, we first look at the join hints w.r.t. the following order: + // 1. broadcast hint: pick broadcast hash join if join type is not full outer. If both sides + // have the broadcast hints, choose the smaller side (based on stats) to broadcast. + // 2. sort merge hint: pick sort merge join if join keys are sortable. + // 3. shuffle hash hint: We pick shuffle hash join if join type is not full outer. If both + // sides have the shuffle hash hints, choose the smaller side (based on stats) as the + // build side. + // 4. shuffle replicate NL hint: pick cartesian product if join type is inner like. + // + // If there is no hint or the hints are not applicable, we follow these rules one by one: + // 1. Pick broadcast hash join if one side is small enough to broadcast, and the join type + // is not full outer. If both sides are small, choose the smaller side (based on stats) + // to broadcast. + // 2. Pick sort merge join if the join keys are sortable. + // 3. Pick shuffle hash join if one side is small enough to build local hash map, and is + // much smaller than the other side. + // 4. Pick cartesian product if join type is inner like. + // 5. Pick broadcast nested loop join as the final solution. It may OOM but we don't have + // other choice. + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) => + def createBroadcastHashJoin(buildSide: BuildSide) = { + Seq(joins.BroadcastHashJoinExec( + leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + } + + def createSortMergeJoin() = { + Seq(joins.SortMergeJoinExec( + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right))) + } + + def createShuffleHashJoin(buildSide: BuildSide) = { + Seq(joins.ShuffledHashJoinExec( + leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + } + + broadcastSideByHints(joinType, left, right, hint).map(createBroadcastHashJoin).getOrElse { + if (RowOrdering.isOrderable(leftKeys) && hintToSortMergeJoin(hint)) { + createSortMergeJoin() + } else { + shuffleHashSideByHints(joinType, left, right, hint).map { side => + createShuffleHashJoin(side) + }.getOrElse { + if (joinType.isInstanceOf[InnerLike] && hintToShuffleReplicateNL(hint)) { + createCartesianProduct(left, right, condition) + } else { + createJoinWithoutHint() + } + } + } + } + def createJoinWithoutHint() = { + broadcastSideBySizes(joinType, left, right).map(createBroadcastHashJoin).getOrElse { + val shuffleHashBuildSide = shuffleHashSideBySizes(joinType, left, right) + if (!conf.preferSortMergeJoin && shuffleHashBuildSide.isDefined) { + createShuffleHashJoin(shuffleHashBuildSide.get) + } else if (RowOrdering.isOrderable(leftKeys)) { + createSortMergeJoin() + } else if (joinType.isInstanceOf[InnerLike]) { + createCartesianProduct(left, right, condition) + } else { + createFinalBroadcastNLJoin(left, right, joinType, condition) + } + } + } + + // If it is not an equi-join, we first look at the join hints w.r.t. the following order: + // 1. broadcast hint: pick broadcast nested loop join. If both sides have the broadcast + // hints, choose the smaller side (based on stats) to broadcast. + // 2. shuffle replicate NL hint: pick cartesian product if join type is inner like. + // + // If there is no hint or the hints are not applicable, we follow these rules one by one: + // 1. Pick broadcast nested loop join if one side is small enough to broadcast. If both + // sides are small, choose the smaller side (based on stats) to broadcast. + // 2. Pick cartesian product if join type is inner like. + // 3. Pick broadcast nested loop join as the final solution. It may OOM but we don't have + // other choice. case logical.Join(left, right, joinType, condition, hint) => - val buildSide = broadcastSide( - hint.leftHint.exists(_.strategy.contains(BROADCAST)), - hint.rightHint.exists(_.strategy.contains(BROADCAST)), left, right) - // This join could be very slow or OOM - joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + def createBroadcastNLJoin(buildSide: BuildSide) = { + Seq(joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition)) + } - // --- Cases where this strategy does not apply --------------------------------------------- + broadcastSideByHints(joinType, left, right, hint).map(createBroadcastNLJoin).getOrElse { + if (joinType.isInstanceOf[InnerLike] && hintToShuffleReplicateNL(hint)) { + createCartesianProduct(left, right, condition) + } else { + createJoinWithoutHint() + } + } - case _ => Nil + def createJoinWithoutHint() = { + broadcastSideBySizes(joinType, left, right).map(createBroadcastNLJoin).getOrElse { + if (joinType.isInstanceOf[InnerLike]) { + createCartesianProduct(left, right, condition) + } else { + createFinalBroadcastNLJoin(left, right, joinType, condition) + } + } + } } } From 6bd7f5692d5c15774d4b5d7657ffa6b2d2ca197f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 9 Apr 2019 10:29:16 +0800 Subject: [PATCH 06/12] fix --- .../sql/catalyst/plans/logical/hints.scala | 7 +-- .../spark/sql/execution/SparkStrategies.scala | 49 ++++++++++--------- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index 007ba483d411..5b6a6fcde190 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -79,12 +79,7 @@ case class HintInfo(strategy: Option[JoinStrategyHint] = None) { } override def toString: String = { - val hints = scala.collection.mutable.ArrayBuffer.empty[String] - if (strategy.isDefined) { - hints += s"strategy=${strategy.get}" - } - - if (hints.isEmpty) "none" else hints.mkString("(", ", ", ")") + strategy.map(s => s"strategy=$s").getOrElse("none") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3313ec289163..cf968132be10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -292,6 +292,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) } + def createJoinWithoutHint() = { + broadcastSideBySizes(joinType, left, right).map(createBroadcastHashJoin).getOrElse { + val shuffleHashBuildSide = shuffleHashSideBySizes(joinType, left, right) + if (!conf.preferSortMergeJoin && shuffleHashBuildSide.isDefined) { + createShuffleHashJoin(shuffleHashBuildSide.get) + } else if (RowOrdering.isOrderable(leftKeys)) { + createSortMergeJoin() + } else if (joinType.isInstanceOf[InnerLike]) { + createCartesianProduct(left, right, condition) + } else { + createFinalBroadcastNLJoin(left, right, joinType, condition) + } + } + } + broadcastSideByHints(joinType, left, right, hint).map(createBroadcastHashJoin).getOrElse { if (RowOrdering.isOrderable(leftKeys) && hintToSortMergeJoin(hint)) { createSortMergeJoin() @@ -308,21 +323,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - def createJoinWithoutHint() = { - broadcastSideBySizes(joinType, left, right).map(createBroadcastHashJoin).getOrElse { - val shuffleHashBuildSide = shuffleHashSideBySizes(joinType, left, right) - if (!conf.preferSortMergeJoin && shuffleHashBuildSide.isDefined) { - createShuffleHashJoin(shuffleHashBuildSide.get) - } else if (RowOrdering.isOrderable(leftKeys)) { - createSortMergeJoin() - } else if (joinType.isInstanceOf[InnerLike]) { - createCartesianProduct(left, right, condition) - } else { - createFinalBroadcastNLJoin(left, right, joinType, condition) - } - } - } - // If it is not an equi-join, we first look at the join hints w.r.t. the following order: // 1. broadcast hint: pick broadcast nested loop join. If both sides have the broadcast // hints, choose the smaller side (based on stats) to broadcast. @@ -340,14 +340,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { planLater(left), planLater(right), buildSide, joinType, condition)) } - broadcastSideByHints(joinType, left, right, hint).map(createBroadcastNLJoin).getOrElse { - if (joinType.isInstanceOf[InnerLike] && hintToShuffleReplicateNL(hint)) { - createCartesianProduct(left, right, condition) - } else { - createJoinWithoutHint() - } - } - def createJoinWithoutHint() = { broadcastSideBySizes(joinType, left, right).map(createBroadcastNLJoin).getOrElse { if (joinType.isInstanceOf[InnerLike]) { @@ -357,6 +349,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } } + + broadcastSideByHints(joinType, left, right, hint).map(createBroadcastNLJoin).getOrElse { + if (joinType.isInstanceOf[InnerLike] && hintToShuffleReplicateNL(hint)) { + createCartesianProduct(left, right, condition) + } else { + createJoinWithoutHint() + } + } + + // --- Cases where this strategy does not apply --------------------------------------------- + case _ => Nil } } From c535d3680687049e4d3d92d8f11d17df6e2e1944 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Mon, 8 Apr 2019 23:05:12 -0500 Subject: [PATCH 07/12] add hint event handling --- .../sql/catalyst/analysis/ResolveHints.scala | 37 ++++++-- .../optimizer/EliminateResolvedHint.scala | 53 +++++++++--- .../sql/catalyst/plans/logical/hints.scala | 35 ++++---- .../org/apache/spark/sql/JoinHintSuite.scala | 86 +++++++++++++++---- 4 files changed, 160 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 75dbdbc70323..9440a3f806b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale +import scala.collection.mutable + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.IntegerLiteral import org.apache.spark.sql.catalyst.plans.logical._ @@ -62,23 +64,27 @@ object ResolveHints { private def applyJoinStrategyHint( plan: LogicalPlan, - relations: Set[String], + relations: mutable.HashSet[String], hintName: String): LogicalPlan = { // Whether to continue recursing down the tree var recurse = true val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { - case ResolvedHint(u: UnresolvedRelation, _) + case ResolvedHint(u: UnresolvedRelation, hint) if relations.exists(resolver(_, u.tableIdentifier.table)) => - ResolvedHint(u, createHintInfo(hintName)) - case ResolvedHint(r: SubqueryAlias, _) + relations.remove(u.tableIdentifier.table) + ResolvedHint(u, createHintInfo(hintName).merge(hint, handleOverriddenHintInfo)) + case ResolvedHint(r: SubqueryAlias, hint) if relations.exists(resolver(_, r.alias)) => - ResolvedHint(r, createHintInfo(hintName)) + relations.remove(r.alias) + ResolvedHint(r, createHintInfo(hintName).merge(hint, handleOverriddenHintInfo)) case u: UnresolvedRelation if relations.exists(resolver(_, u.tableIdentifier.table)) => + relations.remove(u.tableIdentifier.table) ResolvedHint(plan, createHintInfo(hintName)) case r: SubqueryAlias if relations.exists(resolver(_, r.alias)) => + relations.remove(r.alias) ResolvedHint(plan, createHintInfo(hintName)) case _: ResolvedHint | _: View | _: With | _: SubqueryAlias => @@ -103,6 +109,10 @@ object ResolveHints { } } + private def handleOverriddenHintInfo(hint: HintInfo): Unit = { + logWarning(s"Join hint $hint is overridden by another hint and will not take effect.") + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { case h: UnresolvedHint if STRATEGY_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => if (h.parameters.isEmpty) { @@ -110,12 +120,21 @@ object ResolveHints { ResolvedHint(h.child, createHintInfo(h.name)) } else { // Otherwise, find within the subtree query plans to apply the hint. - applyJoinStrategyHint(h.child, h.parameters.map { + val relationNames = h.parameters.map { case tableName: String => tableName case tableId: UnresolvedAttribute => tableId.name case unsupported => throw new AnalysisException("Join strategy hint parameter " + s"should be an identifier or string but was $unsupported (${unsupported.getClass}") - }.toSet, h.name) + } + val relationNameSet = new mutable.HashSet[String] + relationNames.foreach(relationNameSet.add) + + val applied = applyJoinStrategyHint(h.child, relationNameSet, h.name) + relationNameSet.foreach { n => + logWarning(s"Count not find relation '$n' for join strategy hint " + + s"'${h.name}${relationNames.mkString("(", ", ", ")")}'.") + } + applied } } } @@ -152,7 +171,9 @@ object ResolveHints { */ object RemoveAllHints extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { - case h: UnresolvedHint => h.child + case h: UnresolvedHint => + logWarning(s"Unrecognized hint: ${h.name}${h.parameters.mkString("(", ", ", ")")}") + h.child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala index 78383c25e0a1..5586690520c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala @@ -30,29 +30,58 @@ object EliminateResolvedHint extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { val pulledUp = plan transformUp { case j: Join => - val leftHint = mergeHints(collectHints(j.left)) - val rightHint = mergeHints(collectHints(j.right)) - j.copy(hint = JoinHint(leftHint, rightHint)) + val (newLeft, leftHints) = extractHintsFromPlan(j.left) + val (newRight, rightHints) = extractHintsFromPlan(j.right) + val newJoinHint = JoinHint(mergeHints(leftHints), mergeHints(rightHints)) + j.copy(left = newLeft, right = newRight, hint = newJoinHint) } pulledUp.transformUp { - case h: ResolvedHint => h.child + case h: ResolvedHint => + handleInvalidHintInfo(h.hints) + h.child } } + /** + * Combine a list of [[HintInfo]]s into one [[HintInfo]]. + */ private def mergeHints(hints: Seq[HintInfo]): Option[HintInfo] = { - hints.reduceOption((h1, h2) => h1.merge(h2)) + hints.reduceOption((h1, h2) => h1.merge(h2, handleOverriddenHintInfo)) } - private def collectHints(plan: LogicalPlan): Seq[HintInfo] = { + /** + * Extract all hints from the plan, returning a list of extracted hints and the transformed plan + * with [[ResolvedHint]] nodes removed. The returned hint list comes in top-down order. + * Note that hints can only be extracted from under certain nodes. Those that cannot be extracted + * in this method will be cleaned up later by this rule, and may emit warnings depending on the + * configurations. + */ + private def extractHintsFromPlan(plan: LogicalPlan): (LogicalPlan, Seq[HintInfo]) = { plan match { - case h: ResolvedHint => h.hints +: collectHints(h.child) - case u: UnaryNode => collectHints(u.child) + case h: ResolvedHint => + val (plan, hints) = extractHintsFromPlan(h.child) + (plan, h.hints +: hints) + case u: UnaryNode => + val (plan, hints) = extractHintsFromPlan(u.child) + (u.withNewChildren(Seq(plan)), hints) // TODO revisit this logic: // except and intersect are semi/anti-joins which won't return more data then - // their left argument, so the join strategy hint should be propagated here - case i: Intersect => collectHints(i.left) - case e: Except => collectHints(e.left) - case _ => Seq.empty + // their left argument, so the broadcast hint should be propagated here + case i: Intersect => + val (plan, hints) = extractHintsFromPlan(i.left) + (i.copy(left = plan), hints) + case e: Except => + val (plan, hints) = extractHintsFromPlan(e.left) + (e.copy(left = plan), hints) + case p: LogicalPlan => (p, Seq.empty) } } + + private def handleInvalidHintInfo(hint: HintInfo): Unit = { + logWarning(s"A join hint $hint is specified but it is not part of a join relation.") + } + + private def handleOverriddenHintInfo(hint: HintInfo): Unit = { + logWarning(s"Join hint $hint is overridden by another hint and will not take effect.") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index 007ba483d411..00a39b3133a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -71,21 +71,26 @@ object JoinHint { case class HintInfo(strategy: Option[JoinStrategyHint] = None) { /** - * Combine two [[HintInfo]]s into one [[HintInfo]], in which the new strategy will the strategy - * in this [[HintInfo]] if defined, otherwise the strategy in the other [[HintInfo]]. + * Combine this [[HintInfo]] with another [[HintInfo]] and return the new [[HintInfo]]. + * @param other the other [[HintInfo]] + * @param hintOverriddenCallback a callback to notify if any [[HintInfo]] has been overridden + * in this merge. + * + * Currently, for join strategy hints, the new [[HintInfo]] will contain the strategy in this + * [[HintInfo]] if defined, otherwise the strategy in the other [[HintInfo]]. The + * `hintOverriddenCallback` will be called if this [[HintInfo]] and the other [[HintInfo]] + * both have a strategy defined but the join strategies are different. */ - def merge(other: HintInfo): HintInfo = { + def merge(other: HintInfo, hintOverriddenCallback: HintInfo => Unit): HintInfo = { + if (this.strategy.isDefined && + other.strategy.isDefined && + this.strategy.get != other.strategy.get) { + hintOverriddenCallback(other) + } HintInfo(strategy = this.strategy.orElse(other.strategy)) } - override def toString: String = { - val hints = scala.collection.mutable.ArrayBuffer.empty[String] - if (strategy.isDefined) { - hints += s"strategy=${strategy.get}" - } - - if (hints.isEmpty) "none" else hints.mkString("(", ", ", ")") - } + override def toString: String = if (strategy.isDefined) s"(strategy=${strategy.get})" else "none" } sealed abstract class JoinStrategyHint { @@ -117,7 +122,7 @@ object JoinStrategyHint { * equi-join keys. */ case object BROADCAST extends JoinStrategyHint { - override def displayName: String = "broadcast-hash" + override def displayName: String = "broadcast" override def hintAliases: Set[String] = Set( "BROADCAST", "BROADCASTJOIN", @@ -128,7 +133,7 @@ case object BROADCAST extends JoinStrategyHint { * The hint for shuffle sort merge join. */ case object SHUFFLE_MERGE extends JoinStrategyHint { - override def displayName: String = "shuffle-merge" + override def displayName: String = "merge" override def hintAliases: Set[String] = Set( "SHUFFLE_MERGE", "MERGE", @@ -139,7 +144,7 @@ case object SHUFFLE_MERGE extends JoinStrategyHint { * The hint for shuffle hash join. */ case object SHUFFLE_HASH extends JoinStrategyHint { - override def displayName: String = "shuffle-hash" + override def displayName: String = "shuffle_hash" override def hintAliases: Set[String] = Set( "SHUFFLE_HASH") } @@ -148,7 +153,7 @@ case object SHUFFLE_HASH extends JoinStrategyHint { * The hint for shuffle-and-replicate nested loop join, a.k.a. cartesian product join. */ case object SHUFFLE_REPLICATE_NL extends JoinStrategyHint { - override def displayName: String = "shuffle-replicate-nested-loop" + override def displayName: String = "shuffle_replicate_nl" override def hintAliases: Set[String] = Set( "SHUFFLE_REPLICATE_NL") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala index af6f91bbbdae..9c2dc0c62b2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -17,6 +17,11 @@ package org.apache.spark.sql +import scala.collection.mutable.ArrayBuffer + +import org.apache.log4j.{AppenderSkeleton, Level} +import org.apache.log4j.spi.LoggingEvent + import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.joins._ @@ -31,6 +36,41 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { lazy val df2 = df.selectExpr("id as b1", "id as b2") lazy val df3 = df.selectExpr("id as c1", "id as c2") + class MockAppender extends AppenderSkeleton { + val loggingEvents = new ArrayBuffer[LoggingEvent]() + + override def append(loggingEvent: LoggingEvent): Unit = loggingEvents.append(loggingEvent) + override def close(): Unit = {} + override def requiresLayout(): Boolean = false + } + + def msgNoHintRelationFound(relation: String, hint: String): String = + s"Count not find relation '$relation' for join strategy hint '$hint'." + + def msgNoJoinForJoinHint(strategy: String): String = + s"A join hint (strategy=$strategy) is specified but it is not part of a join relation." + + def msgJoinHintOverridden(strategy: String): String = + s"Join hint (strategy=$strategy) is overridden by another hint and will not take effect." + + def verifyJoinHintWithWarnings( + df: => DataFrame, + expectedHints: Seq[JoinHint], + warnings: Seq[String]): Unit = { + val logAppender = new MockAppender() + withLogAppender(logAppender) { + verifyJoinHint(df, expectedHints) + } + val warningMessages = logAppender.loggingEvents + .filter(_.getLevel == Level.WARN) + .map(_.getRenderedMessage) + .filter(_.contains("hint")) + assert(warningMessages.size == warnings.size) + warnings.foreach { w => + assert(warningMessages.contains(w)) + } + } + def verifyJoinHint(df: DataFrame, expectedHints: Seq[JoinHint]): Unit = { val optimized = df.queryExecution.optimizedPlan val joinHints = optimized collect { @@ -179,25 +219,29 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { } test("hint merge") { - verifyJoinHint( + verifyJoinHintWithWarnings( df.hint("broadcast").filter('id > 2).hint("broadcast").join(df, "id"), JoinHint( Some(HintInfo(strategy = Some(BROADCAST))), - None) :: Nil + None) :: Nil, + Nil ) - verifyJoinHint( + verifyJoinHintWithWarnings( df.join(df.hint("broadcast").limit(2).hint("broadcast"), "id"), JoinHint( None, - Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil + Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil, + Nil ) - verifyJoinHint( - df.hint("merge").filter('id > 2).hint("shuffle_hash").join(df, "id"), + verifyJoinHintWithWarnings( + df.hint("merge").filter('id > 2).hint("shuffle_hash").join(df, "id").hint("broadcast"), JoinHint( Some(HintInfo(strategy = Some(SHUFFLE_HASH))), - None) :: Nil + None) :: Nil, + msgJoinHintOverridden("merge") :: + msgNoJoinForJoinHint("broadcast") :: Nil ) - verifyJoinHint( + verifyJoinHintWithWarnings( df.join(df.hint("broadcast").limit(2).hint("merge"), "id") .hint("shuffle_hash") .hint("shuffle_replicate_nl") @@ -207,7 +251,9 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { None) :: JoinHint( None, - Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) :: Nil + Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) :: Nil, + msgJoinHintOverridden("broadcast") :: + msgJoinHintOverridden("shuffle_hash") :: Nil ) } @@ -216,25 +262,30 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { df1.createOrReplaceTempView("a") df2.createOrReplaceTempView("b") df3.createOrReplaceTempView("c") - verifyJoinHint( - sql("select /*+ merge(a, c) broadcast(a, b)*/ * from a, b, c " + + verifyJoinHintWithWarnings( + sql("select /*+ shuffle_hash merge(a, c) broadcast(a, b)*/ * from a, b, c " + "where a.a1 = b.b1 and b.b1 = c.c1"), JoinHint( None, Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) :: JoinHint( Some(HintInfo(strategy = Some(SHUFFLE_MERGE))), - Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil + Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil, + msgNoJoinForJoinHint("shuffle_hash") :: + msgJoinHintOverridden("broadcast") :: Nil ) - verifyJoinHint( + verifyJoinHintWithWarnings( sql("select /*+ shuffle_hash(a, b) merge(b, d) broadcast(b)*/ * from a, b, c " + "where a.a1 = b.b1 and b.b1 = c.c1"), JoinHint.NONE :: JoinHint( Some(HintInfo(strategy = Some(SHUFFLE_HASH))), - Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil + Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil, + msgNoHintRelationFound("d", "merge(b, d)") :: + msgJoinHintOverridden("broadcast") :: + msgJoinHintOverridden("merge") :: Nil ) - verifyJoinHint( + verifyJoinHintWithWarnings( sql( """ |select /*+ broadcast(a, c) merge(a, d)*/ * from a @@ -249,7 +300,10 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) :: JoinHint( Some(HintInfo(strategy = Some(SHUFFLE_REPLICATE_NL))), - Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil + Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil, + msgNoHintRelationFound("c", "broadcast(a, c)") :: + msgJoinHintOverridden("merge") :: + msgJoinHintOverridden("shuffle_replicate_nl") :: Nil ) } } From 7342fbdadc55f15c0a2529299715e639aa3d84f7 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Tue, 9 Apr 2019 16:42:34 -0500 Subject: [PATCH 08/12] add more tests --- .../catalyst/analysis/ResolveHintsSuite.scala | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 0429e7c0681b..474e58a335e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -17,6 +17,11 @@ package org.apache.spark.sql.catalyst.analysis +import scala.collection.mutable.ArrayBuffer + +import org.apache.log4j.{AppenderSkeleton, Level} +import org.apache.log4j.spi.LoggingEvent + import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal @@ -27,6 +32,14 @@ import org.apache.spark.sql.catalyst.plans.logical._ class ResolveHintsSuite extends AnalysisTest { import org.apache.spark.sql.catalyst.analysis.TestRelations._ + class MockAppender extends AppenderSkeleton { + val loggingEvents = new ArrayBuffer[LoggingEvent]() + + override def append(loggingEvent: LoggingEvent): Unit = loggingEvents.append(loggingEvent) + override def close(): Unit = {} + override def requiresLayout(): Boolean = false + } + test("invalid hints should be ignored") { checkAnalysis( UnresolvedHint("some_random_hint_that_does_not_exist", Seq("TaBlE"), table("TaBlE")), @@ -156,4 +169,17 @@ class ResolveHintsSuite extends AnalysisTest { UnresolvedHint("REPARTITION", Seq(Literal(true)), table("TaBlE")), Seq(errMsgRepa)) } + + test("log warnings for invalid hints") { + val logAppender = new MockAppender() + withLogAppender(logAppender) { + checkAnalysis( + UnresolvedHint("unknown_hint", Seq("TaBlE"), table("TaBlE")), + testRelation, + caseSensitive = false) + } + assert(logAppender.loggingEvents.exists( + e => e.getLevel == Level.WARN && + e.getRenderedMessage.contains("Unrecognized hint: unknown_hint"))) + } } From 091299718fbe39fd863aa823e776f31466dadd36 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Apr 2019 14:01:03 +0800 Subject: [PATCH 09/12] fix tests --- .../spark/sql/execution/SparkStrategies.scala | 200 +++++++++--------- .../execution/joins/BroadcastJoinSuite.scala | 10 +- 2 files changed, 102 insertions(+), 108 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index cf968132be10..42252ee78c84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -182,68 +182,37 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - private def hintToSortMergeJoin(hint: JoinHint): Boolean = { - hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) || - hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE)) + private def hintToBroadcastLeft(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(BROADCAST)) } - private def hintToShuffleReplicateNL(hint: JoinHint): Boolean = { - hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) || - hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) + private def hintToBroadcastRight(hint: JoinHint): Boolean = { + hint.rightHint.exists(_.strategy.contains(BROADCAST)) } - private def broadcastSideByHints( - joinType: JoinType, - left: LogicalPlan, - right: LogicalPlan, - hint: JoinHint): Option[BuildSide] = { - val wantToBuildLeft = - canBuildLeft(joinType) && hint.leftHint.exists(_.strategy.contains(BROADCAST)) - val wantToBuildRight = - canBuildRight(joinType) && hint.rightHint.exists(_.strategy.contains(BROADCAST)) - getBuildSide(wantToBuildLeft, wantToBuildRight, left, right) - } - - private def shuffleHashSideByHints( - joinType: JoinType, - left: LogicalPlan, - right: LogicalPlan, - hint: JoinHint): Option[BuildSide] = { - val wantToBuildLeft = - canBuildLeft(joinType) && hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH)) - val wantToBuildRight = - canBuildRight(joinType) && hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH)) - getBuildSide(wantToBuildLeft, wantToBuildRight, left, right) + private def hintToShuffleHashLeft(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH)) } - private def broadcastSideBySizes( - joinType: JoinType, - left: LogicalPlan, - right: LogicalPlan): Option[BuildSide] = { - val wantToBuildLeft = canBuildLeft(joinType) && canBroadcast(left) - val wantToBuildRight = canBuildRight(joinType) && canBroadcast(right) - getBuildSide(wantToBuildLeft, wantToBuildRight, left, right) + private def hintToShuffleHashRight(hint: JoinHint): Boolean = { + hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH)) } - private def shuffleHashSideBySizes( - joinType: JoinType, - left: LogicalPlan, - right: LogicalPlan): Option[BuildSide] = { - val wantToBuildLeft = - canBuildLeft(joinType) && canBuildLocalHashMap(left) && muchSmaller(left, right) - val wantToBuildRight = - canBuildRight(joinType) && canBuildLocalHashMap(right) && muchSmaller(right, left) - getBuildSide(wantToBuildLeft, wantToBuildRight, left, right) + private def hintToSortMergeJoin(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) || + hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE)) } - private def createCartesianProduct( - left: LogicalPlan, - right: LogicalPlan, - condition: Option[Expression]) = { - Seq(joins.CartesianProductExec(planLater(left), planLater(right), condition)) + private def hintToShuffleReplicateNL(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) || + hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) } - private def createFinalBroadcastNLJoin( + /** + * Create BroadcastNestedLoopJoinExec forcibly as the final solution, when no other join + * strategy is applicable. + */ + private def createForcibleBroadcastNLJoin( left: LogicalPlan, right: LogicalPlan, joinType: JoinType, @@ -258,10 +227,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // If it is an equi-join, we first look at the join hints w.r.t. the following order: - // 1. broadcast hint: pick broadcast hash join if join type is not full outer. If both sides + // 1. broadcast hint: pick broadcast hash join if the join type is supported. If both sides // have the broadcast hints, choose the smaller side (based on stats) to broadcast. // 2. sort merge hint: pick sort merge join if join keys are sortable. - // 3. shuffle hash hint: We pick shuffle hash join if join type is not full outer. If both + // 3. shuffle hash hint: We pick shuffle hash join if the join type is supported. If both // sides have the shuffle hash hints, choose the smaller side (based on stats) as the // build side. // 4. shuffle replicate NL hint: pick cartesian product if join type is inner like. @@ -270,59 +239,82 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // 1. Pick broadcast hash join if one side is small enough to broadcast, and the join type // is not full outer. If both sides are small, choose the smaller side (based on stats) // to broadcast. - // 2. Pick sort merge join if the join keys are sortable. - // 3. Pick shuffle hash join if one side is small enough to build local hash map, and is - // much smaller than the other side. + // 2. Pick shuffle hash join if one side is small enough to build local hash map, and is + // much smaller than the other side, and `spark.sql.join.preferSortMergeJoin` is false. + // 3. Pick sort merge join if the join keys are sortable. // 4. Pick cartesian product if join type is inner like. // 5. Pick broadcast nested loop join as the final solution. It may OOM but we don't have // other choice. case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) => - def createBroadcastHashJoin(buildSide: BuildSide) = { - Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + def createBroadcastHashJoin(buildLeft: Boolean, buildRight: Boolean) = { + val wantToBuildLeft = canBuildLeft(joinType) && buildLeft + val wantToBuildRight = canBuildRight(joinType) && buildRight + getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map { buildSide => + Seq(joins.BroadcastHashJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + condition, + planLater(left), + planLater(right))) + } } - def createSortMergeJoin() = { - Seq(joins.SortMergeJoinExec( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right))) + def createShuffleHashJoin(buildLeft: Boolean, buildRight: Boolean) = { + val wantToBuildLeft = canBuildLeft(joinType) && buildLeft + val wantToBuildRight = canBuildRight(joinType) && buildRight + getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map { buildSide => + Seq(joins.ShuffledHashJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + condition, + planLater(left), + planLater(right))) + } } - def createShuffleHashJoin(buildSide: BuildSide) = { - Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + def createSortMergeJoin() = { + if (RowOrdering.isOrderable(leftKeys)) { + Some(Seq(joins.SortMergeJoinExec( + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)))) + } else { + None + } } - def createJoinWithoutHint() = { - broadcastSideBySizes(joinType, left, right).map(createBroadcastHashJoin).getOrElse { - val shuffleHashBuildSide = shuffleHashSideBySizes(joinType, left, right) - if (!conf.preferSortMergeJoin && shuffleHashBuildSide.isDefined) { - createShuffleHashJoin(shuffleHashBuildSide.get) - } else if (RowOrdering.isOrderable(leftKeys)) { - createSortMergeJoin() - } else if (joinType.isInstanceOf[InnerLike]) { - createCartesianProduct(left, right, condition) - } else { - createFinalBroadcastNLJoin(left, right, joinType, condition) - } + def createCartesianProduct() = { + if (joinType.isInstanceOf[InnerLike]) { + Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), condition))) + } else { + None } } - broadcastSideByHints(joinType, left, right, hint).map(createBroadcastHashJoin).getOrElse { - if (RowOrdering.isOrderable(leftKeys) && hintToSortMergeJoin(hint)) { - createSortMergeJoin() - } else { - shuffleHashSideByHints(joinType, left, right, hint).map { side => - createShuffleHashJoin(side) - }.getOrElse { - if (joinType.isInstanceOf[InnerLike] && hintToShuffleReplicateNL(hint)) { - createCartesianProduct(left, right, condition) + def createJoinWithoutHint() = { + createBroadcastHashJoin(canBroadcast(left), canBroadcast(right)) + .orElse { + if (!conf.preferSortMergeJoin) { + createShuffleHashJoin( + canBuildLocalHashMap(left) && muchSmaller(left, right), + canBuildLocalHashMap(right) && muchSmaller(right, left)) } else { - createJoinWithoutHint() + None } } - } + .orElse(createSortMergeJoin()) + .orElse(createCartesianProduct()) + .getOrElse(createForcibleBroadcastNLJoin(left, right, joinType, condition)) } + createBroadcastHashJoin(hintToBroadcastLeft(hint), hintToBroadcastRight(hint)) + .orElse { if (hintToSortMergeJoin(hint)) createSortMergeJoin() else None } + .orElse(createShuffleHashJoin(hintToShuffleHashLeft(hint), hintToShuffleHashRight(hint))) + .orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None } + .getOrElse(createJoinWithoutHint()) + // If it is not an equi-join, we first look at the join hints w.r.t. the following order: // 1. broadcast hint: pick broadcast nested loop join. If both sides have the broadcast // hints, choose the smaller side (based on stats) to broadcast. @@ -335,29 +327,31 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // 3. Pick broadcast nested loop join as the final solution. It may OOM but we don't have // other choice. case logical.Join(left, right, joinType, condition, hint) => - def createBroadcastNLJoin(buildSide: BuildSide) = { - Seq(joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), buildSide, joinType, condition)) - } - - def createJoinWithoutHint() = { - broadcastSideBySizes(joinType, left, right).map(createBroadcastNLJoin).getOrElse { - if (joinType.isInstanceOf[InnerLike]) { - createCartesianProduct(left, right, condition) - } else { - createFinalBroadcastNLJoin(left, right, joinType, condition) - } + def createBroadcastNLJoin(buildLeft: Boolean, buildRight: Boolean) = { + getBuildSide(buildLeft, buildRight, left, right).map { buildSide => + Seq(joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition)) } } - broadcastSideByHints(joinType, left, right, hint).map(createBroadcastNLJoin).getOrElse { - if (joinType.isInstanceOf[InnerLike] && hintToShuffleReplicateNL(hint)) { - createCartesianProduct(left, right, condition) + def createCartesianProduct() = { + if (joinType.isInstanceOf[InnerLike]) { + Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), condition))) } else { - createJoinWithoutHint() + None } } + def createJoinWithoutHint() = { + createBroadcastNLJoin(canBroadcast(left), canBroadcast(right)) + .orElse(createCartesianProduct()) + .getOrElse(createForcibleBroadcastNLJoin(left, right, joinType, condition)) + } + + createBroadcastNLJoin(hintToBroadcastLeft(hint), hintToBroadcastRight(hint)) + .orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None } + .getOrElse(createJoinWithoutHint()) + // --- Cases where this strategy does not apply --------------------------------------------- case _ => Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 05c583c80e50..a413c2838a35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -301,10 +301,10 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { // INNER JOIN && t1Size < t2Size => BuildLeft assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2", bl, BuildLeft) - // FULL JOIN && t1Size < t2Size => BuildLeft + // FULL JOIN && no join key && t1Size < t2Size => BuildLeft assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 FULL JOIN t2", bl, BuildLeft) - // LEFT JOIN => BuildRight - assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 LEFT JOIN t2", bl, BuildRight) + // LEFT JOIN && no join key && t1Size < t2Size => BuildLeft + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 LEFT JOIN t2", bl, BuildLeft) // RIGHT JOIN => BuildLeft assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 RIGHT JOIN t2", bl, BuildLeft) // INNER JOIN && broadcast(t1) => BuildLeft @@ -345,11 +345,11 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { assertJoinBuildSide("SELECT * FROM t1 FULL OUTER JOIN t2", bl, BuildLeft) assertJoinBuildSide("SELECT * FROM t2 FULL OUTER JOIN t1", bl, BuildRight) - assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2", bl, BuildRight) + assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2", bl, BuildLeft) assertJoinBuildSide("SELECT * FROM t2 LEFT JOIN t1", bl, BuildRight) assertJoinBuildSide("SELECT * FROM t1 RIGHT JOIN t2", bl, BuildLeft) - assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1", bl, BuildLeft) + assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1", bl, BuildRight) } } } From c0b217c60e30829cd9eaba76f98ff43eb71756d7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Apr 2019 20:05:08 +0800 Subject: [PATCH 10/12] fix test --- .../src/test/scala/org/apache/spark/sql/JoinHintSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala index 9c2dc0c62b2f..a755bf8bbd90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -411,7 +411,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { assertBroadcastHashJoin( sql(equiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil)), BuildLeft) assertBroadcastNLJoin( - sql(nonEquiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil, "left")), BuildRight) + sql(nonEquiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil, "left")), BuildLeft) assertBroadcastNLJoin( sql(nonEquiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil, "right")), BuildLeft) @@ -472,7 +472,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { // Shuffle-merge hint specified but not doable assertBroadcastNLJoin( - sql(nonEquiJoinQueryWithHint("MERGE(t1, t2)" :: Nil, "left")), BuildRight) + sql(nonEquiJoinQueryWithHint("MERGE(t1, t2)" :: Nil, "left")), BuildLeft) } } } From 4a48286f4a17422e22669083a4d9135fa98312ea Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 11 Apr 2019 13:53:07 +0800 Subject: [PATCH 11/12] fix behaviors --- .../spark/sql/execution/SparkStrategies.scala | 35 +++++++++++++++++-- .../org/apache/spark/sql/JoinHintSuite.scala | 4 +-- .../execution/joins/BroadcastJoinSuite.scala | 10 +++--- 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 42252ee78c84..c66ba14bfecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -348,9 +348,38 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { .getOrElse(createForcibleBroadcastNLJoin(left, right, joinType, condition)) } - createBroadcastNLJoin(hintToBroadcastLeft(hint), hintToBroadcastRight(hint)) - .orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None } - .getOrElse(createJoinWithoutHint()) + if (joinType.isInstanceOf[InnerLike] || joinType == FullOuter) { + createBroadcastNLJoin(hintToBroadcastLeft(hint), hintToBroadcastRight(hint)) + .orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None } + .getOrElse(createJoinWithoutHint()) + } else { + val smallerSide = + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft + val buildSide = if (canBuildLeft(joinType)) { + // For RIGHT JOIN, we may broadcast left side even if the hint asks us to broadcast + // the right side. This is for history reasons. + if (hintToBroadcastLeft(hint) || canBroadcast(left)) { + BuildLeft + } else if (hintToBroadcastRight(hint)) { + BuildRight + } else { + smallerSide + } + } else { + // For LEFT JOIN, we may broadcast right side even if the hint asks us to broadcast + // the left side. This is for history reasons. + if (hintToBroadcastRight(hint) || canBroadcast(right)) { + BuildRight + } else if (hintToBroadcastLeft(hint)) { + BuildLeft + } else { + smallerSide + } + } + Seq(joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition)) + } + // --- Cases where this strategy does not apply --------------------------------------------- case _ => Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala index a755bf8bbd90..9c2dc0c62b2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -411,7 +411,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { assertBroadcastHashJoin( sql(equiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil)), BuildLeft) assertBroadcastNLJoin( - sql(nonEquiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil, "left")), BuildLeft) + sql(nonEquiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil, "left")), BuildRight) assertBroadcastNLJoin( sql(nonEquiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil, "right")), BuildLeft) @@ -472,7 +472,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { // Shuffle-merge hint specified but not doable assertBroadcastNLJoin( - sql(nonEquiJoinQueryWithHint("MERGE(t1, t2)" :: Nil, "left")), BuildLeft) + sql(nonEquiJoinQueryWithHint("MERGE(t1, t2)" :: Nil, "left")), BuildRight) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index a413c2838a35..05c583c80e50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -301,10 +301,10 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { // INNER JOIN && t1Size < t2Size => BuildLeft assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2", bl, BuildLeft) - // FULL JOIN && no join key && t1Size < t2Size => BuildLeft + // FULL JOIN && t1Size < t2Size => BuildLeft assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 FULL JOIN t2", bl, BuildLeft) - // LEFT JOIN && no join key && t1Size < t2Size => BuildLeft - assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 LEFT JOIN t2", bl, BuildLeft) + // LEFT JOIN => BuildRight + assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 LEFT JOIN t2", bl, BuildRight) // RIGHT JOIN => BuildLeft assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 RIGHT JOIN t2", bl, BuildLeft) // INNER JOIN && broadcast(t1) => BuildLeft @@ -345,11 +345,11 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { assertJoinBuildSide("SELECT * FROM t1 FULL OUTER JOIN t2", bl, BuildLeft) assertJoinBuildSide("SELECT * FROM t2 FULL OUTER JOIN t1", bl, BuildRight) - assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2", bl, BuildLeft) + assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2", bl, BuildRight) assertJoinBuildSide("SELECT * FROM t2 LEFT JOIN t1", bl, BuildRight) assertJoinBuildSide("SELECT * FROM t1 RIGHT JOIN t2", bl, BuildLeft) - assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1", bl, BuildRight) + assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1", bl, BuildLeft) } } } From a9634c4b0e0d0a194bc4bfc3441528e7c4f836ae Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 11 Apr 2019 14:26:42 +0800 Subject: [PATCH 12/12] code cleanup --- .../spark/sql/execution/SparkStrategies.scala | 56 ++++++++----------- 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c66ba14bfecd..efd05a3e2b3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -168,11 +168,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if (wantToBuildLeft && wantToBuildRight) { // returns the smaller side base on its estimated physical size, if we want to build the // both sides. - if (right.stats.sizeInBytes <= left.stats.sizeInBytes) { - Some(BuildRight) - } else { - Some(BuildLeft) - } + Some(getSmallerSide(left, right)) } else if (wantToBuildLeft) { Some(BuildLeft) } else if (wantToBuildRight) { @@ -182,6 +178,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + private def getSmallerSide(left: LogicalPlan, right: LogicalPlan) = { + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft + } + private def hintToBroadcastLeft(hint: JoinHint): Boolean = { hint.leftHint.exists(_.strategy.contains(BROADCAST)) } @@ -208,22 +208,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) } - /** - * Create BroadcastNestedLoopJoinExec forcibly as the final solution, when no other join - * strategy is applicable. - */ - private def createForcibleBroadcastNLJoin( - left: LogicalPlan, - right: LogicalPlan, - joinType: JoinType, - condition: Option[Expression]) = { - val smallerSide = - if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft - // This join could be very slow or OOM - Seq(joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), smallerSide, joinType, condition)) - } - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // If it is an equi-join, we first look at the join hints w.r.t. the following order: @@ -237,7 +221,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // // If there is no hint or the hints are not applicable, we follow these rules one by one: // 1. Pick broadcast hash join if one side is small enough to broadcast, and the join type - // is not full outer. If both sides are small, choose the smaller side (based on stats) + // is supported. If both sides are small, choose the smaller side (based on stats) // to broadcast. // 2. Pick shuffle hash join if one side is small enough to build local hash map, and is // much smaller than the other side, and `spark.sql.join.preferSortMergeJoin` is false. @@ -306,7 +290,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } .orElse(createSortMergeJoin()) .orElse(createCartesianProduct()) - .getOrElse(createForcibleBroadcastNLJoin(left, right, joinType, condition)) + .getOrElse { + // This join could be very slow or OOM + val buildSide = getSmallerSide(left, right) + Seq(joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition)) + } } createBroadcastHashJoin(hintToBroadcastLeft(hint), hintToBroadcastRight(hint)) @@ -321,11 +310,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // 2. shuffle replicate NL hint: pick cartesian product if join type is inner like. // // If there is no hint or the hints are not applicable, we follow these rules one by one: - // 1. Pick broadcast nested loop join if one side is small enough to broadcast. If both - // sides are small, choose the smaller side (based on stats) to broadcast. - // 2. Pick cartesian product if join type is inner like. - // 3. Pick broadcast nested loop join as the final solution. It may OOM but we don't have - // other choice. + // 1. Pick cartesian product if join type is inner like, and both sides are too big to + // to broadcast. + // 2. Pick broadcast nested loop join. Pick the smaller side (based on stats) to broadcast. case logical.Join(left, right, joinType, condition, hint) => def createBroadcastNLJoin(buildLeft: Boolean, buildRight: Boolean) = { getBuildSide(buildLeft, buildRight, left, right).map { buildSide => @@ -343,9 +330,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def createJoinWithoutHint() = { - createBroadcastNLJoin(canBroadcast(left), canBroadcast(right)) - .orElse(createCartesianProduct()) - .getOrElse(createForcibleBroadcastNLJoin(left, right, joinType, condition)) + (if (!canBroadcast(left) && !canBroadcast(right)) createCartesianProduct() else None) + .getOrElse { + // This join could be very slow or OOM + val buildSide = getSmallerSide(left, right) + Seq(joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition)) + } } if (joinType.isInstanceOf[InnerLike] || joinType == FullOuter) { @@ -353,8 +344,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { .orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None } .getOrElse(createJoinWithoutHint()) } else { - val smallerSide = - if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft + val smallerSide = getSmallerSide(left, right) val buildSide = if (canBuildLeft(joinType)) { // For RIGHT JOIN, we may broadcast left side even if the hint asks us to broadcast // the right side. This is for history reasons.