diff --git a/docs/sql-performance-tuning.md b/docs/sql-performance-tuning.md index 68569744afcc..2a1edda84252 100644 --- a/docs/sql-performance-tuning.md +++ b/docs/sql-performance-tuning.md @@ -107,14 +107,22 @@ that these options will be deprecated in future release as more optimizations ar -## Broadcast Hint for SQL Queries - -The `BROADCAST` hint guides Spark to broadcast each specified table when joining them with another table or view. -When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, -even if the statistics is above the configuration `spark.sql.autoBroadcastJoinThreshold`. -When both sides of a join are specified, Spark broadcasts the one having the lower statistics. -Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) -support BHJ. When the broadcast nested loop join is selected, we still respect the hint. +## Join Strategy Hints for SQL Queries + +The join strategy hints, namely `BROADCAST`, `MERGE`, `SHUFFLE_HASH` and `SHUFFLE_REPLICATE_NL`, +instruct Spark to use the hinted strategy on each specified relation when joining them with another +relation. For example, when the `BROADCAST` hint is used on table 't1', broadcast join (either +broadcast hash join or broadcast nested loop join depending on whether there is any equi-join key) +with 't1' as the build side will be prioritized by Spark even if the size of table 't1' suggested +by the statistics is above the configuration `spark.sql.autoBroadcastJoinThreshold`. + +When different join strategy hints are specified on both sides of a join, Spark prioritizes the +`BROADCAST` hint over the `MERGE` hint over the `SHUFFLE_HASH` hint over the `SHUFFLE_REPLICATE_NL` +hint. When both sides are specified with the `BROADCAST` hint or the `SHUFFLE_HASH` hint, Spark will +pick the build side based on the join type and the sizes of the relations. + +Note that there is no guarantee that Spark will choose the join strategy specified in the hint since +a specific strategy may not support all join types.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 01e40e64a3e8..02d83e7e8cb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -153,7 +153,7 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Hints", fixedPoint, - new ResolveHints.ResolveBroadcastHints(conf), + new ResolveHints.ResolveJoinStrategyHints(conf), ResolveHints.ResolveCoalesceHints, ResolveHints.RemoveAllHints), Batch("Simple Sanity Check", Once, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index dbd4ed845e32..9440a3f806b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale +import scala.collection.mutable + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.IntegerLiteral import org.apache.spark.sql.catalyst.plans.logical._ @@ -28,45 +30,66 @@ import org.apache.spark.sql.internal.SQLConf /** - * Collection of rules related to hints. The only hint currently available is broadcast join hint. + * Collection of rules related to hints. The only hint currently available is join strategy hint. * * Note that this is separately into two rules because in the future we might introduce new hint - * rules that have different ordering requirements from broadcast. + * rules that have different ordering requirements from join strategies. */ object ResolveHints { /** - * For broadcast hint, we accept "BROADCAST", "BROADCASTJOIN", and "MAPJOIN", and a sequence of - * relation aliases can be specified in the hint. A broadcast hint plan node will be inserted - * on top of any relation (that is not aliased differently), subquery, or common table expression - * that match the specified name. + * The list of allowed join strategy hints is defined in [[JoinStrategyHint.strategies]], and a + * sequence of relation aliases can be specified with a join strategy hint, e.g., "MERGE(a, c)", + * "BROADCAST(a)". A join strategy hint plan node will be inserted on top of any relation (that + * is not aliased differently), subquery, or common table expression that match the specified + * name. * * The hint resolution works by recursively traversing down the query plan to find a relation or - * subquery that matches one of the specified broadcast aliases. The traversal does not go past - * beyond any existing broadcast hints, subquery aliases. + * subquery that matches one of the specified relation aliases. The traversal does not go past + * beyond any view reference, with clause or subquery alias. * * This rule must happen before common table expressions. */ - class ResolveBroadcastHints(conf: SQLConf) extends Rule[LogicalPlan] { - private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") + class ResolveJoinStrategyHints(conf: SQLConf) extends Rule[LogicalPlan] { + private val STRATEGY_HINT_NAMES = JoinStrategyHint.strategies.flatMap(_.hintAliases) def resolver: Resolver = conf.resolver - private def applyBroadcastHint(plan: LogicalPlan, toBroadcast: Set[String]): LogicalPlan = { + private def createHintInfo(hintName: String): HintInfo = { + HintInfo(strategy = + JoinStrategyHint.strategies.find( + _.hintAliases.map( + _.toUpperCase(Locale.ROOT)).contains(hintName.toUpperCase(Locale.ROOT)))) + } + + private def applyJoinStrategyHint( + plan: LogicalPlan, + relations: mutable.HashSet[String], + hintName: String): LogicalPlan = { // Whether to continue recursing down the tree var recurse = true val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { - case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) => - ResolvedHint(plan, HintInfo(broadcast = true)) - case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) => - ResolvedHint(plan, HintInfo(broadcast = true)) + case ResolvedHint(u: UnresolvedRelation, hint) + if relations.exists(resolver(_, u.tableIdentifier.table)) => + relations.remove(u.tableIdentifier.table) + ResolvedHint(u, createHintInfo(hintName).merge(hint, handleOverriddenHintInfo)) + case ResolvedHint(r: SubqueryAlias, hint) + if relations.exists(resolver(_, r.alias)) => + relations.remove(r.alias) + ResolvedHint(r, createHintInfo(hintName).merge(hint, handleOverriddenHintInfo)) + + case u: UnresolvedRelation if relations.exists(resolver(_, u.tableIdentifier.table)) => + relations.remove(u.tableIdentifier.table) + ResolvedHint(plan, createHintInfo(hintName)) + case r: SubqueryAlias if relations.exists(resolver(_, r.alias)) => + relations.remove(r.alias) + ResolvedHint(plan, createHintInfo(hintName)) case _: ResolvedHint | _: View | _: With | _: SubqueryAlias => // Don't traverse down these nodes. - // For an existing broadcast hint, there is no point going down (if we do, we either - // won't change the structure, or will introduce another broadcast hint that is useless. + // For an existing strategy hint, there is no chance for a match from this point down. // The rest (view, with, subquery) indicates different scopes that we shouldn't traverse // down. Note that technically when this rule is executed, we haven't completed view // resolution yet and as a result the view part should be deadcode. I'm leaving it here @@ -80,25 +103,38 @@ object ResolveHints { } if ((plan fastEquals newNode) && recurse) { - newNode.mapChildren(child => applyBroadcastHint(child, toBroadcast)) + newNode.mapChildren(child => applyJoinStrategyHint(child, relations, hintName)) } else { newNode } } + private def handleOverriddenHintInfo(hint: HintInfo): Unit = { + logWarning(s"Join hint $hint is overridden by another hint and will not take effect.") + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { - case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => + case h: UnresolvedHint if STRATEGY_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => if (h.parameters.isEmpty) { - // If there is no table alias specified, turn the entire subtree into a BroadcastHint. - ResolvedHint(h.child, HintInfo(broadcast = true)) + // If there is no table alias specified, apply the hint on the entire subtree. + ResolvedHint(h.child, createHintInfo(h.name)) } else { - // Otherwise, find within the subtree query plans that should be broadcasted. - applyBroadcastHint(h.child, h.parameters.map { + // Otherwise, find within the subtree query plans to apply the hint. + val relationNames = h.parameters.map { case tableName: String => tableName case tableId: UnresolvedAttribute => tableId.name - case unsupported => throw new AnalysisException("Broadcast hint parameter should be " + - s"an identifier or string but was $unsupported (${unsupported.getClass}") - }.toSet) + case unsupported => throw new AnalysisException("Join strategy hint parameter " + + s"should be an identifier or string but was $unsupported (${unsupported.getClass}") + } + val relationNameSet = new mutable.HashSet[String] + relationNames.foreach(relationNameSet.add) + + val applied = applyJoinStrategyHint(h.child, relationNameSet, h.name) + relationNameSet.foreach { n => + logWarning(s"Count not find relation '$n' for join strategy hint " + + s"'${h.name}${relationNames.mkString("(", ", ", ")")}'.") + } + applied } } } @@ -135,7 +171,9 @@ object ResolveHints { */ object RemoveAllHints extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { - case h: UnresolvedHint => h.child + case h: UnresolvedHint => + logWarning(s"Unrecognized hint: ${h.name}${h.parameters.mkString("(", ", ", ")")}") + h.child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 60066371e3dc..2d646721f87a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -374,7 +374,7 @@ object CatalogTable { /** * This class of statistics is used in [[CatalogTable]] to interact with metastore. * We define this new class instead of directly using [[Statistics]] here because there are no - * concepts of attributes or broadcast hint in catalog. + * concepts of attributes in catalog. */ case class CatalogStatistics( sizeInBytes: BigInt, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala index a136f0493699..5586690520c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala @@ -30,30 +30,58 @@ object EliminateResolvedHint extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { val pulledUp = plan transformUp { case j: Join => - val leftHint = mergeHints(collectHints(j.left)) - val rightHint = mergeHints(collectHints(j.right)) - j.copy(hint = JoinHint(leftHint, rightHint)) + val (newLeft, leftHints) = extractHintsFromPlan(j.left) + val (newRight, rightHints) = extractHintsFromPlan(j.right) + val newJoinHint = JoinHint(mergeHints(leftHints), mergeHints(rightHints)) + j.copy(left = newLeft, right = newRight, hint = newJoinHint) } pulledUp.transformUp { - case h: ResolvedHint => h.child + case h: ResolvedHint => + handleInvalidHintInfo(h.hints) + h.child } } + /** + * Combine a list of [[HintInfo]]s into one [[HintInfo]]. + */ private def mergeHints(hints: Seq[HintInfo]): Option[HintInfo] = { - hints.reduceOption((h1, h2) => HintInfo( - broadcast = h1.broadcast || h2.broadcast)) + hints.reduceOption((h1, h2) => h1.merge(h2, handleOverriddenHintInfo)) } - private def collectHints(plan: LogicalPlan): Seq[HintInfo] = { + /** + * Extract all hints from the plan, returning a list of extracted hints and the transformed plan + * with [[ResolvedHint]] nodes removed. The returned hint list comes in top-down order. + * Note that hints can only be extracted from under certain nodes. Those that cannot be extracted + * in this method will be cleaned up later by this rule, and may emit warnings depending on the + * configurations. + */ + private def extractHintsFromPlan(plan: LogicalPlan): (LogicalPlan, Seq[HintInfo]) = { plan match { - case h: ResolvedHint => collectHints(h.child) :+ h.hints - case u: UnaryNode => collectHints(u.child) + case h: ResolvedHint => + val (plan, hints) = extractHintsFromPlan(h.child) + (plan, h.hints +: hints) + case u: UnaryNode => + val (plan, hints) = extractHintsFromPlan(u.child) + (u.withNewChildren(Seq(plan)), hints) // TODO revisit this logic: // except and intersect are semi/anti-joins which won't return more data then // their left argument, so the broadcast hint should be propagated here - case i: Intersect => collectHints(i.left) - case e: Except => collectHints(e.left) - case _ => Seq.empty + case i: Intersect => + val (plan, hints) = extractHintsFromPlan(i.left) + (i.copy(left = plan), hints) + case e: Except => + val (plan, hints) = extractHintsFromPlan(e.left) + (e.copy(left = plan), hints) + case p: LogicalPlan => (p, Seq.empty) } } + + private def handleInvalidHintInfo(hint: HintInfo): Unit = { + logWarning(s"A join hint $hint is specified but it is not part of a join relation.") + } + + private def handleOverriddenHintInfo(hint: HintInfo): Unit = { + logWarning(s"Join hint $hint is overridden by another hint and will not take effect.") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index b2ba725e9d44..870dd87d8a36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -66,17 +66,94 @@ object JoinHint { /** * The hint attributes to be applied on a specific node. * - * @param broadcast If set to true, it indicates that the broadcast hash join is the preferred join - * strategy and the node with this hint is preferred to be the build side. + * @param strategy The preferred join strategy. */ -case class HintInfo(broadcast: Boolean = false) { +case class HintInfo(strategy: Option[JoinStrategyHint] = None) { - override def toString: String = { - val hints = scala.collection.mutable.ArrayBuffer.empty[String] - if (broadcast) { - hints += "broadcast" + /** + * Combine this [[HintInfo]] with another [[HintInfo]] and return the new [[HintInfo]]. + * @param other the other [[HintInfo]] + * @param hintOverriddenCallback a callback to notify if any [[HintInfo]] has been overridden + * in this merge. + * + * Currently, for join strategy hints, the new [[HintInfo]] will contain the strategy in this + * [[HintInfo]] if defined, otherwise the strategy in the other [[HintInfo]]. The + * `hintOverriddenCallback` will be called if this [[HintInfo]] and the other [[HintInfo]] + * both have a strategy defined but the join strategies are different. + */ + def merge(other: HintInfo, hintOverriddenCallback: HintInfo => Unit): HintInfo = { + if (this.strategy.isDefined && + other.strategy.isDefined && + this.strategy.get != other.strategy.get) { + hintOverriddenCallback(other) } - - if (hints.isEmpty) "none" else hints.mkString("(", ", ", ")") + HintInfo(strategy = this.strategy.orElse(other.strategy)) } + + override def toString: String = strategy.map(s => s"(strategy=$s)").getOrElse("none") +} + +sealed abstract class JoinStrategyHint { + + def displayName: String + def hintAliases: Set[String] + + override def toString: String = displayName +} + +/** + * The enumeration of join strategy hints. + * + * The hinted strategy will be used for the join with which it is associated if doable. In case + * of contradicting strategy hints specified for each side of the join, hints are prioritized as + * BROADCAST over SHUFFLE_MERGE over SHUFFLE_HASH over SHUFFLE_REPLICATE_NL. + */ +object JoinStrategyHint { + + val strategies: Set[JoinStrategyHint] = Set( + BROADCAST, + SHUFFLE_MERGE, + SHUFFLE_HASH, + SHUFFLE_REPLICATE_NL) +} + +/** + * The hint for broadcast hash join or broadcast nested loop join, depending on the availability of + * equi-join keys. + */ +case object BROADCAST extends JoinStrategyHint { + override def displayName: String = "broadcast" + override def hintAliases: Set[String] = Set( + "BROADCAST", + "BROADCASTJOIN", + "MAPJOIN") +} + +/** + * The hint for shuffle sort merge join. + */ +case object SHUFFLE_MERGE extends JoinStrategyHint { + override def displayName: String = "merge" + override def hintAliases: Set[String] = Set( + "SHUFFLE_MERGE", + "MERGE", + "MERGEJOIN") +} + +/** + * The hint for shuffle hash join. + */ +case object SHUFFLE_HASH extends JoinStrategyHint { + override def displayName: String = "shuffle_hash" + override def hintAliases: Set[String] = Set( + "SHUFFLE_HASH") +} + +/** + * The hint for shuffle-and-replicate nested loop join, a.k.a. cartesian product join. + */ +case object SHUFFLE_REPLICATE_NL extends JoinStrategyHint { + override def displayName: String = "shuffle_replicate_nl" + override def hintAliases: Set[String] = Set( + "SHUFFLE_REPLICATE_NL") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 563e8adf87ed..474e58a335e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -17,6 +17,11 @@ package org.apache.spark.sql.catalyst.analysis +import scala.collection.mutable.ArrayBuffer + +import org.apache.log4j.{AppenderSkeleton, Level} +import org.apache.log4j.spi.LoggingEvent + import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal @@ -27,6 +32,14 @@ import org.apache.spark.sql.catalyst.plans.logical._ class ResolveHintsSuite extends AnalysisTest { import org.apache.spark.sql.catalyst.analysis.TestRelations._ + class MockAppender extends AppenderSkeleton { + val loggingEvents = new ArrayBuffer[LoggingEvent]() + + override def append(loggingEvent: LoggingEvent): Unit = loggingEvents.append(loggingEvent) + override def close(): Unit = {} + override def requiresLayout(): Boolean = false + } + test("invalid hints should be ignored") { checkAnalysis( UnresolvedHint("some_random_hint_that_does_not_exist", Seq("TaBlE"), table("TaBlE")), @@ -37,17 +50,17 @@ class ResolveHintsSuite extends AnalysisTest { test("case-sensitive or insensitive parameters") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), - ResolvedHint(testRelation, HintInfo(broadcast = true)), + ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")), - ResolvedHint(testRelation, HintInfo(broadcast = true)), + ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), - ResolvedHint(testRelation, HintInfo(broadcast = true)), + ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = true) checkAnalysis( @@ -59,28 +72,29 @@ class ResolveHintsSuite extends AnalysisTest { test("multiple broadcast hint aliases") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), - Join(ResolvedHint(testRelation, HintInfo(broadcast = true)), - ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None, JoinHint.NONE), + Join(ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), + ResolvedHint(testRelation2, HintInfo(strategy = Some(BROADCAST))), + Inner, None, JoinHint.NONE), caseSensitive = false) } test("do not traverse past existing broadcast hints") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table"), - ResolvedHint(table("table").where('a > 1), HintInfo(broadcast = true))), - ResolvedHint(testRelation.where('a > 1), HintInfo(broadcast = true)).analyze, + ResolvedHint(table("table").where('a > 1), HintInfo(strategy = Some(BROADCAST)))), + ResolvedHint(testRelation.where('a > 1), HintInfo(strategy = Some(BROADCAST))).analyze, caseSensitive = false) } test("should work for subqueries") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")), - ResolvedHint(testRelation, HintInfo(broadcast = true)), + ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)), - ResolvedHint(testRelation, HintInfo(broadcast = true)), + ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = false) // Negative case: if the alias doesn't match, don't match the original table name. @@ -105,7 +119,7 @@ class ResolveHintsSuite extends AnalysisTest { |SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable """.stripMargin ), - ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(broadcast = true)) + ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(strategy = Some(BROADCAST))) .select('a).analyze, caseSensitive = false) } @@ -155,4 +169,17 @@ class ResolveHintsSuite extends AnalysisTest { UnresolvedHint("REPARTITION", Seq(Literal(true)), table("TaBlE")), Seq(errMsgRepa)) } + + test("log warnings for invalid hints") { + val logAppender = new MockAppender() + withLogAppender(logAppender) { + checkAnalysis( + UnresolvedHint("unknown_hint", Seq("TaBlE"), table("TaBlE")), + testRelation, + caseSensitive = false) + } + assert(logAppender.loggingEvents.exists( + e => e.getLevel == Level.WARN && + e.getRenderedMessage.contains("Unrecognized hint: unknown_hint"))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index bf38189c2fb5..efd05a3e2b3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -90,61 +90,35 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** - * Select the proper physical plan for join based on joining keys and size of logical plan. - * - * At first, uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the - * predicates can be evaluated by matching join keys. If found, join implementations are chosen - * with the following precedence: + * Select the proper physical plan for join based on join strategy hints, the availability of + * equi-join keys and the sizes of joining relations. Below are the existing join strategies, + * their characteristics and their limitations. * * - Broadcast hash join (BHJ): - * BHJ is not supported for full outer join. For right outer join, we only can broadcast the - * left side. For left outer, left semi, left anti and the internal join type ExistenceJoin, - * we only can broadcast the right side. For inner like join, we can broadcast both sides. - * Normally, BHJ can perform faster than the other join algorithms when the broadcast side is - * small. However, broadcasting tables is a network-intensive operation. It could cause OOM - * or perform worse than the other join algorithms, especially when the build/broadcast side - * is big. - * - * For the supported cases, users can specify the broadcast hint (e.g. the user applied the - * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame) and session-based - * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to adjust whether BHJ is used and - * which join side is broadcast. - * - * 1) Broadcast the join side with the broadcast hint, even if the size is larger than - * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint (only when the type - * is inner like join), the side with a smaller estimated physical size will be broadcast. - * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and broadcast the side - * whose estimated physical size is smaller than the threshold. If both sides are below the - * threshold, broadcast the smaller side. If neither is smaller, BHJ is not used. - * - * - Shuffle hash join: if the average size of a single partition is small enough to build a hash - * table. - * - * - Sort merge: if the matching join keys are sortable. - * - * If there is no joining keys, Join implementations are chosen with the following precedence: - * - BroadcastNestedLoopJoin (BNLJ): - * BNLJ supports all the join types but the impl is OPTIMIZED for the following scenarios: - * For right outer join, the left side is broadcast. For left outer, left semi, left anti - * and the internal join type ExistenceJoin, the right side is broadcast. For inner like - * joins, either side is broadcast. + * Only supported for equi-joins, while the join keys do not need to be sortable. + * Supported for all join types except full outer joins. + * BHJ usually performs faster than the other join algorithms when the broadcast side is + * small. However, broadcasting tables is a network-intensive operation and it could cause + * OOM or perform badly in some cases, especially when the build/broadcast side is big. * - * Like BHJ, users still can specify the broadcast hint and session-based - * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to impact which side is broadcast. + * - Shuffle hash join: + * Only supported for equi-joins, while the join keys do not need to be sortable. + * Supported for all join types except full outer joins. * - * 1) Broadcast the join side with the broadcast hint, even if the size is larger than - * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint (i.e., just for - * inner-like join), the side with a smaller estimated physical size will be broadcast. - * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and broadcast the side - * whose estimated physical size is smaller than the threshold. If both sides are below the - * threshold, broadcast the smaller side. If neither is smaller, BNLJ is not used. + * - Shuffle sort merge join (SMJ): + * Only supported for equi-joins and the join keys have to be sortable. + * Supported for all join types. * - * - CartesianProduct: for inner like join, CartesianProduct is the fallback option. + * - Broadcast nested loop join (BNLJ): + * Supports both equi-joins and non-equi-joins. + * Supports all the join types, but the implementation is optimized for: + * 1) broadcasting the left side in a right outer join; + * 2) broadcasting the right side in a left outer, left semi, left anti or existence join; + * 3) broadcasting either side in an inner-like join. * - * - BroadcastNestedLoopJoin (BNLJ): - * For the other join types, BNLJ is the fallback option. Here, we just pick the broadcast - * side with the broadcast hint. If neither side has a hint, we broadcast the side with - * the smaller estimated physical size. + * - Shuffle-and-replicate nested loop join (a.k.a. cartesian product join): + * Supports both equi-joins and non-equi-joins. + * Supports only inner like joins. */ object JoinSelection extends Strategy with PredicateHelper { @@ -186,126 +160,218 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => false } - private def broadcastSide( - canBuildLeft: Boolean, - canBuildRight: Boolean, + private def getBuildSide( + wantToBuildLeft: Boolean, + wantToBuildRight: Boolean, left: LogicalPlan, - right: LogicalPlan): BuildSide = { - - def smallerSide = - if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft - - if (canBuildRight && canBuildLeft) { - // Broadcast smaller side base on its estimated physical size - // if both sides have broadcast hint - smallerSide - } else if (canBuildRight) { - BuildRight - } else if (canBuildLeft) { - BuildLeft + right: LogicalPlan): Option[BuildSide] = { + if (wantToBuildLeft && wantToBuildRight) { + // returns the smaller side base on its estimated physical size, if we want to build the + // both sides. + Some(getSmallerSide(left, right)) + } else if (wantToBuildLeft) { + Some(BuildLeft) + } else if (wantToBuildRight) { + Some(BuildRight) } else { - // for the last default broadcast nested loop join - smallerSide + None } } - private def canBroadcastByHints( - joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): Boolean = { - val buildLeft = canBuildLeft(joinType) && hint.leftHint.exists(_.broadcast) - val buildRight = canBuildRight(joinType) && hint.rightHint.exists(_.broadcast) - buildLeft || buildRight + private def getSmallerSide(left: LogicalPlan, right: LogicalPlan) = { + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft } - private def broadcastSideByHints( - joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): BuildSide = { - val buildLeft = canBuildLeft(joinType) && hint.leftHint.exists(_.broadcast) - val buildRight = canBuildRight(joinType) && hint.rightHint.exists(_.broadcast) - broadcastSide(buildLeft, buildRight, left, right) + private def hintToBroadcastLeft(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(BROADCAST)) } - private def canBroadcastBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) - : Boolean = { - val buildLeft = canBuildLeft(joinType) && canBroadcast(left) - val buildRight = canBuildRight(joinType) && canBroadcast(right) - buildLeft || buildRight + private def hintToBroadcastRight(hint: JoinHint): Boolean = { + hint.rightHint.exists(_.strategy.contains(BROADCAST)) } - private def broadcastSideBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) - : BuildSide = { - val buildLeft = canBuildLeft(joinType) && canBroadcast(left) - val buildRight = canBuildRight(joinType) && canBroadcast(right) - broadcastSide(buildLeft, buildRight, left, right) + private def hintToShuffleHashLeft(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH)) + } + + private def hintToShuffleHashRight(hint: JoinHint): Boolean = { + hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH)) + } + + private def hintToSortMergeJoin(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) || + hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE)) + } + + private def hintToShuffleReplicateNL(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) || + hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // --- BroadcastHashJoin -------------------------------------------------------------------- - - // broadcast hints were specified - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) - if canBroadcastByHints(joinType, left, right, hint) => - val buildSide = broadcastSideByHints(joinType, left, right, hint) - Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) - - // broadcast hints were not specified, so need to infer it from size and configuration. - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) - if canBroadcastBySizes(joinType, left, right) => - val buildSide = broadcastSideBySizes(joinType, left, right) - Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) - - // --- ShuffledHashJoin --------------------------------------------------------------------- - - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) - if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right) - && muchSmaller(right, left) || - !RowOrdering.isOrderable(leftKeys) => - Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) - - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) - if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left) - && muchSmaller(left, right) || - !RowOrdering.isOrderable(leftKeys) => - Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) - - // --- SortMergeJoin ------------------------------------------------------------ - - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) - if RowOrdering.isOrderable(leftKeys) => - joins.SortMergeJoinExec( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil - - // --- Without joining keys ------------------------------------------------------------ - - // Pick BroadcastNestedLoopJoin if one side could be broadcast - case j @ logical.Join(left, right, joinType, condition, hint) - if canBroadcastByHints(joinType, left, right, hint) => - val buildSide = broadcastSideByHints(joinType, left, right, hint) - joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), buildSide, joinType, condition) :: Nil - - case j @ logical.Join(left, right, joinType, condition, _) - if canBroadcastBySizes(joinType, left, right) => - val buildSide = broadcastSideBySizes(joinType, left, right) - joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), buildSide, joinType, condition) :: Nil - - // Pick CartesianProduct for InnerJoin - case logical.Join(left, right, _: InnerLike, condition, _) => - joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil + // If it is an equi-join, we first look at the join hints w.r.t. the following order: + // 1. broadcast hint: pick broadcast hash join if the join type is supported. If both sides + // have the broadcast hints, choose the smaller side (based on stats) to broadcast. + // 2. sort merge hint: pick sort merge join if join keys are sortable. + // 3. shuffle hash hint: We pick shuffle hash join if the join type is supported. If both + // sides have the shuffle hash hints, choose the smaller side (based on stats) as the + // build side. + // 4. shuffle replicate NL hint: pick cartesian product if join type is inner like. + // + // If there is no hint or the hints are not applicable, we follow these rules one by one: + // 1. Pick broadcast hash join if one side is small enough to broadcast, and the join type + // is supported. If both sides are small, choose the smaller side (based on stats) + // to broadcast. + // 2. Pick shuffle hash join if one side is small enough to build local hash map, and is + // much smaller than the other side, and `spark.sql.join.preferSortMergeJoin` is false. + // 3. Pick sort merge join if the join keys are sortable. + // 4. Pick cartesian product if join type is inner like. + // 5. Pick broadcast nested loop join as the final solution. It may OOM but we don't have + // other choice. + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) => + def createBroadcastHashJoin(buildLeft: Boolean, buildRight: Boolean) = { + val wantToBuildLeft = canBuildLeft(joinType) && buildLeft + val wantToBuildRight = canBuildRight(joinType) && buildRight + getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map { buildSide => + Seq(joins.BroadcastHashJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + condition, + planLater(left), + planLater(right))) + } + } + + def createShuffleHashJoin(buildLeft: Boolean, buildRight: Boolean) = { + val wantToBuildLeft = canBuildLeft(joinType) && buildLeft + val wantToBuildRight = canBuildRight(joinType) && buildRight + getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map { buildSide => + Seq(joins.ShuffledHashJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + condition, + planLater(left), + planLater(right))) + } + } + + def createSortMergeJoin() = { + if (RowOrdering.isOrderable(leftKeys)) { + Some(Seq(joins.SortMergeJoinExec( + leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)))) + } else { + None + } + } + + def createCartesianProduct() = { + if (joinType.isInstanceOf[InnerLike]) { + Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), condition))) + } else { + None + } + } + + def createJoinWithoutHint() = { + createBroadcastHashJoin(canBroadcast(left), canBroadcast(right)) + .orElse { + if (!conf.preferSortMergeJoin) { + createShuffleHashJoin( + canBuildLocalHashMap(left) && muchSmaller(left, right), + canBuildLocalHashMap(right) && muchSmaller(right, left)) + } else { + None + } + } + .orElse(createSortMergeJoin()) + .orElse(createCartesianProduct()) + .getOrElse { + // This join could be very slow or OOM + val buildSide = getSmallerSide(left, right) + Seq(joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition)) + } + } + createBroadcastHashJoin(hintToBroadcastLeft(hint), hintToBroadcastRight(hint)) + .orElse { if (hintToSortMergeJoin(hint)) createSortMergeJoin() else None } + .orElse(createShuffleHashJoin(hintToShuffleHashLeft(hint), hintToShuffleHashRight(hint))) + .orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None } + .getOrElse(createJoinWithoutHint()) + + // If it is not an equi-join, we first look at the join hints w.r.t. the following order: + // 1. broadcast hint: pick broadcast nested loop join. If both sides have the broadcast + // hints, choose the smaller side (based on stats) to broadcast. + // 2. shuffle replicate NL hint: pick cartesian product if join type is inner like. + // + // If there is no hint or the hints are not applicable, we follow these rules one by one: + // 1. Pick cartesian product if join type is inner like, and both sides are too big to + // to broadcast. + // 2. Pick broadcast nested loop join. Pick the smaller side (based on stats) to broadcast. case logical.Join(left, right, joinType, condition, hint) => - val buildSide = broadcastSide( - hint.leftHint.exists(_.broadcast), hint.rightHint.exists(_.broadcast), left, right) - // This join could be very slow or OOM - joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + def createBroadcastNLJoin(buildLeft: Boolean, buildRight: Boolean) = { + getBuildSide(buildLeft, buildRight, left, right).map { buildSide => + Seq(joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition)) + } + } - // --- Cases where this strategy does not apply --------------------------------------------- + def createCartesianProduct() = { + if (joinType.isInstanceOf[InnerLike]) { + Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), condition))) + } else { + None + } + } + + def createJoinWithoutHint() = { + (if (!canBroadcast(left) && !canBroadcast(right)) createCartesianProduct() else None) + .getOrElse { + // This join could be very slow or OOM + val buildSide = getSmallerSide(left, right) + Seq(joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition)) + } + } + + if (joinType.isInstanceOf[InnerLike] || joinType == FullOuter) { + createBroadcastNLJoin(hintToBroadcastLeft(hint), hintToBroadcastRight(hint)) + .orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None } + .getOrElse(createJoinWithoutHint()) + } else { + val smallerSide = getSmallerSide(left, right) + val buildSide = if (canBuildLeft(joinType)) { + // For RIGHT JOIN, we may broadcast left side even if the hint asks us to broadcast + // the right side. This is for history reasons. + if (hintToBroadcastLeft(hint) || canBroadcast(left)) { + BuildLeft + } else if (hintToBroadcastRight(hint)) { + BuildRight + } else { + smallerSide + } + } else { + // For LEFT JOIN, we may broadcast right side even if the hint asks us to broadcast + // the left side. This is for history reasons. + if (hintToBroadcastRight(hint) || canBroadcast(right)) { + BuildRight + } else if (hintToBroadcastLeft(hint)) { + BuildLeft + } else { + smallerSide + } + } + Seq(joins.BroadcastNestedLoopJoinExec( + planLater(left), planLater(right), buildSide, joinType, condition)) + } + + // --- Cases where this strategy does not apply --------------------------------------------- case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d0be216e6a2f..79045e8a5aec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedFunction} import org.apache.spark.sql.internal.SQLConf @@ -1045,7 +1045,7 @@ object functions { */ def broadcast[T](df: Dataset[T]): Dataset[T] = { Dataset[T](df.sparkSession, - ResolvedHint(df.logicalPlan, HintInfo(broadcast = true)))(df.exprEnc) + ResolvedHint(df.logicalPlan, HintInfo(strategy = Some(BROADCAST))))(df.exprEnc) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 4d63390fb452..92157d8ad49e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.executor.DataReadMethod.DataReadMethod import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Join} import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -951,7 +951,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext case Join(_, _, _, _, hint) => hint } assert(hint.size == 1) - assert(hint(0).leftHint.get.broadcast) + assert(hint(0).leftHint.get.strategy.contains(BROADCAST)) assert(hint(0).rightHint.isEmpty) // Clean-up diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala index 67f0f1a6fd23..9c2dc0c62b2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -17,8 +17,14 @@ package org.apache.spark.sql +import scala.collection.mutable.ArrayBuffer + +import org.apache.log4j.{AppenderSkeleton, Level} +import org.apache.log4j.spi.LoggingEvent + import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -30,6 +36,41 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { lazy val df2 = df.selectExpr("id as b1", "id as b2") lazy val df3 = df.selectExpr("id as c1", "id as c2") + class MockAppender extends AppenderSkeleton { + val loggingEvents = new ArrayBuffer[LoggingEvent]() + + override def append(loggingEvent: LoggingEvent): Unit = loggingEvents.append(loggingEvent) + override def close(): Unit = {} + override def requiresLayout(): Boolean = false + } + + def msgNoHintRelationFound(relation: String, hint: String): String = + s"Count not find relation '$relation' for join strategy hint '$hint'." + + def msgNoJoinForJoinHint(strategy: String): String = + s"A join hint (strategy=$strategy) is specified but it is not part of a join relation." + + def msgJoinHintOverridden(strategy: String): String = + s"Join hint (strategy=$strategy) is overridden by another hint and will not take effect." + + def verifyJoinHintWithWarnings( + df: => DataFrame, + expectedHints: Seq[JoinHint], + warnings: Seq[String]): Unit = { + val logAppender = new MockAppender() + withLogAppender(logAppender) { + verifyJoinHint(df, expectedHints) + } + val warningMessages = logAppender.loggingEvents + .filter(_.getLevel == Level.WARN) + .map(_.getRenderedMessage) + .filter(_.contains("hint")) + assert(warningMessages.size == warnings.size) + warnings.foreach { w => + assert(warningMessages.contains(w)) + } + } + def verifyJoinHint(df: DataFrame, expectedHints: Seq[JoinHint]): Unit = { val optimized = df.queryExecution.optimizedPlan val joinHints = optimized collect { @@ -43,14 +84,14 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { verifyJoinHint( df.hint("broadcast").join(df, "id"), JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: Nil ) verifyJoinHint( df.join(df.hint("broadcast"), "id"), JoinHint( None, - Some(HintInfo(broadcast = true))) :: Nil + Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil ) } @@ -59,18 +100,18 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { df1.join(df2.hint("broadcast").join(df3, 'b1 === 'c1).hint("broadcast"), 'a1 === 'c1), JoinHint( None, - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: Nil ) verifyJoinHint( df1.hint("broadcast").join(df2, 'a1 === 'b1).hint("broadcast").join(df3, 'a1 === 'c1), JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: Nil ) } @@ -89,13 +130,13 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { |) b on a.a1 = b.b1 """.stripMargin), JoinHint( - Some(HintInfo(broadcast = true)), - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST))), + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint( None, - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: Nil ) } @@ -112,9 +153,9 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { "where a.a1 = b.b1 and b.b1 = c.c1"), JoinHint( None, - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: Nil ) verifyJoinHint( @@ -122,25 +163,25 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { "where a.a1 = b.b1 and b.b1 = c.c1"), JoinHint.NONE :: JoinHint( - Some(HintInfo(broadcast = true)), - Some(HintInfo(broadcast = true))) :: Nil + Some(HintInfo(strategy = Some(BROADCAST))), + Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil ) verifyJoinHint( sql("select /*+ broadcast(b, c)*/ * from a, c, b " + "where a.a1 = b.b1 and b.b1 = c.c1"), JoinHint( None, - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint( None, - Some(HintInfo(broadcast = true))) :: Nil + Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil ) verifyJoinHint( df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast") .join(df3, 'b1 === 'c1 && 'a1 < 10), JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: JoinHint.NONE :: Nil ) @@ -151,7 +192,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { .join(df, 'b1 === 'id), JoinHint.NONE :: JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: JoinHint.NONE :: Nil ) @@ -164,7 +205,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { verifyJoinHint( df.hint("broadcast").except(dfSub).join(df, "id"), JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), None) :: JoinHint.NONE :: Nil ) @@ -172,31 +213,112 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { df.join(df.hint("broadcast").intersect(dfSub), "id"), JoinHint( None, - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint.NONE :: Nil ) } test("hint merge") { - verifyJoinHint( + verifyJoinHintWithWarnings( df.hint("broadcast").filter('id > 2).hint("broadcast").join(df, "id"), JoinHint( - Some(HintInfo(broadcast = true)), - None) :: Nil + Some(HintInfo(strategy = Some(BROADCAST))), + None) :: Nil, + Nil ) - verifyJoinHint( + verifyJoinHintWithWarnings( df.join(df.hint("broadcast").limit(2).hint("broadcast"), "id"), JoinHint( None, - Some(HintInfo(broadcast = true))) :: Nil + Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil, + Nil + ) + verifyJoinHintWithWarnings( + df.hint("merge").filter('id > 2).hint("shuffle_hash").join(df, "id").hint("broadcast"), + JoinHint( + Some(HintInfo(strategy = Some(SHUFFLE_HASH))), + None) :: Nil, + msgJoinHintOverridden("merge") :: + msgNoJoinForJoinHint("broadcast") :: Nil + ) + verifyJoinHintWithWarnings( + df.join(df.hint("broadcast").limit(2).hint("merge"), "id") + .hint("shuffle_hash") + .hint("shuffle_replicate_nl") + .join(df, "id"), + JoinHint( + Some(HintInfo(strategy = Some(SHUFFLE_REPLICATE_NL))), + None) :: + JoinHint( + None, + Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) :: Nil, + msgJoinHintOverridden("broadcast") :: + msgJoinHintOverridden("shuffle_hash") :: Nil ) } + test("hint merge - SQL") { + withTempView("a", "b", "c") { + df1.createOrReplaceTempView("a") + df2.createOrReplaceTempView("b") + df3.createOrReplaceTempView("c") + verifyJoinHintWithWarnings( + sql("select /*+ shuffle_hash merge(a, c) broadcast(a, b)*/ * from a, b, c " + + "where a.a1 = b.b1 and b.b1 = c.c1"), + JoinHint( + None, + Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) :: + JoinHint( + Some(HintInfo(strategy = Some(SHUFFLE_MERGE))), + Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil, + msgNoJoinForJoinHint("shuffle_hash") :: + msgJoinHintOverridden("broadcast") :: Nil + ) + verifyJoinHintWithWarnings( + sql("select /*+ shuffle_hash(a, b) merge(b, d) broadcast(b)*/ * from a, b, c " + + "where a.a1 = b.b1 and b.b1 = c.c1"), + JoinHint.NONE :: + JoinHint( + Some(HintInfo(strategy = Some(SHUFFLE_HASH))), + Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil, + msgNoHintRelationFound("d", "merge(b, d)") :: + msgJoinHintOverridden("broadcast") :: + msgJoinHintOverridden("merge") :: Nil + ) + verifyJoinHintWithWarnings( + sql( + """ + |select /*+ broadcast(a, c) merge(a, d)*/ * from a + |join ( + | select /*+ shuffle_hash(c) shuffle_replicate_nl(b, c)*/ * from b + | join c on b.b1 = c.c1 + |) as d + |on a.a2 = d.b2 + """.stripMargin), + JoinHint( + Some(HintInfo(strategy = Some(BROADCAST))), + Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) :: + JoinHint( + Some(HintInfo(strategy = Some(SHUFFLE_REPLICATE_NL))), + Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil, + msgNoHintRelationFound("c", "broadcast(a, c)") :: + msgJoinHintOverridden("merge") :: + msgJoinHintOverridden("shuffle_replicate_nl") :: Nil + ) + } + } + test("nested hint") { verifyJoinHint( df.hint("broadcast").hint("broadcast").filter('id > 2).join(df, "id"), JoinHint( - Some(HintInfo(broadcast = true)), + Some(HintInfo(strategy = Some(BROADCAST))), + None) :: Nil + ) + verifyJoinHint( + df.hint("shuffle_hash").hint("broadcast").hint("merge").filter('id > 2).join(df, "id"), + JoinHint( + Some(HintInfo(strategy = Some(SHUFFLE_MERGE))), None) :: Nil ) } @@ -209,12 +331,230 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { join.join(broadcasted, "id").join(broadcasted, "id"), JoinHint( None, - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint( None, - Some(HintInfo(broadcast = true))) :: + Some(HintInfo(strategy = Some(BROADCAST)))) :: JoinHint.NONE :: JoinHint.NONE :: JoinHint.NONE :: Nil ) } } + + def equiJoinQueryWithHint(hints: Seq[String], joinType: String = "INNER"): String = + hints.map("/*+ " + _ + " */").mkString( + "SELECT ", " ", s" * FROM t1 $joinType JOIN t2 ON t1.key = t2.key") + + def nonEquiJoinQueryWithHint(hints: Seq[String], joinType: String = "INNER"): String = + hints.map("/*+ " + _ + " */").mkString( + "SELECT ", " ", s" * FROM t1 $joinType JOIN t2 ON t1.key > t2.key") + + private def assertBroadcastHashJoin(df: DataFrame, buildSide: BuildSide): Unit = { + val executedPlan = df.queryExecution.executedPlan + val broadcastHashJoins = executedPlan.collect { + case b: BroadcastHashJoinExec => b + } + assert(broadcastHashJoins.size == 1) + assert(broadcastHashJoins.head.buildSide == buildSide) + } + + private def assertBroadcastNLJoin(df: DataFrame, buildSide: BuildSide): Unit = { + val executedPlan = df.queryExecution.executedPlan + val broadcastNLJoins = executedPlan.collect { + case b: BroadcastNestedLoopJoinExec => b + } + assert(broadcastNLJoins.size == 1) + assert(broadcastNLJoins.head.buildSide == buildSide) + } + + private def assertShuffleHashJoin(df: DataFrame, buildSide: BuildSide): Unit = { + val executedPlan = df.queryExecution.executedPlan + val shuffleHashJoins = executedPlan.collect { + case s: ShuffledHashJoinExec => s + } + assert(shuffleHashJoins.size == 1) + assert(shuffleHashJoins.head.buildSide == buildSide) + } + + private def assertShuffleMergeJoin(df: DataFrame): Unit = { + val executedPlan = df.queryExecution.executedPlan + val shuffleMergeJoins = executedPlan.collect { + case s: SortMergeJoinExec => s + } + assert(shuffleMergeJoins.size == 1) + } + + private def assertShuffleReplicateNLJoin(df: DataFrame): Unit = { + val executedPlan = df.queryExecution.executedPlan + val shuffleReplicateNLJoins = executedPlan.collect { + case c: CartesianProductExec => c + } + assert(shuffleReplicateNLJoins.size == 1) + } + + test("join strategy hint - broadcast") { + withTempView("t1", "t2") { + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") + + val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes + val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes + assert(t1Size < t2Size) + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + // Broadcast hint specified on one side + assertBroadcastHashJoin( + sql(equiJoinQueryWithHint("BROADCAST(t1)" :: Nil)), BuildLeft) + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("BROADCAST(t2)" :: Nil)), BuildRight) + + // Determine build side based on the join type and child relation sizes + assertBroadcastHashJoin( + sql(equiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil)), BuildLeft) + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil, "left")), BuildRight) + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil, "right")), BuildLeft) + + // Use broadcast-hash join if hinted "broadcast" and equi-join + assertBroadcastHashJoin( + sql(equiJoinQueryWithHint("BROADCAST(t2)" :: "SHUFFLE_HASH(t1)" :: Nil)), BuildRight) + assertBroadcastHashJoin( + sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "MERGE(t1, t2)" :: Nil)), BuildLeft) + assertBroadcastHashJoin( + sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "SHUFFLE_REPLICATE_NL(t2)" :: Nil)), + BuildLeft) + + // Use broadcast-nl join if hinted "broadcast" and non-equi-join + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("SHUFFLE_HASH(t2)" :: "BROADCAST(t1)" :: Nil)), BuildLeft) + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("MERGE(t1)" :: "BROADCAST(t2)" :: Nil)), BuildRight) + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1)" :: "BROADCAST(t2)" :: Nil)), + BuildRight) + + // Broadcast hint specified but not doable + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("BROADCAST(t1)" :: Nil, "left"))) + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("BROADCAST(t2)" :: Nil, "right"))) + } + } + } + + test("join strategy hint - shuffle-merge") { + withTempView("t1", "t2") { + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Int.MaxValue.toString) { + // Shuffle-merge hint specified on one side + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("SHUFFLE_MERGE(t1)" :: Nil))) + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("MERGEJOIN(t2)" :: Nil))) + + // Shuffle-merge hint specified on both sides + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("MERGE(t1, t2)" :: Nil))) + + // Shuffle-merge hint prioritized over shuffle-hash hint and shuffle-replicate-nl hint + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t2)" :: "MERGE(t1)" :: Nil, "left"))) + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("MERGE(t2)" :: "SHUFFLE_HASH(t1)" :: Nil, "right"))) + + // Broadcast hint prioritized over shuffle-merge hint, but broadcast hint is not applicable + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "MERGE(t2)" :: Nil, "left"))) + assertShuffleMergeJoin( + sql(equiJoinQueryWithHint("BROADCAST(t2)" :: "MERGE(t1)" :: Nil, "right"))) + + // Shuffle-merge hint specified but not doable + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("MERGE(t1, t2)" :: Nil, "left")), BuildRight) + } + } + } + + test("join strategy hint - shuffle-hash") { + withTempView("t1", "t2") { + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") + + val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes + val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes + assert(t1Size < t2Size) + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Int.MaxValue.toString) { + // Shuffle-hash hint specified on one side + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1)" :: Nil)), BuildLeft) + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_HASH(t2)" :: Nil)), BuildRight) + + // Determine build side based on the join type and child relation sizes + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1, t2)" :: Nil)), BuildLeft) + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1, t2)" :: Nil, "left")), BuildRight) + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1, t2)" :: Nil, "right")), BuildLeft) + + // Shuffle-hash hint prioritized over shuffle-replicate-nl hint + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t2)" :: "SHUFFLE_HASH(t1)" :: Nil)), + BuildLeft) + + // Broadcast hint prioritized over shuffle-hash hint, but broadcast hint is not applicable + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "SHUFFLE_HASH(t2)" :: Nil, "left")), + BuildRight) + assertShuffleHashJoin( + sql(equiJoinQueryWithHint("BROADCAST(t2)" :: "SHUFFLE_HASH(t1)" :: Nil, "right")), + BuildLeft) + + // Shuffle-hash hint specified but not doable + assertBroadcastHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1)" :: Nil, "left")), BuildRight) + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("SHUFFLE_HASH(t1)" :: Nil)), BuildLeft) + } + } + } + + test("join strategy hint - shuffle-replicate-nl") { + withTempView("t1", "t2") { + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Int.MaxValue.toString) { + // Shuffle-replicate-nl hint specified on one side + assertShuffleReplicateNLJoin( + sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1)" :: Nil))) + assertShuffleReplicateNLJoin( + sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t2)" :: Nil))) + + // Shuffle-replicate-nl hint specified on both sides + assertShuffleReplicateNLJoin( + sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1, t2)" :: Nil))) + + // Shuffle-merge hint prioritized over shuffle-replicate-nl hint, but shuffle-merge hint + // is not applicable + assertShuffleReplicateNLJoin( + sql(nonEquiJoinQueryWithHint("MERGE(t1)" :: "SHUFFLE_REPLICATE_NL(t2)" :: Nil))) + + // Shuffle-hash hint prioritized over shuffle-replicate-nl hint, but shuffle-hash hint is + // not applicable + assertShuffleReplicateNLJoin( + sql(nonEquiJoinQueryWithHint("SHUFFLE_HASH(t2)" :: "SHUFFLE_REPLICATE_NL(t1)" :: Nil))) + + // Shuffle-replicate-nl hint specified but not doable + assertBroadcastHashJoin( + sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1, t2)" :: Nil, "left")), BuildRight) + assertBroadcastNLJoin( + sql(nonEquiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1, t2)" :: Nil, "right")), BuildLeft) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index f238148e61c3..05c583c80e50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,6 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} +import org.apache.spark.sql.catalyst.plans.logical.BROADCAST import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.EnsureRequirements @@ -216,10 +217,10 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { val plan3 = sql(s"SELECT /*+ $name(v) */ * FROM t JOIN u ON t.id = u.id").queryExecution .optimizedPlan - assert(plan1.asInstanceOf[Join].hint.leftHint.get.broadcast) + assert(plan1.asInstanceOf[Join].hint.leftHint.get.strategy.contains(BROADCAST)) assert(plan1.asInstanceOf[Join].hint.rightHint.isEmpty) assert(plan2.asInstanceOf[Join].hint.leftHint.isEmpty) - assert(plan2.asInstanceOf[Join].hint.rightHint.get.broadcast) + assert(plan2.asInstanceOf[Join].hint.rightHint.get.strategy.contains(BROADCAST)) assert(plan3.asInstanceOf[Join].hint.leftHint.isEmpty) assert(plan3.asInstanceOf[Join].hint.rightHint.isEmpty) }