diff --git a/docs/sql-performance-tuning.md b/docs/sql-performance-tuning.md
index 68569744afcc..2a1edda84252 100644
--- a/docs/sql-performance-tuning.md
+++ b/docs/sql-performance-tuning.md
@@ -107,14 +107,22 @@ that these options will be deprecated in future release as more optimizations ar
-## Broadcast Hint for SQL Queries
-
-The `BROADCAST` hint guides Spark to broadcast each specified table when joining them with another table or view.
-When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred,
-even if the statistics is above the configuration `spark.sql.autoBroadcastJoinThreshold`.
-When both sides of a join are specified, Spark broadcasts the one having the lower statistics.
-Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join)
-support BHJ. When the broadcast nested loop join is selected, we still respect the hint.
+## Join Strategy Hints for SQL Queries
+
+The join strategy hints, namely `BROADCAST`, `MERGE`, `SHUFFLE_HASH` and `SHUFFLE_REPLICATE_NL`,
+instruct Spark to use the hinted strategy on each specified relation when joining them with another
+relation. For example, when the `BROADCAST` hint is used on table 't1', broadcast join (either
+broadcast hash join or broadcast nested loop join depending on whether there is any equi-join key)
+with 't1' as the build side will be prioritized by Spark even if the size of table 't1' suggested
+by the statistics is above the configuration `spark.sql.autoBroadcastJoinThreshold`.
+
+When different join strategy hints are specified on both sides of a join, Spark prioritizes the
+`BROADCAST` hint over the `MERGE` hint over the `SHUFFLE_HASH` hint over the `SHUFFLE_REPLICATE_NL`
+hint. When both sides are specified with the `BROADCAST` hint or the `SHUFFLE_HASH` hint, Spark will
+pick the build side based on the join type and the sizes of the relations.
+
+Note that there is no guarantee that Spark will choose the join strategy specified in the hint since
+a specific strategy may not support all join types.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 01e40e64a3e8..02d83e7e8cb6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -153,7 +153,7 @@ class Analyzer(
lazy val batches: Seq[Batch] = Seq(
Batch("Hints", fixedPoint,
- new ResolveHints.ResolveBroadcastHints(conf),
+ new ResolveHints.ResolveJoinStrategyHints(conf),
ResolveHints.ResolveCoalesceHints,
ResolveHints.RemoveAllHints),
Batch("Simple Sanity Check", Once,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
index dbd4ed845e32..9440a3f806b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.analysis
import java.util.Locale
+import scala.collection.mutable
+
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.IntegerLiteral
import org.apache.spark.sql.catalyst.plans.logical._
@@ -28,45 +30,66 @@ import org.apache.spark.sql.internal.SQLConf
/**
- * Collection of rules related to hints. The only hint currently available is broadcast join hint.
+ * Collection of rules related to hints. The only hint currently available is join strategy hint.
*
* Note that this is separately into two rules because in the future we might introduce new hint
- * rules that have different ordering requirements from broadcast.
+ * rules that have different ordering requirements from join strategies.
*/
object ResolveHints {
/**
- * For broadcast hint, we accept "BROADCAST", "BROADCASTJOIN", and "MAPJOIN", and a sequence of
- * relation aliases can be specified in the hint. A broadcast hint plan node will be inserted
- * on top of any relation (that is not aliased differently), subquery, or common table expression
- * that match the specified name.
+ * The list of allowed join strategy hints is defined in [[JoinStrategyHint.strategies]], and a
+ * sequence of relation aliases can be specified with a join strategy hint, e.g., "MERGE(a, c)",
+ * "BROADCAST(a)". A join strategy hint plan node will be inserted on top of any relation (that
+ * is not aliased differently), subquery, or common table expression that match the specified
+ * name.
*
* The hint resolution works by recursively traversing down the query plan to find a relation or
- * subquery that matches one of the specified broadcast aliases. The traversal does not go past
- * beyond any existing broadcast hints, subquery aliases.
+ * subquery that matches one of the specified relation aliases. The traversal does not go past
+ * beyond any view reference, with clause or subquery alias.
*
* This rule must happen before common table expressions.
*/
- class ResolveBroadcastHints(conf: SQLConf) extends Rule[LogicalPlan] {
- private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN")
+ class ResolveJoinStrategyHints(conf: SQLConf) extends Rule[LogicalPlan] {
+ private val STRATEGY_HINT_NAMES = JoinStrategyHint.strategies.flatMap(_.hintAliases)
def resolver: Resolver = conf.resolver
- private def applyBroadcastHint(plan: LogicalPlan, toBroadcast: Set[String]): LogicalPlan = {
+ private def createHintInfo(hintName: String): HintInfo = {
+ HintInfo(strategy =
+ JoinStrategyHint.strategies.find(
+ _.hintAliases.map(
+ _.toUpperCase(Locale.ROOT)).contains(hintName.toUpperCase(Locale.ROOT))))
+ }
+
+ private def applyJoinStrategyHint(
+ plan: LogicalPlan,
+ relations: mutable.HashSet[String],
+ hintName: String): LogicalPlan = {
// Whether to continue recursing down the tree
var recurse = true
val newNode = CurrentOrigin.withOrigin(plan.origin) {
plan match {
- case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) =>
- ResolvedHint(plan, HintInfo(broadcast = true))
- case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) =>
- ResolvedHint(plan, HintInfo(broadcast = true))
+ case ResolvedHint(u: UnresolvedRelation, hint)
+ if relations.exists(resolver(_, u.tableIdentifier.table)) =>
+ relations.remove(u.tableIdentifier.table)
+ ResolvedHint(u, createHintInfo(hintName).merge(hint, handleOverriddenHintInfo))
+ case ResolvedHint(r: SubqueryAlias, hint)
+ if relations.exists(resolver(_, r.alias)) =>
+ relations.remove(r.alias)
+ ResolvedHint(r, createHintInfo(hintName).merge(hint, handleOverriddenHintInfo))
+
+ case u: UnresolvedRelation if relations.exists(resolver(_, u.tableIdentifier.table)) =>
+ relations.remove(u.tableIdentifier.table)
+ ResolvedHint(plan, createHintInfo(hintName))
+ case r: SubqueryAlias if relations.exists(resolver(_, r.alias)) =>
+ relations.remove(r.alias)
+ ResolvedHint(plan, createHintInfo(hintName))
case _: ResolvedHint | _: View | _: With | _: SubqueryAlias =>
// Don't traverse down these nodes.
- // For an existing broadcast hint, there is no point going down (if we do, we either
- // won't change the structure, or will introduce another broadcast hint that is useless.
+ // For an existing strategy hint, there is no chance for a match from this point down.
// The rest (view, with, subquery) indicates different scopes that we shouldn't traverse
// down. Note that technically when this rule is executed, we haven't completed view
// resolution yet and as a result the view part should be deadcode. I'm leaving it here
@@ -80,25 +103,38 @@ object ResolveHints {
}
if ((plan fastEquals newNode) && recurse) {
- newNode.mapChildren(child => applyBroadcastHint(child, toBroadcast))
+ newNode.mapChildren(child => applyJoinStrategyHint(child, relations, hintName))
} else {
newNode
}
}
+ private def handleOverriddenHintInfo(hint: HintInfo): Unit = {
+ logWarning(s"Join hint $hint is overridden by another hint and will not take effect.")
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
- case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
+ case h: UnresolvedHint if STRATEGY_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
if (h.parameters.isEmpty) {
- // If there is no table alias specified, turn the entire subtree into a BroadcastHint.
- ResolvedHint(h.child, HintInfo(broadcast = true))
+ // If there is no table alias specified, apply the hint on the entire subtree.
+ ResolvedHint(h.child, createHintInfo(h.name))
} else {
- // Otherwise, find within the subtree query plans that should be broadcasted.
- applyBroadcastHint(h.child, h.parameters.map {
+ // Otherwise, find within the subtree query plans to apply the hint.
+ val relationNames = h.parameters.map {
case tableName: String => tableName
case tableId: UnresolvedAttribute => tableId.name
- case unsupported => throw new AnalysisException("Broadcast hint parameter should be " +
- s"an identifier or string but was $unsupported (${unsupported.getClass}")
- }.toSet)
+ case unsupported => throw new AnalysisException("Join strategy hint parameter " +
+ s"should be an identifier or string but was $unsupported (${unsupported.getClass}")
+ }
+ val relationNameSet = new mutable.HashSet[String]
+ relationNames.foreach(relationNameSet.add)
+
+ val applied = applyJoinStrategyHint(h.child, relationNameSet, h.name)
+ relationNameSet.foreach { n =>
+ logWarning(s"Count not find relation '$n' for join strategy hint " +
+ s"'${h.name}${relationNames.mkString("(", ", ", ")")}'.")
+ }
+ applied
}
}
}
@@ -135,7 +171,9 @@ object ResolveHints {
*/
object RemoveAllHints extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
- case h: UnresolvedHint => h.child
+ case h: UnresolvedHint =>
+ logWarning(s"Unrecognized hint: ${h.name}${h.parameters.mkString("(", ", ", ")")}")
+ h.child
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index 60066371e3dc..2d646721f87a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -374,7 +374,7 @@ object CatalogTable {
/**
* This class of statistics is used in [[CatalogTable]] to interact with metastore.
* We define this new class instead of directly using [[Statistics]] here because there are no
- * concepts of attributes or broadcast hint in catalog.
+ * concepts of attributes in catalog.
*/
case class CatalogStatistics(
sizeInBytes: BigInt,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
index a136f0493699..5586690520c2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateResolvedHint.scala
@@ -30,30 +30,58 @@ object EliminateResolvedHint extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
val pulledUp = plan transformUp {
case j: Join =>
- val leftHint = mergeHints(collectHints(j.left))
- val rightHint = mergeHints(collectHints(j.right))
- j.copy(hint = JoinHint(leftHint, rightHint))
+ val (newLeft, leftHints) = extractHintsFromPlan(j.left)
+ val (newRight, rightHints) = extractHintsFromPlan(j.right)
+ val newJoinHint = JoinHint(mergeHints(leftHints), mergeHints(rightHints))
+ j.copy(left = newLeft, right = newRight, hint = newJoinHint)
}
pulledUp.transformUp {
- case h: ResolvedHint => h.child
+ case h: ResolvedHint =>
+ handleInvalidHintInfo(h.hints)
+ h.child
}
}
+ /**
+ * Combine a list of [[HintInfo]]s into one [[HintInfo]].
+ */
private def mergeHints(hints: Seq[HintInfo]): Option[HintInfo] = {
- hints.reduceOption((h1, h2) => HintInfo(
- broadcast = h1.broadcast || h2.broadcast))
+ hints.reduceOption((h1, h2) => h1.merge(h2, handleOverriddenHintInfo))
}
- private def collectHints(plan: LogicalPlan): Seq[HintInfo] = {
+ /**
+ * Extract all hints from the plan, returning a list of extracted hints and the transformed plan
+ * with [[ResolvedHint]] nodes removed. The returned hint list comes in top-down order.
+ * Note that hints can only be extracted from under certain nodes. Those that cannot be extracted
+ * in this method will be cleaned up later by this rule, and may emit warnings depending on the
+ * configurations.
+ */
+ private def extractHintsFromPlan(plan: LogicalPlan): (LogicalPlan, Seq[HintInfo]) = {
plan match {
- case h: ResolvedHint => collectHints(h.child) :+ h.hints
- case u: UnaryNode => collectHints(u.child)
+ case h: ResolvedHint =>
+ val (plan, hints) = extractHintsFromPlan(h.child)
+ (plan, h.hints +: hints)
+ case u: UnaryNode =>
+ val (plan, hints) = extractHintsFromPlan(u.child)
+ (u.withNewChildren(Seq(plan)), hints)
// TODO revisit this logic:
// except and intersect are semi/anti-joins which won't return more data then
// their left argument, so the broadcast hint should be propagated here
- case i: Intersect => collectHints(i.left)
- case e: Except => collectHints(e.left)
- case _ => Seq.empty
+ case i: Intersect =>
+ val (plan, hints) = extractHintsFromPlan(i.left)
+ (i.copy(left = plan), hints)
+ case e: Except =>
+ val (plan, hints) = extractHintsFromPlan(e.left)
+ (e.copy(left = plan), hints)
+ case p: LogicalPlan => (p, Seq.empty)
}
}
+
+ private def handleInvalidHintInfo(hint: HintInfo): Unit = {
+ logWarning(s"A join hint $hint is specified but it is not part of a join relation.")
+ }
+
+ private def handleOverriddenHintInfo(hint: HintInfo): Unit = {
+ logWarning(s"Join hint $hint is overridden by another hint and will not take effect.")
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
index b2ba725e9d44..870dd87d8a36 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
@@ -66,17 +66,94 @@ object JoinHint {
/**
* The hint attributes to be applied on a specific node.
*
- * @param broadcast If set to true, it indicates that the broadcast hash join is the preferred join
- * strategy and the node with this hint is preferred to be the build side.
+ * @param strategy The preferred join strategy.
*/
-case class HintInfo(broadcast: Boolean = false) {
+case class HintInfo(strategy: Option[JoinStrategyHint] = None) {
- override def toString: String = {
- val hints = scala.collection.mutable.ArrayBuffer.empty[String]
- if (broadcast) {
- hints += "broadcast"
+ /**
+ * Combine this [[HintInfo]] with another [[HintInfo]] and return the new [[HintInfo]].
+ * @param other the other [[HintInfo]]
+ * @param hintOverriddenCallback a callback to notify if any [[HintInfo]] has been overridden
+ * in this merge.
+ *
+ * Currently, for join strategy hints, the new [[HintInfo]] will contain the strategy in this
+ * [[HintInfo]] if defined, otherwise the strategy in the other [[HintInfo]]. The
+ * `hintOverriddenCallback` will be called if this [[HintInfo]] and the other [[HintInfo]]
+ * both have a strategy defined but the join strategies are different.
+ */
+ def merge(other: HintInfo, hintOverriddenCallback: HintInfo => Unit): HintInfo = {
+ if (this.strategy.isDefined &&
+ other.strategy.isDefined &&
+ this.strategy.get != other.strategy.get) {
+ hintOverriddenCallback(other)
}
-
- if (hints.isEmpty) "none" else hints.mkString("(", ", ", ")")
+ HintInfo(strategy = this.strategy.orElse(other.strategy))
}
+
+ override def toString: String = strategy.map(s => s"(strategy=$s)").getOrElse("none")
+}
+
+sealed abstract class JoinStrategyHint {
+
+ def displayName: String
+ def hintAliases: Set[String]
+
+ override def toString: String = displayName
+}
+
+/**
+ * The enumeration of join strategy hints.
+ *
+ * The hinted strategy will be used for the join with which it is associated if doable. In case
+ * of contradicting strategy hints specified for each side of the join, hints are prioritized as
+ * BROADCAST over SHUFFLE_MERGE over SHUFFLE_HASH over SHUFFLE_REPLICATE_NL.
+ */
+object JoinStrategyHint {
+
+ val strategies: Set[JoinStrategyHint] = Set(
+ BROADCAST,
+ SHUFFLE_MERGE,
+ SHUFFLE_HASH,
+ SHUFFLE_REPLICATE_NL)
+}
+
+/**
+ * The hint for broadcast hash join or broadcast nested loop join, depending on the availability of
+ * equi-join keys.
+ */
+case object BROADCAST extends JoinStrategyHint {
+ override def displayName: String = "broadcast"
+ override def hintAliases: Set[String] = Set(
+ "BROADCAST",
+ "BROADCASTJOIN",
+ "MAPJOIN")
+}
+
+/**
+ * The hint for shuffle sort merge join.
+ */
+case object SHUFFLE_MERGE extends JoinStrategyHint {
+ override def displayName: String = "merge"
+ override def hintAliases: Set[String] = Set(
+ "SHUFFLE_MERGE",
+ "MERGE",
+ "MERGEJOIN")
+}
+
+/**
+ * The hint for shuffle hash join.
+ */
+case object SHUFFLE_HASH extends JoinStrategyHint {
+ override def displayName: String = "shuffle_hash"
+ override def hintAliases: Set[String] = Set(
+ "SHUFFLE_HASH")
+}
+
+/**
+ * The hint for shuffle-and-replicate nested loop join, a.k.a. cartesian product join.
+ */
+case object SHUFFLE_REPLICATE_NL extends JoinStrategyHint {
+ override def displayName: String = "shuffle_replicate_nl"
+ override def hintAliases: Set[String] = Set(
+ "SHUFFLE_REPLICATE_NL")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
index 563e8adf87ed..474e58a335e7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
@@ -17,6 +17,11 @@
package org.apache.spark.sql.catalyst.analysis
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.log4j.{AppenderSkeleton, Level}
+import org.apache.log4j.spi.LoggingEvent
+
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
@@ -27,6 +32,14 @@ import org.apache.spark.sql.catalyst.plans.logical._
class ResolveHintsSuite extends AnalysisTest {
import org.apache.spark.sql.catalyst.analysis.TestRelations._
+ class MockAppender extends AppenderSkeleton {
+ val loggingEvents = new ArrayBuffer[LoggingEvent]()
+
+ override def append(loggingEvent: LoggingEvent): Unit = loggingEvents.append(loggingEvent)
+ override def close(): Unit = {}
+ override def requiresLayout(): Boolean = false
+ }
+
test("invalid hints should be ignored") {
checkAnalysis(
UnresolvedHint("some_random_hint_that_does_not_exist", Seq("TaBlE"), table("TaBlE")),
@@ -37,17 +50,17 @@ class ResolveHintsSuite extends AnalysisTest {
test("case-sensitive or insensitive parameters") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
- ResolvedHint(testRelation, HintInfo(broadcast = true)),
+ ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
caseSensitive = false)
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")),
- ResolvedHint(testRelation, HintInfo(broadcast = true)),
+ ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
caseSensitive = false)
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
- ResolvedHint(testRelation, HintInfo(broadcast = true)),
+ ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
caseSensitive = true)
checkAnalysis(
@@ -59,28 +72,29 @@ class ResolveHintsSuite extends AnalysisTest {
test("multiple broadcast hint aliases") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))),
- Join(ResolvedHint(testRelation, HintInfo(broadcast = true)),
- ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None, JoinHint.NONE),
+ Join(ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
+ ResolvedHint(testRelation2, HintInfo(strategy = Some(BROADCAST))),
+ Inner, None, JoinHint.NONE),
caseSensitive = false)
}
test("do not traverse past existing broadcast hints") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table"),
- ResolvedHint(table("table").where('a > 1), HintInfo(broadcast = true))),
- ResolvedHint(testRelation.where('a > 1), HintInfo(broadcast = true)).analyze,
+ ResolvedHint(table("table").where('a > 1), HintInfo(strategy = Some(BROADCAST)))),
+ ResolvedHint(testRelation.where('a > 1), HintInfo(strategy = Some(BROADCAST))).analyze,
caseSensitive = false)
}
test("should work for subqueries") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")),
- ResolvedHint(testRelation, HintInfo(broadcast = true)),
+ ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
caseSensitive = false)
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)),
- ResolvedHint(testRelation, HintInfo(broadcast = true)),
+ ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
caseSensitive = false)
// Negative case: if the alias doesn't match, don't match the original table name.
@@ -105,7 +119,7 @@ class ResolveHintsSuite extends AnalysisTest {
|SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable
""".stripMargin
),
- ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(broadcast = true))
+ ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(strategy = Some(BROADCAST)))
.select('a).analyze,
caseSensitive = false)
}
@@ -155,4 +169,17 @@ class ResolveHintsSuite extends AnalysisTest {
UnresolvedHint("REPARTITION", Seq(Literal(true)), table("TaBlE")),
Seq(errMsgRepa))
}
+
+ test("log warnings for invalid hints") {
+ val logAppender = new MockAppender()
+ withLogAppender(logAppender) {
+ checkAnalysis(
+ UnresolvedHint("unknown_hint", Seq("TaBlE"), table("TaBlE")),
+ testRelation,
+ caseSensitive = false)
+ }
+ assert(logAppender.loggingEvents.exists(
+ e => e.getLevel == Level.WARN &&
+ e.getRenderedMessage.contains("Unrecognized hint: unknown_hint")))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index bf38189c2fb5..efd05a3e2b3e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -90,61 +90,35 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
/**
- * Select the proper physical plan for join based on joining keys and size of logical plan.
- *
- * At first, uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the
- * predicates can be evaluated by matching join keys. If found, join implementations are chosen
- * with the following precedence:
+ * Select the proper physical plan for join based on join strategy hints, the availability of
+ * equi-join keys and the sizes of joining relations. Below are the existing join strategies,
+ * their characteristics and their limitations.
*
* - Broadcast hash join (BHJ):
- * BHJ is not supported for full outer join. For right outer join, we only can broadcast the
- * left side. For left outer, left semi, left anti and the internal join type ExistenceJoin,
- * we only can broadcast the right side. For inner like join, we can broadcast both sides.
- * Normally, BHJ can perform faster than the other join algorithms when the broadcast side is
- * small. However, broadcasting tables is a network-intensive operation. It could cause OOM
- * or perform worse than the other join algorithms, especially when the build/broadcast side
- * is big.
- *
- * For the supported cases, users can specify the broadcast hint (e.g. the user applied the
- * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame) and session-based
- * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to adjust whether BHJ is used and
- * which join side is broadcast.
- *
- * 1) Broadcast the join side with the broadcast hint, even if the size is larger than
- * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint (only when the type
- * is inner like join), the side with a smaller estimated physical size will be broadcast.
- * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and broadcast the side
- * whose estimated physical size is smaller than the threshold. If both sides are below the
- * threshold, broadcast the smaller side. If neither is smaller, BHJ is not used.
- *
- * - Shuffle hash join: if the average size of a single partition is small enough to build a hash
- * table.
- *
- * - Sort merge: if the matching join keys are sortable.
- *
- * If there is no joining keys, Join implementations are chosen with the following precedence:
- * - BroadcastNestedLoopJoin (BNLJ):
- * BNLJ supports all the join types but the impl is OPTIMIZED for the following scenarios:
- * For right outer join, the left side is broadcast. For left outer, left semi, left anti
- * and the internal join type ExistenceJoin, the right side is broadcast. For inner like
- * joins, either side is broadcast.
+ * Only supported for equi-joins, while the join keys do not need to be sortable.
+ * Supported for all join types except full outer joins.
+ * BHJ usually performs faster than the other join algorithms when the broadcast side is
+ * small. However, broadcasting tables is a network-intensive operation and it could cause
+ * OOM or perform badly in some cases, especially when the build/broadcast side is big.
*
- * Like BHJ, users still can specify the broadcast hint and session-based
- * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to impact which side is broadcast.
+ * - Shuffle hash join:
+ * Only supported for equi-joins, while the join keys do not need to be sortable.
+ * Supported for all join types except full outer joins.
*
- * 1) Broadcast the join side with the broadcast hint, even if the size is larger than
- * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint (i.e., just for
- * inner-like join), the side with a smaller estimated physical size will be broadcast.
- * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and broadcast the side
- * whose estimated physical size is smaller than the threshold. If both sides are below the
- * threshold, broadcast the smaller side. If neither is smaller, BNLJ is not used.
+ * - Shuffle sort merge join (SMJ):
+ * Only supported for equi-joins and the join keys have to be sortable.
+ * Supported for all join types.
*
- * - CartesianProduct: for inner like join, CartesianProduct is the fallback option.
+ * - Broadcast nested loop join (BNLJ):
+ * Supports both equi-joins and non-equi-joins.
+ * Supports all the join types, but the implementation is optimized for:
+ * 1) broadcasting the left side in a right outer join;
+ * 2) broadcasting the right side in a left outer, left semi, left anti or existence join;
+ * 3) broadcasting either side in an inner-like join.
*
- * - BroadcastNestedLoopJoin (BNLJ):
- * For the other join types, BNLJ is the fallback option. Here, we just pick the broadcast
- * side with the broadcast hint. If neither side has a hint, we broadcast the side with
- * the smaller estimated physical size.
+ * - Shuffle-and-replicate nested loop join (a.k.a. cartesian product join):
+ * Supports both equi-joins and non-equi-joins.
+ * Supports only inner like joins.
*/
object JoinSelection extends Strategy with PredicateHelper {
@@ -186,126 +160,218 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => false
}
- private def broadcastSide(
- canBuildLeft: Boolean,
- canBuildRight: Boolean,
+ private def getBuildSide(
+ wantToBuildLeft: Boolean,
+ wantToBuildRight: Boolean,
left: LogicalPlan,
- right: LogicalPlan): BuildSide = {
-
- def smallerSide =
- if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft
-
- if (canBuildRight && canBuildLeft) {
- // Broadcast smaller side base on its estimated physical size
- // if both sides have broadcast hint
- smallerSide
- } else if (canBuildRight) {
- BuildRight
- } else if (canBuildLeft) {
- BuildLeft
+ right: LogicalPlan): Option[BuildSide] = {
+ if (wantToBuildLeft && wantToBuildRight) {
+ // returns the smaller side base on its estimated physical size, if we want to build the
+ // both sides.
+ Some(getSmallerSide(left, right))
+ } else if (wantToBuildLeft) {
+ Some(BuildLeft)
+ } else if (wantToBuildRight) {
+ Some(BuildRight)
} else {
- // for the last default broadcast nested loop join
- smallerSide
+ None
}
}
- private def canBroadcastByHints(
- joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): Boolean = {
- val buildLeft = canBuildLeft(joinType) && hint.leftHint.exists(_.broadcast)
- val buildRight = canBuildRight(joinType) && hint.rightHint.exists(_.broadcast)
- buildLeft || buildRight
+ private def getSmallerSide(left: LogicalPlan, right: LogicalPlan) = {
+ if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft
}
- private def broadcastSideByHints(
- joinType: JoinType, left: LogicalPlan, right: LogicalPlan, hint: JoinHint): BuildSide = {
- val buildLeft = canBuildLeft(joinType) && hint.leftHint.exists(_.broadcast)
- val buildRight = canBuildRight(joinType) && hint.rightHint.exists(_.broadcast)
- broadcastSide(buildLeft, buildRight, left, right)
+ private def hintToBroadcastLeft(hint: JoinHint): Boolean = {
+ hint.leftHint.exists(_.strategy.contains(BROADCAST))
}
- private def canBroadcastBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan)
- : Boolean = {
- val buildLeft = canBuildLeft(joinType) && canBroadcast(left)
- val buildRight = canBuildRight(joinType) && canBroadcast(right)
- buildLeft || buildRight
+ private def hintToBroadcastRight(hint: JoinHint): Boolean = {
+ hint.rightHint.exists(_.strategy.contains(BROADCAST))
}
- private def broadcastSideBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan)
- : BuildSide = {
- val buildLeft = canBuildLeft(joinType) && canBroadcast(left)
- val buildRight = canBuildRight(joinType) && canBroadcast(right)
- broadcastSide(buildLeft, buildRight, left, right)
+ private def hintToShuffleHashLeft(hint: JoinHint): Boolean = {
+ hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH))
+ }
+
+ private def hintToShuffleHashRight(hint: JoinHint): Boolean = {
+ hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH))
+ }
+
+ private def hintToSortMergeJoin(hint: JoinHint): Boolean = {
+ hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) ||
+ hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE))
+ }
+
+ private def hintToShuffleReplicateNL(hint: JoinHint): Boolean = {
+ hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) ||
+ hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL))
}
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- // --- BroadcastHashJoin --------------------------------------------------------------------
-
- // broadcast hints were specified
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint)
- if canBroadcastByHints(joinType, left, right, hint) =>
- val buildSide = broadcastSideByHints(joinType, left, right, hint)
- Seq(joins.BroadcastHashJoinExec(
- leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))
-
- // broadcast hints were not specified, so need to infer it from size and configuration.
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _)
- if canBroadcastBySizes(joinType, left, right) =>
- val buildSide = broadcastSideBySizes(joinType, left, right)
- Seq(joins.BroadcastHashJoinExec(
- leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))
-
- // --- ShuffledHashJoin ---------------------------------------------------------------------
-
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _)
- if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right)
- && muchSmaller(right, left) ||
- !RowOrdering.isOrderable(leftKeys) =>
- Seq(joins.ShuffledHashJoinExec(
- leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
-
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _)
- if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left)
- && muchSmaller(left, right) ||
- !RowOrdering.isOrderable(leftKeys) =>
- Seq(joins.ShuffledHashJoinExec(
- leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right)))
-
- // --- SortMergeJoin ------------------------------------------------------------
-
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _)
- if RowOrdering.isOrderable(leftKeys) =>
- joins.SortMergeJoinExec(
- leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
-
- // --- Without joining keys ------------------------------------------------------------
-
- // Pick BroadcastNestedLoopJoin if one side could be broadcast
- case j @ logical.Join(left, right, joinType, condition, hint)
- if canBroadcastByHints(joinType, left, right, hint) =>
- val buildSide = broadcastSideByHints(joinType, left, right, hint)
- joins.BroadcastNestedLoopJoinExec(
- planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
-
- case j @ logical.Join(left, right, joinType, condition, _)
- if canBroadcastBySizes(joinType, left, right) =>
- val buildSide = broadcastSideBySizes(joinType, left, right)
- joins.BroadcastNestedLoopJoinExec(
- planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
-
- // Pick CartesianProduct for InnerJoin
- case logical.Join(left, right, _: InnerLike, condition, _) =>
- joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil
+ // If it is an equi-join, we first look at the join hints w.r.t. the following order:
+ // 1. broadcast hint: pick broadcast hash join if the join type is supported. If both sides
+ // have the broadcast hints, choose the smaller side (based on stats) to broadcast.
+ // 2. sort merge hint: pick sort merge join if join keys are sortable.
+ // 3. shuffle hash hint: We pick shuffle hash join if the join type is supported. If both
+ // sides have the shuffle hash hints, choose the smaller side (based on stats) as the
+ // build side.
+ // 4. shuffle replicate NL hint: pick cartesian product if join type is inner like.
+ //
+ // If there is no hint or the hints are not applicable, we follow these rules one by one:
+ // 1. Pick broadcast hash join if one side is small enough to broadcast, and the join type
+ // is supported. If both sides are small, choose the smaller side (based on stats)
+ // to broadcast.
+ // 2. Pick shuffle hash join if one side is small enough to build local hash map, and is
+ // much smaller than the other side, and `spark.sql.join.preferSortMergeJoin` is false.
+ // 3. Pick sort merge join if the join keys are sortable.
+ // 4. Pick cartesian product if join type is inner like.
+ // 5. Pick broadcast nested loop join as the final solution. It may OOM but we don't have
+ // other choice.
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) =>
+ def createBroadcastHashJoin(buildLeft: Boolean, buildRight: Boolean) = {
+ val wantToBuildLeft = canBuildLeft(joinType) && buildLeft
+ val wantToBuildRight = canBuildRight(joinType) && buildRight
+ getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map { buildSide =>
+ Seq(joins.BroadcastHashJoinExec(
+ leftKeys,
+ rightKeys,
+ joinType,
+ buildSide,
+ condition,
+ planLater(left),
+ planLater(right)))
+ }
+ }
+
+ def createShuffleHashJoin(buildLeft: Boolean, buildRight: Boolean) = {
+ val wantToBuildLeft = canBuildLeft(joinType) && buildLeft
+ val wantToBuildRight = canBuildRight(joinType) && buildRight
+ getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map { buildSide =>
+ Seq(joins.ShuffledHashJoinExec(
+ leftKeys,
+ rightKeys,
+ joinType,
+ buildSide,
+ condition,
+ planLater(left),
+ planLater(right)))
+ }
+ }
+
+ def createSortMergeJoin() = {
+ if (RowOrdering.isOrderable(leftKeys)) {
+ Some(Seq(joins.SortMergeJoinExec(
+ leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right))))
+ } else {
+ None
+ }
+ }
+
+ def createCartesianProduct() = {
+ if (joinType.isInstanceOf[InnerLike]) {
+ Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), condition)))
+ } else {
+ None
+ }
+ }
+
+ def createJoinWithoutHint() = {
+ createBroadcastHashJoin(canBroadcast(left), canBroadcast(right))
+ .orElse {
+ if (!conf.preferSortMergeJoin) {
+ createShuffleHashJoin(
+ canBuildLocalHashMap(left) && muchSmaller(left, right),
+ canBuildLocalHashMap(right) && muchSmaller(right, left))
+ } else {
+ None
+ }
+ }
+ .orElse(createSortMergeJoin())
+ .orElse(createCartesianProduct())
+ .getOrElse {
+ // This join could be very slow or OOM
+ val buildSide = getSmallerSide(left, right)
+ Seq(joins.BroadcastNestedLoopJoinExec(
+ planLater(left), planLater(right), buildSide, joinType, condition))
+ }
+ }
+ createBroadcastHashJoin(hintToBroadcastLeft(hint), hintToBroadcastRight(hint))
+ .orElse { if (hintToSortMergeJoin(hint)) createSortMergeJoin() else None }
+ .orElse(createShuffleHashJoin(hintToShuffleHashLeft(hint), hintToShuffleHashRight(hint)))
+ .orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None }
+ .getOrElse(createJoinWithoutHint())
+
+ // If it is not an equi-join, we first look at the join hints w.r.t. the following order:
+ // 1. broadcast hint: pick broadcast nested loop join. If both sides have the broadcast
+ // hints, choose the smaller side (based on stats) to broadcast.
+ // 2. shuffle replicate NL hint: pick cartesian product if join type is inner like.
+ //
+ // If there is no hint or the hints are not applicable, we follow these rules one by one:
+ // 1. Pick cartesian product if join type is inner like, and both sides are too big to
+ // to broadcast.
+ // 2. Pick broadcast nested loop join. Pick the smaller side (based on stats) to broadcast.
case logical.Join(left, right, joinType, condition, hint) =>
- val buildSide = broadcastSide(
- hint.leftHint.exists(_.broadcast), hint.rightHint.exists(_.broadcast), left, right)
- // This join could be very slow or OOM
- joins.BroadcastNestedLoopJoinExec(
- planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
+ def createBroadcastNLJoin(buildLeft: Boolean, buildRight: Boolean) = {
+ getBuildSide(buildLeft, buildRight, left, right).map { buildSide =>
+ Seq(joins.BroadcastNestedLoopJoinExec(
+ planLater(left), planLater(right), buildSide, joinType, condition))
+ }
+ }
- // --- Cases where this strategy does not apply ---------------------------------------------
+ def createCartesianProduct() = {
+ if (joinType.isInstanceOf[InnerLike]) {
+ Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), condition)))
+ } else {
+ None
+ }
+ }
+
+ def createJoinWithoutHint() = {
+ (if (!canBroadcast(left) && !canBroadcast(right)) createCartesianProduct() else None)
+ .getOrElse {
+ // This join could be very slow or OOM
+ val buildSide = getSmallerSide(left, right)
+ Seq(joins.BroadcastNestedLoopJoinExec(
+ planLater(left), planLater(right), buildSide, joinType, condition))
+ }
+ }
+
+ if (joinType.isInstanceOf[InnerLike] || joinType == FullOuter) {
+ createBroadcastNLJoin(hintToBroadcastLeft(hint), hintToBroadcastRight(hint))
+ .orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None }
+ .getOrElse(createJoinWithoutHint())
+ } else {
+ val smallerSide = getSmallerSide(left, right)
+ val buildSide = if (canBuildLeft(joinType)) {
+ // For RIGHT JOIN, we may broadcast left side even if the hint asks us to broadcast
+ // the right side. This is for history reasons.
+ if (hintToBroadcastLeft(hint) || canBroadcast(left)) {
+ BuildLeft
+ } else if (hintToBroadcastRight(hint)) {
+ BuildRight
+ } else {
+ smallerSide
+ }
+ } else {
+ // For LEFT JOIN, we may broadcast right side even if the hint asks us to broadcast
+ // the left side. This is for history reasons.
+ if (hintToBroadcastRight(hint) || canBroadcast(right)) {
+ BuildRight
+ } else if (hintToBroadcastLeft(hint)) {
+ BuildLeft
+ } else {
+ smallerSide
+ }
+ }
+ Seq(joins.BroadcastNestedLoopJoinExec(
+ planLater(left), planLater(right), buildSide, joinType, condition))
+ }
+
+ // --- Cases where this strategy does not apply ---------------------------------------------
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index d0be216e6a2f..79045e8a5aec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedFunction}
import org.apache.spark.sql.internal.SQLConf
@@ -1045,7 +1045,7 @@ object functions {
*/
def broadcast[T](df: Dataset[T]): Dataset[T] = {
Dataset[T](df.sparkSession,
- ResolvedHint(df.logicalPlan, HintInfo(broadcast = true)))(df.exprEnc)
+ ResolvedHint(df.logicalPlan, HintInfo(strategy = Some(BROADCAST))))(df.exprEnc)
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 4d63390fb452..92157d8ad49e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.executor.DataReadMethod.DataReadMethod
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
-import org.apache.spark.sql.catalyst.plans.logical.Join
+import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Join}
import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -951,7 +951,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
case Join(_, _, _, _, hint) => hint
}
assert(hint.size == 1)
- assert(hint(0).leftHint.get.broadcast)
+ assert(hint(0).leftHint.get.strategy.contains(BROADCAST))
assert(hint(0).rightHint.isEmpty)
// Clean-up
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
index 67f0f1a6fd23..9c2dc0c62b2f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
@@ -17,8 +17,14 @@
package org.apache.spark.sql
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.log4j.{AppenderSkeleton, Level}
+import org.apache.log4j.spi.LoggingEvent
+
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
@@ -30,6 +36,41 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
lazy val df2 = df.selectExpr("id as b1", "id as b2")
lazy val df3 = df.selectExpr("id as c1", "id as c2")
+ class MockAppender extends AppenderSkeleton {
+ val loggingEvents = new ArrayBuffer[LoggingEvent]()
+
+ override def append(loggingEvent: LoggingEvent): Unit = loggingEvents.append(loggingEvent)
+ override def close(): Unit = {}
+ override def requiresLayout(): Boolean = false
+ }
+
+ def msgNoHintRelationFound(relation: String, hint: String): String =
+ s"Count not find relation '$relation' for join strategy hint '$hint'."
+
+ def msgNoJoinForJoinHint(strategy: String): String =
+ s"A join hint (strategy=$strategy) is specified but it is not part of a join relation."
+
+ def msgJoinHintOverridden(strategy: String): String =
+ s"Join hint (strategy=$strategy) is overridden by another hint and will not take effect."
+
+ def verifyJoinHintWithWarnings(
+ df: => DataFrame,
+ expectedHints: Seq[JoinHint],
+ warnings: Seq[String]): Unit = {
+ val logAppender = new MockAppender()
+ withLogAppender(logAppender) {
+ verifyJoinHint(df, expectedHints)
+ }
+ val warningMessages = logAppender.loggingEvents
+ .filter(_.getLevel == Level.WARN)
+ .map(_.getRenderedMessage)
+ .filter(_.contains("hint"))
+ assert(warningMessages.size == warnings.size)
+ warnings.foreach { w =>
+ assert(warningMessages.contains(w))
+ }
+ }
+
def verifyJoinHint(df: DataFrame, expectedHints: Seq[JoinHint]): Unit = {
val optimized = df.queryExecution.optimizedPlan
val joinHints = optimized collect {
@@ -43,14 +84,14 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
verifyJoinHint(
df.hint("broadcast").join(df, "id"),
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) :: Nil
)
verifyJoinHint(
df.join(df.hint("broadcast"), "id"),
JoinHint(
None,
- Some(HintInfo(broadcast = true))) :: Nil
+ Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil
)
}
@@ -59,18 +100,18 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
df1.join(df2.hint("broadcast").join(df3, 'b1 === 'c1).hint("broadcast"), 'a1 === 'c1),
JoinHint(
None,
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) :: Nil
)
verifyJoinHint(
df1.hint("broadcast").join(df2, 'a1 === 'b1).hint("broadcast").join(df3, 'a1 === 'c1),
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) ::
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) :: Nil
)
}
@@ -89,13 +130,13 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
|) b on a.a1 = b.b1
""".stripMargin),
JoinHint(
- Some(HintInfo(broadcast = true)),
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST))),
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint(
None,
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) :: Nil
)
}
@@ -112,9 +153,9 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
"where a.a1 = b.b1 and b.b1 = c.c1"),
JoinHint(
None,
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) :: Nil
)
verifyJoinHint(
@@ -122,25 +163,25 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
"where a.a1 = b.b1 and b.b1 = c.c1"),
JoinHint.NONE ::
JoinHint(
- Some(HintInfo(broadcast = true)),
- Some(HintInfo(broadcast = true))) :: Nil
+ Some(HintInfo(strategy = Some(BROADCAST))),
+ Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil
)
verifyJoinHint(
sql("select /*+ broadcast(b, c)*/ * from a, c, b " +
"where a.a1 = b.b1 and b.b1 = c.c1"),
JoinHint(
None,
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint(
None,
- Some(HintInfo(broadcast = true))) :: Nil
+ Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil
)
verifyJoinHint(
df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast")
.join(df3, 'b1 === 'c1 && 'a1 < 10),
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) ::
JoinHint.NONE :: Nil
)
@@ -151,7 +192,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
.join(df, 'b1 === 'id),
JoinHint.NONE ::
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) ::
JoinHint.NONE :: Nil
)
@@ -164,7 +205,7 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
verifyJoinHint(
df.hint("broadcast").except(dfSub).join(df, "id"),
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
None) ::
JoinHint.NONE :: Nil
)
@@ -172,31 +213,112 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
df.join(df.hint("broadcast").intersect(dfSub), "id"),
JoinHint(
None,
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint.NONE :: Nil
)
}
test("hint merge") {
- verifyJoinHint(
+ verifyJoinHintWithWarnings(
df.hint("broadcast").filter('id > 2).hint("broadcast").join(df, "id"),
JoinHint(
- Some(HintInfo(broadcast = true)),
- None) :: Nil
+ Some(HintInfo(strategy = Some(BROADCAST))),
+ None) :: Nil,
+ Nil
)
- verifyJoinHint(
+ verifyJoinHintWithWarnings(
df.join(df.hint("broadcast").limit(2).hint("broadcast"), "id"),
JoinHint(
None,
- Some(HintInfo(broadcast = true))) :: Nil
+ Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil,
+ Nil
+ )
+ verifyJoinHintWithWarnings(
+ df.hint("merge").filter('id > 2).hint("shuffle_hash").join(df, "id").hint("broadcast"),
+ JoinHint(
+ Some(HintInfo(strategy = Some(SHUFFLE_HASH))),
+ None) :: Nil,
+ msgJoinHintOverridden("merge") ::
+ msgNoJoinForJoinHint("broadcast") :: Nil
+ )
+ verifyJoinHintWithWarnings(
+ df.join(df.hint("broadcast").limit(2).hint("merge"), "id")
+ .hint("shuffle_hash")
+ .hint("shuffle_replicate_nl")
+ .join(df, "id"),
+ JoinHint(
+ Some(HintInfo(strategy = Some(SHUFFLE_REPLICATE_NL))),
+ None) ::
+ JoinHint(
+ None,
+ Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) :: Nil,
+ msgJoinHintOverridden("broadcast") ::
+ msgJoinHintOverridden("shuffle_hash") :: Nil
)
}
+ test("hint merge - SQL") {
+ withTempView("a", "b", "c") {
+ df1.createOrReplaceTempView("a")
+ df2.createOrReplaceTempView("b")
+ df3.createOrReplaceTempView("c")
+ verifyJoinHintWithWarnings(
+ sql("select /*+ shuffle_hash merge(a, c) broadcast(a, b)*/ * from a, b, c " +
+ "where a.a1 = b.b1 and b.b1 = c.c1"),
+ JoinHint(
+ None,
+ Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) ::
+ JoinHint(
+ Some(HintInfo(strategy = Some(SHUFFLE_MERGE))),
+ Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil,
+ msgNoJoinForJoinHint("shuffle_hash") ::
+ msgJoinHintOverridden("broadcast") :: Nil
+ )
+ verifyJoinHintWithWarnings(
+ sql("select /*+ shuffle_hash(a, b) merge(b, d) broadcast(b)*/ * from a, b, c " +
+ "where a.a1 = b.b1 and b.b1 = c.c1"),
+ JoinHint.NONE ::
+ JoinHint(
+ Some(HintInfo(strategy = Some(SHUFFLE_HASH))),
+ Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil,
+ msgNoHintRelationFound("d", "merge(b, d)") ::
+ msgJoinHintOverridden("broadcast") ::
+ msgJoinHintOverridden("merge") :: Nil
+ )
+ verifyJoinHintWithWarnings(
+ sql(
+ """
+ |select /*+ broadcast(a, c) merge(a, d)*/ * from a
+ |join (
+ | select /*+ shuffle_hash(c) shuffle_replicate_nl(b, c)*/ * from b
+ | join c on b.b1 = c.c1
+ |) as d
+ |on a.a2 = d.b2
+ """.stripMargin),
+ JoinHint(
+ Some(HintInfo(strategy = Some(BROADCAST))),
+ Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) ::
+ JoinHint(
+ Some(HintInfo(strategy = Some(SHUFFLE_REPLICATE_NL))),
+ Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil,
+ msgNoHintRelationFound("c", "broadcast(a, c)") ::
+ msgJoinHintOverridden("merge") ::
+ msgJoinHintOverridden("shuffle_replicate_nl") :: Nil
+ )
+ }
+ }
+
test("nested hint") {
verifyJoinHint(
df.hint("broadcast").hint("broadcast").filter('id > 2).join(df, "id"),
JoinHint(
- Some(HintInfo(broadcast = true)),
+ Some(HintInfo(strategy = Some(BROADCAST))),
+ None) :: Nil
+ )
+ verifyJoinHint(
+ df.hint("shuffle_hash").hint("broadcast").hint("merge").filter('id > 2).join(df, "id"),
+ JoinHint(
+ Some(HintInfo(strategy = Some(SHUFFLE_MERGE))),
None) :: Nil
)
}
@@ -209,12 +331,230 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
join.join(broadcasted, "id").join(broadcasted, "id"),
JoinHint(
None,
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint(
None,
- Some(HintInfo(broadcast = true))) ::
+ Some(HintInfo(strategy = Some(BROADCAST)))) ::
JoinHint.NONE :: JoinHint.NONE :: JoinHint.NONE :: Nil
)
}
}
+
+ def equiJoinQueryWithHint(hints: Seq[String], joinType: String = "INNER"): String =
+ hints.map("/*+ " + _ + " */").mkString(
+ "SELECT ", " ", s" * FROM t1 $joinType JOIN t2 ON t1.key = t2.key")
+
+ def nonEquiJoinQueryWithHint(hints: Seq[String], joinType: String = "INNER"): String =
+ hints.map("/*+ " + _ + " */").mkString(
+ "SELECT ", " ", s" * FROM t1 $joinType JOIN t2 ON t1.key > t2.key")
+
+ private def assertBroadcastHashJoin(df: DataFrame, buildSide: BuildSide): Unit = {
+ val executedPlan = df.queryExecution.executedPlan
+ val broadcastHashJoins = executedPlan.collect {
+ case b: BroadcastHashJoinExec => b
+ }
+ assert(broadcastHashJoins.size == 1)
+ assert(broadcastHashJoins.head.buildSide == buildSide)
+ }
+
+ private def assertBroadcastNLJoin(df: DataFrame, buildSide: BuildSide): Unit = {
+ val executedPlan = df.queryExecution.executedPlan
+ val broadcastNLJoins = executedPlan.collect {
+ case b: BroadcastNestedLoopJoinExec => b
+ }
+ assert(broadcastNLJoins.size == 1)
+ assert(broadcastNLJoins.head.buildSide == buildSide)
+ }
+
+ private def assertShuffleHashJoin(df: DataFrame, buildSide: BuildSide): Unit = {
+ val executedPlan = df.queryExecution.executedPlan
+ val shuffleHashJoins = executedPlan.collect {
+ case s: ShuffledHashJoinExec => s
+ }
+ assert(shuffleHashJoins.size == 1)
+ assert(shuffleHashJoins.head.buildSide == buildSide)
+ }
+
+ private def assertShuffleMergeJoin(df: DataFrame): Unit = {
+ val executedPlan = df.queryExecution.executedPlan
+ val shuffleMergeJoins = executedPlan.collect {
+ case s: SortMergeJoinExec => s
+ }
+ assert(shuffleMergeJoins.size == 1)
+ }
+
+ private def assertShuffleReplicateNLJoin(df: DataFrame): Unit = {
+ val executedPlan = df.queryExecution.executedPlan
+ val shuffleReplicateNLJoins = executedPlan.collect {
+ case c: CartesianProductExec => c
+ }
+ assert(shuffleReplicateNLJoins.size == 1)
+ }
+
+ test("join strategy hint - broadcast") {
+ withTempView("t1", "t2") {
+ Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
+ Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2")
+
+ val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
+ val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
+ assert(t1Size < t2Size)
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ // Broadcast hint specified on one side
+ assertBroadcastHashJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t1)" :: Nil)), BuildLeft)
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("BROADCAST(t2)" :: Nil)), BuildRight)
+
+ // Determine build side based on the join type and child relation sizes
+ assertBroadcastHashJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil)), BuildLeft)
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil, "left")), BuildRight)
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("BROADCAST(t1, t2)" :: Nil, "right")), BuildLeft)
+
+ // Use broadcast-hash join if hinted "broadcast" and equi-join
+ assertBroadcastHashJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t2)" :: "SHUFFLE_HASH(t1)" :: Nil)), BuildRight)
+ assertBroadcastHashJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "MERGE(t1, t2)" :: Nil)), BuildLeft)
+ assertBroadcastHashJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "SHUFFLE_REPLICATE_NL(t2)" :: Nil)),
+ BuildLeft)
+
+ // Use broadcast-nl join if hinted "broadcast" and non-equi-join
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("SHUFFLE_HASH(t2)" :: "BROADCAST(t1)" :: Nil)), BuildLeft)
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("MERGE(t1)" :: "BROADCAST(t2)" :: Nil)), BuildRight)
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1)" :: "BROADCAST(t2)" :: Nil)),
+ BuildRight)
+
+ // Broadcast hint specified but not doable
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t1)" :: Nil, "left")))
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t2)" :: Nil, "right")))
+ }
+ }
+ }
+
+ test("join strategy hint - shuffle-merge") {
+ withTempView("t1", "t2") {
+ Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
+ Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2")
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Int.MaxValue.toString) {
+ // Shuffle-merge hint specified on one side
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_MERGE(t1)" :: Nil)))
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("MERGEJOIN(t2)" :: Nil)))
+
+ // Shuffle-merge hint specified on both sides
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("MERGE(t1, t2)" :: Nil)))
+
+ // Shuffle-merge hint prioritized over shuffle-hash hint and shuffle-replicate-nl hint
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t2)" :: "MERGE(t1)" :: Nil, "left")))
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("MERGE(t2)" :: "SHUFFLE_HASH(t1)" :: Nil, "right")))
+
+ // Broadcast hint prioritized over shuffle-merge hint, but broadcast hint is not applicable
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "MERGE(t2)" :: Nil, "left")))
+ assertShuffleMergeJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t2)" :: "MERGE(t1)" :: Nil, "right")))
+
+ // Shuffle-merge hint specified but not doable
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("MERGE(t1, t2)" :: Nil, "left")), BuildRight)
+ }
+ }
+ }
+
+ test("join strategy hint - shuffle-hash") {
+ withTempView("t1", "t2") {
+ Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
+ Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2")
+
+ val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
+ val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
+ assert(t1Size < t2Size)
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Int.MaxValue.toString) {
+ // Shuffle-hash hint specified on one side
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1)" :: Nil)), BuildLeft)
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_HASH(t2)" :: Nil)), BuildRight)
+
+ // Determine build side based on the join type and child relation sizes
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1, t2)" :: Nil)), BuildLeft)
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1, t2)" :: Nil, "left")), BuildRight)
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1, t2)" :: Nil, "right")), BuildLeft)
+
+ // Shuffle-hash hint prioritized over shuffle-replicate-nl hint
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t2)" :: "SHUFFLE_HASH(t1)" :: Nil)),
+ BuildLeft)
+
+ // Broadcast hint prioritized over shuffle-hash hint, but broadcast hint is not applicable
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t1)" :: "SHUFFLE_HASH(t2)" :: Nil, "left")),
+ BuildRight)
+ assertShuffleHashJoin(
+ sql(equiJoinQueryWithHint("BROADCAST(t2)" :: "SHUFFLE_HASH(t1)" :: Nil, "right")),
+ BuildLeft)
+
+ // Shuffle-hash hint specified but not doable
+ assertBroadcastHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_HASH(t1)" :: Nil, "left")), BuildRight)
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("SHUFFLE_HASH(t1)" :: Nil)), BuildLeft)
+ }
+ }
+ }
+
+ test("join strategy hint - shuffle-replicate-nl") {
+ withTempView("t1", "t2") {
+ Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1")
+ Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2")
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Int.MaxValue.toString) {
+ // Shuffle-replicate-nl hint specified on one side
+ assertShuffleReplicateNLJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1)" :: Nil)))
+ assertShuffleReplicateNLJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t2)" :: Nil)))
+
+ // Shuffle-replicate-nl hint specified on both sides
+ assertShuffleReplicateNLJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1, t2)" :: Nil)))
+
+ // Shuffle-merge hint prioritized over shuffle-replicate-nl hint, but shuffle-merge hint
+ // is not applicable
+ assertShuffleReplicateNLJoin(
+ sql(nonEquiJoinQueryWithHint("MERGE(t1)" :: "SHUFFLE_REPLICATE_NL(t2)" :: Nil)))
+
+ // Shuffle-hash hint prioritized over shuffle-replicate-nl hint, but shuffle-hash hint is
+ // not applicable
+ assertShuffleReplicateNLJoin(
+ sql(nonEquiJoinQueryWithHint("SHUFFLE_HASH(t2)" :: "SHUFFLE_REPLICATE_NL(t1)" :: Nil)))
+
+ // Shuffle-replicate-nl hint specified but not doable
+ assertBroadcastHashJoin(
+ sql(equiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1, t2)" :: Nil, "left")), BuildRight)
+ assertBroadcastNLJoin(
+ sql(nonEquiJoinQueryWithHint("SHUFFLE_REPLICATE_NL(t1, t2)" :: Nil, "right")), BuildLeft)
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index f238148e61c3..05c583c80e50 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -22,6 +22,7 @@ import scala.reflect.ClassTag
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft}
+import org.apache.spark.sql.catalyst.plans.logical.BROADCAST
import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange.EnsureRequirements
@@ -216,10 +217,10 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
val plan3 = sql(s"SELECT /*+ $name(v) */ * FROM t JOIN u ON t.id = u.id").queryExecution
.optimizedPlan
- assert(plan1.asInstanceOf[Join].hint.leftHint.get.broadcast)
+ assert(plan1.asInstanceOf[Join].hint.leftHint.get.strategy.contains(BROADCAST))
assert(plan1.asInstanceOf[Join].hint.rightHint.isEmpty)
assert(plan2.asInstanceOf[Join].hint.leftHint.isEmpty)
- assert(plan2.asInstanceOf[Join].hint.rightHint.get.broadcast)
+ assert(plan2.asInstanceOf[Join].hint.rightHint.get.strategy.contains(BROADCAST))
assert(plan3.asInstanceOf[Join].hint.leftHint.isEmpty)
assert(plan3.asInstanceOf[Join].hint.rightHint.isEmpty)
}