Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class Analyzer(

lazy val batches: Seq[Batch] = Seq(
Batch("Hints", fixedPoint,
new ResolveHints.ResolveBroadcastHints(conf),
new ResolveHints.ResolveJoinStrategyHints(conf),
ResolveHints.ResolveCoalesceHints,
ResolveHints.RemoveAllHints),
Batch("Simple Sanity Check", Once,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,45 +28,56 @@ import org.apache.spark.sql.internal.SQLConf


/**
* Collection of rules related to hints. The only hint currently available is broadcast join hint.
* Collection of rules related to hints. The only hint currently available is join strategy hint.
*
* Note that this is separately into two rules because in the future we might introduce new hint
* rules that have different ordering requirements from broadcast.
* rules that have different ordering requirements from join strategies.
*/
object ResolveHints {

/**
* For broadcast hint, we accept "BROADCAST", "BROADCASTJOIN", and "MAPJOIN", and a sequence of
* relation aliases can be specified in the hint. A broadcast hint plan node will be inserted
* on top of any relation (that is not aliased differently), subquery, or common table expression
* that match the specified name.
* The list of allowed join strategy hints is defined in [[JoinStrategyHint.strategies]], and a
* sequence of relation aliases can be specified with a join strategy hint, e.g., "MERGE(a, c)",
* "BROADCAST(a)". A join strategy hint plan node will be inserted on top of any relation (that
* is not aliased differently), subquery, or common table expression that match the specified
* name.
*
* The hint resolution works by recursively traversing down the query plan to find a relation or
* subquery that matches one of the specified broadcast aliases. The traversal does not go past
* beyond any existing broadcast hints, subquery aliases.
* subquery that matches one of the specified relation aliases. The traversal does not go past
* beyond any view reference, with clause or subquery alias.
*
* This rule must happen before common table expressions.
*/
class ResolveBroadcastHints(conf: SQLConf) extends Rule[LogicalPlan] {
private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN")
class ResolveJoinStrategyHints(conf: SQLConf) extends Rule[LogicalPlan] {
private val STRATEGY_HINT_NAMES = JoinStrategyHint.strategies.flatMap(_.hintAliases)

def resolver: Resolver = conf.resolver

private def applyBroadcastHint(plan: LogicalPlan, toBroadcast: Set[String]): LogicalPlan = {
private def createHintInfo(hintName: String): HintInfo = {
HintInfo(strategy =
JoinStrategyHint.strategies.find(
_.hintAliases.map(
_.toUpperCase(Locale.ROOT)).contains(hintName.toUpperCase(Locale.ROOT))))
}

private def applyJoinStrategyHint(
plan: LogicalPlan,
relations: Set[String],
hintName: String): LogicalPlan = {
// Whether to continue recursing down the tree
var recurse = true

val newNode = CurrentOrigin.withOrigin(plan.origin) {
plan match {
case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) =>
ResolvedHint(plan, HintInfo(broadcast = true))
case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) =>
ResolvedHint(plan, HintInfo(broadcast = true))
case u: UnresolvedRelation if relations.exists(resolver(_, u.tableIdentifier.table)) =>
ResolvedHint(plan, createHintInfo(hintName))
case r: SubqueryAlias if relations.exists(resolver(_, r.alias)) =>
ResolvedHint(plan, createHintInfo(hintName))

case _: ResolvedHint | _: View | _: With | _: SubqueryAlias =>
// Don't traverse down these nodes.
// For an existing broadcast hint, there is no point going down (if we do, we either
// won't change the structure, or will introduce another broadcast hint that is useless.
// For an existing strategy hint, there is no point going down (if we do, we either
// won't change the structure, or will introduce another strategy hint that is useless.
// The rest (view, with, subquery) indicates different scopes that we shouldn't traverse
// down. Note that technically when this rule is executed, we haven't completed view
// resolution yet and as a result the view part should be deadcode. I'm leaving it here
Expand All @@ -80,25 +91,25 @@ object ResolveHints {
}

if ((plan fastEquals newNode) && recurse) {
newNode.mapChildren(child => applyBroadcastHint(child, toBroadcast))
newNode.mapChildren(child => applyJoinStrategyHint(child, relations, hintName))
} else {
newNode
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
case h: UnresolvedHint if STRATEGY_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
if (h.parameters.isEmpty) {
// If there is no table alias specified, turn the entire subtree into a BroadcastHint.
ResolvedHint(h.child, HintInfo(broadcast = true))
// If there is no table alias specified, apply the hint on the entire subtree.
ResolvedHint(h.child, createHintInfo(h.name))
} else {
// Otherwise, find within the subtree query plans that should be broadcasted.
applyBroadcastHint(h.child, h.parameters.map {
// Otherwise, find within the subtree query plans to apply the hint.
applyJoinStrategyHint(h.child, h.parameters.map {
case tableName: String => tableName
case tableId: UnresolvedAttribute => tableId.name
case unsupported => throw new AnalysisException("Broadcast hint parameter should be " +
s"an identifier or string but was $unsupported (${unsupported.getClass}")
}.toSet)
case unsupported => throw new AnalysisException("Join strategy hint parameter " +
s"should be an identifier or string but was $unsupported (${unsupported.getClass}")
}.toSet, h.name)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ object CatalogTable {
/**
* This class of statistics is used in [[CatalogTable]] to interact with metastore.
* We define this new class instead of directly using [[Statistics]] here because there are no
* concepts of attributes or broadcast hint in catalog.
* concepts of attributes in catalog.
*/
case class CatalogStatistics(
sizeInBytes: BigInt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ object EliminateResolvedHint extends Rule[LogicalPlan] {
}

private def mergeHints(hints: Seq[HintInfo]): Option[HintInfo] = {
hints.reduceOption((h1, h2) => HintInfo(
broadcast = h1.broadcast || h2.broadcast))
hints.reduceOption((h1, h2) => h1.merge(h2))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks suspicious. What is the conflict resolution policy we are following here? A behavior change? Do we need to log the inputs and the resulting hint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no behavior change here. There cannot be, coz we only had one hint before this PR, so a broadcast + broadcast = broadcast. The new behavior is defined in the description of this PR:

Conflicts within either side of the join: take the first strategy hint specified in the query, or the top hint node in Dataset. For example, in "select /*+ merge(t1) / /+ broadcast(t1) */ k1, v2 from t1 join t2 on t1.k1 = t2.k2", take "merge(t1)"; in df1.hint("merge").hint("shuffle_hash").join(df2), take "shuffle_hash". This is a general hint conflict resolving strategy, not specific to join strategy hint.

The merge function is implemented in HintInfo, which should be responsible for deciding the strategy of merging hints between a node on top and a node on the bottom.

}

private def collectHints(plan: LogicalPlan): Seq[HintInfo] = {
Expand All @@ -50,7 +49,7 @@ object EliminateResolvedHint extends Rule[LogicalPlan] {
case u: UnaryNode => collectHints(u.child)
// TODO revisit this logic:
// except and intersect are semi/anti-joins which won't return more data then
// their left argument, so the broadcast hint should be propagated here
// their left argument, so the join strategy hint should be propagated here
case i: Intersect => collectHints(i.left)
case e: Except => collectHints(e.left)
case _ => Seq.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,76 @@ object JoinHint {
/**
* The hint attributes to be applied on a specific node.
*
* @param broadcast If set to true, it indicates that the broadcast hash join is the preferred join
* strategy and the node with this hint is preferred to be the build side.
* @param strategy The preferred join strategy.
*/
case class HintInfo(broadcast: Boolean = false) {
case class HintInfo(strategy: Option[JoinStrategyHint] = None) {

/**
* Combine two [[HintInfo]]s into one [[HintInfo]], in which the new strategy will the strategy
* in this [[HintInfo]] if defined, otherwise the strategy in the other [[HintInfo]].
*/
def merge(other: HintInfo): HintInfo = {
HintInfo(strategy = this.strategy.orElse(other.strategy))
}

override def toString: String = {
val hints = scala.collection.mutable.ArrayBuffer.empty[String]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we don't need to create an array buffer here.

if (broadcast) {
hints += "broadcast"
if (strategy.isDefined) {
hints += s"strategy=${strategy.get}"
}

if (hints.isEmpty) "none" else hints.mkString("(", ", ", ")")
}
}

sealed abstract class JoinStrategyHint {

def displayName: String
def hintAliases: Set[String]

override def toString: String = displayName
}

/**
* The enumeration of join strategy hints.
*
* The hinted strategy will be used for the join with which it is associated if doable. In case
* of contradicting strategy hints specified for each side of the join, hints are prioritized as
* BROADCAST over SHUFFLE_MERGE over SHUFFLE_HASH over SHUFFLE_REPLICATE_NL.
*/
object JoinStrategyHint {

val strategies: Set[JoinStrategyHint] = Set(
BROADCAST,
SHUFFLE_MERGE,
SHUFFLE_HASH,
SHUFFLE_REPLICATE_NL)
}

case object BROADCAST extends JoinStrategyHint {
override def displayName: String = "broadcast-hash"
override def hintAliases: Set[String] = Set(
"BROADCAST",
"BROADCASTJOIN",
"MAPJOIN")
}

case object SHUFFLE_MERGE extends JoinStrategyHint {
override def displayName: String = "shuffle-merge"
override def hintAliases: Set[String] = Set(
"SHUFFLE_MERGE",
"MERGE",
"MERGEJOIN")
}

case object SHUFFLE_HASH extends JoinStrategyHint {
override def displayName: String = "shuffle-hash"
override def hintAliases: Set[String] = Set(
"SHUFFLE_HASH")
}

case object SHUFFLE_REPLICATE_NL extends JoinStrategyHint {
override def displayName: String = "shuffle-replicate-nested-loop"
override def hintAliases: Set[String] = Set(
"SHUFFLE_REPLICATE_NL")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hint for cartesian products is useful for users?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. In the default logic, broadcast-nl is prioritized over shuffle-replicate-nl (cartesian-product), so this can be used for special cases where shuffle-replicate-nl is favored.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might need a code comment to explain SHUFFLE_REPLICATE_NL is cartesian products.

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ class ResolveHintsSuite extends AnalysisTest {
test("case-sensitive or insensitive parameters") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
ResolvedHint(testRelation, HintInfo(broadcast = true)),
ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
caseSensitive = false)

checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")),
ResolvedHint(testRelation, HintInfo(broadcast = true)),
ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
caseSensitive = false)

checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
ResolvedHint(testRelation, HintInfo(broadcast = true)),
ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
caseSensitive = true)

checkAnalysis(
Expand All @@ -59,28 +59,29 @@ class ResolveHintsSuite extends AnalysisTest {
test("multiple broadcast hint aliases") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))),
Join(ResolvedHint(testRelation, HintInfo(broadcast = true)),
ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None, JoinHint.NONE),
Join(ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
ResolvedHint(testRelation2, HintInfo(strategy = Some(BROADCAST))),
Inner, None, JoinHint.NONE),
caseSensitive = false)
}

test("do not traverse past existing broadcast hints") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table"),
ResolvedHint(table("table").where('a > 1), HintInfo(broadcast = true))),
ResolvedHint(testRelation.where('a > 1), HintInfo(broadcast = true)).analyze,
ResolvedHint(table("table").where('a > 1), HintInfo(strategy = Some(BROADCAST)))),
ResolvedHint(testRelation.where('a > 1), HintInfo(strategy = Some(BROADCAST))).analyze,
caseSensitive = false)
}

test("should work for subqueries") {
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")),
ResolvedHint(testRelation, HintInfo(broadcast = true)),
ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
caseSensitive = false)

checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)),
ResolvedHint(testRelation, HintInfo(broadcast = true)),
ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
caseSensitive = false)

// Negative case: if the alias doesn't match, don't match the original table name.
Expand All @@ -105,7 +106,7 @@ class ResolveHintsSuite extends AnalysisTest {
|SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable
""".stripMargin
),
ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(broadcast = true))
ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(strategy = Some(BROADCAST)))
.select('a).analyze,
caseSensitive = false)
}
Expand Down
Loading