Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.analysis

import java.util.Locale

import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.IntegerLiteral
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -62,23 +64,27 @@ object ResolveHints {

private def applyJoinStrategyHint(
plan: LogicalPlan,
relations: Set[String],
relations: mutable.HashSet[String],
hintName: String): LogicalPlan = {
// Whether to continue recursing down the tree
var recurse = true

val newNode = CurrentOrigin.withOrigin(plan.origin) {
plan match {
case ResolvedHint(u: UnresolvedRelation, _)
case ResolvedHint(u: UnresolvedRelation, hint)
if relations.exists(resolver(_, u.tableIdentifier.table)) =>
ResolvedHint(u, createHintInfo(hintName))
case ResolvedHint(r: SubqueryAlias, _)
relations.remove(u.tableIdentifier.table)
ResolvedHint(u, createHintInfo(hintName).merge(hint, handleOverriddenHintInfo))
case ResolvedHint(r: SubqueryAlias, hint)
if relations.exists(resolver(_, r.alias)) =>
ResolvedHint(r, createHintInfo(hintName))
relations.remove(r.alias)
ResolvedHint(r, createHintInfo(hintName).merge(hint, handleOverriddenHintInfo))

case u: UnresolvedRelation if relations.exists(resolver(_, u.tableIdentifier.table)) =>
relations.remove(u.tableIdentifier.table)
ResolvedHint(plan, createHintInfo(hintName))
case r: SubqueryAlias if relations.exists(resolver(_, r.alias)) =>
relations.remove(r.alias)
ResolvedHint(plan, createHintInfo(hintName))

case _: ResolvedHint | _: View | _: With | _: SubqueryAlias =>
Expand All @@ -103,19 +109,32 @@ object ResolveHints {
}
}

private def handleOverriddenHintInfo(hint: HintInfo): Unit = {
logWarning(s"Join hint $hint is overridden by another hint and will not take effect.")
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case h: UnresolvedHint if STRATEGY_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
if (h.parameters.isEmpty) {
// If there is no table alias specified, apply the hint on the entire subtree.
ResolvedHint(h.child, createHintInfo(h.name))
} else {
// Otherwise, find within the subtree query plans to apply the hint.
applyJoinStrategyHint(h.child, h.parameters.map {
val relationNames = h.parameters.map {
case tableName: String => tableName
case tableId: UnresolvedAttribute => tableId.name
case unsupported => throw new AnalysisException("Join strategy hint parameter " +
s"should be an identifier or string but was $unsupported (${unsupported.getClass}")
}.toSet, h.name)
}
val relationNameSet = new mutable.HashSet[String]
relationNames.foreach(relationNameSet.add)

val applied = applyJoinStrategyHint(h.child, relationNameSet, h.name)
relationNameSet.foreach { n =>
logWarning(s"Count not find relation '$n' for join strategy hint " +
s"'${h.name}${relationNames.mkString("(", ", ", ")")}'.")
}
applied
}
}
}
Expand Down Expand Up @@ -152,7 +171,9 @@ object ResolveHints {
*/
object RemoveAllHints extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case h: UnresolvedHint => h.child
case h: UnresolvedHint =>
logWarning(s"Unrecognized hint: ${h.name}${h.parameters.mkString("(", ", ", ")")}")
h.child
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,58 @@ object EliminateResolvedHint extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
val pulledUp = plan transformUp {
case j: Join =>
val leftHint = mergeHints(collectHints(j.left))
val rightHint = mergeHints(collectHints(j.right))
j.copy(hint = JoinHint(leftHint, rightHint))
val (newLeft, leftHints) = extractHintsFromPlan(j.left)
val (newRight, rightHints) = extractHintsFromPlan(j.right)
val newJoinHint = JoinHint(mergeHints(leftHints), mergeHints(rightHints))
j.copy(left = newLeft, right = newRight, hint = newJoinHint)
}
pulledUp.transformUp {
case h: ResolvedHint => h.child
case h: ResolvedHint =>
handleInvalidHintInfo(h.hints)
h.child
}
}

/**
* Combine a list of [[HintInfo]]s into one [[HintInfo]].
*/
private def mergeHints(hints: Seq[HintInfo]): Option[HintInfo] = {
hints.reduceOption((h1, h2) => h1.merge(h2))
hints.reduceOption((h1, h2) => h1.merge(h2, handleOverriddenHintInfo))
}

private def collectHints(plan: LogicalPlan): Seq[HintInfo] = {
/**
* Extract all hints from the plan, returning a list of extracted hints and the transformed plan
* with [[ResolvedHint]] nodes removed. The returned hint list comes in top-down order.
* Note that hints can only be extracted from under certain nodes. Those that cannot be extracted
* in this method will be cleaned up later by this rule, and may emit warnings depending on the
* configurations.
*/
private def extractHintsFromPlan(plan: LogicalPlan): (LogicalPlan, Seq[HintInfo]) = {
plan match {
case h: ResolvedHint => h.hints +: collectHints(h.child)
case u: UnaryNode => collectHints(u.child)
case h: ResolvedHint =>
val (plan, hints) = extractHintsFromPlan(h.child)
(plan, h.hints +: hints)
case u: UnaryNode =>
val (plan, hints) = extractHintsFromPlan(u.child)
(u.withNewChildren(Seq(plan)), hints)
// TODO revisit this logic:
// except and intersect are semi/anti-joins which won't return more data then
// their left argument, so the join strategy hint should be propagated here
case i: Intersect => collectHints(i.left)
case e: Except => collectHints(e.left)
case _ => Seq.empty
// their left argument, so the broadcast hint should be propagated here
case i: Intersect =>
val (plan, hints) = extractHintsFromPlan(i.left)
(i.copy(left = plan), hints)
case e: Except =>
val (plan, hints) = extractHintsFromPlan(e.left)
(e.copy(left = plan), hints)
case p: LogicalPlan => (p, Seq.empty)
}
}

private def handleInvalidHintInfo(hint: HintInfo): Unit = {
logWarning(s"A join hint $hint is specified but it is not part of a join relation.")
}

private def handleOverriddenHintInfo(hint: HintInfo): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a little weird to see this method being defined twice. Can we just log the message inside HintInfo.merge?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking to have a centralized handler for all kinds of hint events/errors, and the action, whether to log warnings/errors or to throw exceptions, can be configurable. WDYT?

logWarning(s"Join hint $hint is overridden by another hint and will not take effect.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,26 @@ object JoinHint {
case class HintInfo(strategy: Option[JoinStrategyHint] = None) {

/**
* Combine two [[HintInfo]]s into one [[HintInfo]], in which the new strategy will the strategy
* in this [[HintInfo]] if defined, otherwise the strategy in the other [[HintInfo]].
* Combine this [[HintInfo]] with another [[HintInfo]] and return the new [[HintInfo]].
* @param other the other [[HintInfo]]
* @param hintOverriddenCallback a callback to notify if any [[HintInfo]] has been overridden
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we create a hint merging strategy framework, I think it will not be an arbitrary callback. Shall we make it simple now and leave it for future design? Then we can just log message inside this method.

* in this merge.
*
* Currently, for join strategy hints, the new [[HintInfo]] will contain the strategy in this
* [[HintInfo]] if defined, otherwise the strategy in the other [[HintInfo]]. The
* `hintOverriddenCallback` will be called if this [[HintInfo]] and the other [[HintInfo]]
* both have a strategy defined but the join strategies are different.
*/
def merge(other: HintInfo): HintInfo = {
def merge(other: HintInfo, hintOverriddenCallback: HintInfo => Unit): HintInfo = {
if (this.strategy.isDefined &&
other.strategy.isDefined &&
this.strategy.get != other.strategy.get) {
hintOverriddenCallback(other)
}
HintInfo(strategy = this.strategy.orElse(other.strategy))
}

override def toString: String = {
val hints = scala.collection.mutable.ArrayBuffer.empty[String]
if (strategy.isDefined) {
hints += s"strategy=${strategy.get}"
}

if (hints.isEmpty) "none" else hints.mkString("(", ", ", ")")
}
override def toString: String = if (strategy.isDefined) s"(strategy=${strategy.get})" else "none"
}

sealed abstract class JoinStrategyHint {
Expand Down Expand Up @@ -117,7 +122,7 @@ object JoinStrategyHint {
* equi-join keys.
*/
case object BROADCAST extends JoinStrategyHint {
override def displayName: String = "broadcast-hash"
override def displayName: String = "broadcast"
override def hintAliases: Set[String] = Set(
"BROADCAST",
"BROADCASTJOIN",
Expand All @@ -128,7 +133,7 @@ case object BROADCAST extends JoinStrategyHint {
* The hint for shuffle sort merge join.
*/
case object SHUFFLE_MERGE extends JoinStrategyHint {
override def displayName: String = "shuffle-merge"
override def displayName: String = "merge"
override def hintAliases: Set[String] = Set(
"SHUFFLE_MERGE",
"MERGE",
Expand All @@ -139,7 +144,7 @@ case object SHUFFLE_MERGE extends JoinStrategyHint {
* The hint for shuffle hash join.
*/
case object SHUFFLE_HASH extends JoinStrategyHint {
override def displayName: String = "shuffle-hash"
override def displayName: String = "shuffle_hash"
override def hintAliases: Set[String] = Set(
"SHUFFLE_HASH")
}
Expand All @@ -148,7 +153,7 @@ case object SHUFFLE_HASH extends JoinStrategyHint {
* The hint for shuffle-and-replicate nested loop join, a.k.a. cartesian product join.
*/
case object SHUFFLE_REPLICATE_NL extends JoinStrategyHint {
override def displayName: String = "shuffle-replicate-nested-loop"
override def displayName: String = "shuffle_replicate_nl"
override def hintAliases: Set[String] = Set(
"SHUFFLE_REPLICATE_NL")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hint for cartesian products is useful for users?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. In the default logic, broadcast-nl is prioritized over shuffle-replicate-nl (cartesian-product), so this can be used for special cases where shuffle-replicate-nl is favored.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might need a code comment to explain SHUFFLE_REPLICATE_NL is cartesian products.

}
86 changes: 70 additions & 16 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@

package org.apache.spark.sql

import scala.collection.mutable.ArrayBuffer

import org.apache.log4j.{AppenderSkeleton, Level}
import org.apache.log4j.spi.LoggingEvent

import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.joins._
Expand All @@ -31,6 +36,41 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
lazy val df2 = df.selectExpr("id as b1", "id as b2")
lazy val df3 = df.selectExpr("id as c1", "id as c2")

class MockAppender extends AppenderSkeleton {
val loggingEvents = new ArrayBuffer[LoggingEvent]()

override def append(loggingEvent: LoggingEvent): Unit = loggingEvents.append(loggingEvent)
override def close(): Unit = {}
override def requiresLayout(): Boolean = false
}

def msgNoHintRelationFound(relation: String, hint: String): String =
s"Count not find relation '$relation' for join strategy hint '$hint'."

def msgNoJoinForJoinHint(strategy: String): String =
s"A join hint (strategy=$strategy) is specified but it is not part of a join relation."

def msgJoinHintOverridden(strategy: String): String =
s"Join hint (strategy=$strategy) is overridden by another hint and will not take effect."

def verifyJoinHintWithWarnings(
df: => DataFrame,
expectedHints: Seq[JoinHint],
warnings: Seq[String]): Unit = {
val logAppender = new MockAppender()
withLogAppender(logAppender) {
verifyJoinHint(df, expectedHints)
}
val warningMessages = logAppender.loggingEvents
.filter(_.getLevel == Level.WARN)
.map(_.getRenderedMessage)
.filter(_.contains("hint"))
assert(warningMessages.size == warnings.size)
warnings.foreach { w =>
assert(warningMessages.contains(w))
}
}

def verifyJoinHint(df: DataFrame, expectedHints: Seq[JoinHint]): Unit = {
val optimized = df.queryExecution.optimizedPlan
val joinHints = optimized collect {
Expand Down Expand Up @@ -179,25 +219,29 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
}

test("hint merge") {
verifyJoinHint(
verifyJoinHintWithWarnings(
df.hint("broadcast").filter('id > 2).hint("broadcast").join(df, "id"),
JoinHint(
Some(HintInfo(strategy = Some(BROADCAST))),
None) :: Nil
None) :: Nil,
Nil
)
verifyJoinHint(
verifyJoinHintWithWarnings(
df.join(df.hint("broadcast").limit(2).hint("broadcast"), "id"),
JoinHint(
None,
Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil
Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil,
Nil
)
verifyJoinHint(
df.hint("merge").filter('id > 2).hint("shuffle_hash").join(df, "id"),
verifyJoinHintWithWarnings(
df.hint("merge").filter('id > 2).hint("shuffle_hash").join(df, "id").hint("broadcast"),
JoinHint(
Some(HintInfo(strategy = Some(SHUFFLE_HASH))),
None) :: Nil
None) :: Nil,
msgJoinHintOverridden("merge") ::
msgNoJoinForJoinHint("broadcast") :: Nil
)
verifyJoinHint(
verifyJoinHintWithWarnings(
df.join(df.hint("broadcast").limit(2).hint("merge"), "id")
.hint("shuffle_hash")
.hint("shuffle_replicate_nl")
Expand All @@ -207,7 +251,9 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
None) ::
JoinHint(
None,
Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) :: Nil
Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) :: Nil,
msgJoinHintOverridden("broadcast") ::
msgJoinHintOverridden("shuffle_hash") :: Nil
)
}

Expand All @@ -216,25 +262,30 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
df1.createOrReplaceTempView("a")
df2.createOrReplaceTempView("b")
df3.createOrReplaceTempView("c")
verifyJoinHint(
sql("select /*+ merge(a, c) broadcast(a, b)*/ * from a, b, c " +
verifyJoinHintWithWarnings(
sql("select /*+ shuffle_hash merge(a, c) broadcast(a, b)*/ * from a, b, c " +
"where a.a1 = b.b1 and b.b1 = c.c1"),
JoinHint(
None,
Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) ::
JoinHint(
Some(HintInfo(strategy = Some(SHUFFLE_MERGE))),
Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil
Some(HintInfo(strategy = Some(BROADCAST)))) :: Nil,
msgNoJoinForJoinHint("shuffle_hash") ::
msgJoinHintOverridden("broadcast") :: Nil
)
verifyJoinHint(
verifyJoinHintWithWarnings(
sql("select /*+ shuffle_hash(a, b) merge(b, d) broadcast(b)*/ * from a, b, c " +
"where a.a1 = b.b1 and b.b1 = c.c1"),
JoinHint.NONE ::
JoinHint(
Some(HintInfo(strategy = Some(SHUFFLE_HASH))),
Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil
Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil,
msgNoHintRelationFound("d", "merge(b, d)") ::
msgJoinHintOverridden("broadcast") ::
msgJoinHintOverridden("merge") :: Nil
)
verifyJoinHint(
verifyJoinHintWithWarnings(
sql(
"""
|select /*+ broadcast(a, c) merge(a, d)*/ * from a
Expand All @@ -249,7 +300,10 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
Some(HintInfo(strategy = Some(SHUFFLE_MERGE)))) ::
JoinHint(
Some(HintInfo(strategy = Some(SHUFFLE_REPLICATE_NL))),
Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil
Some(HintInfo(strategy = Some(SHUFFLE_HASH)))) :: Nil,
msgNoHintRelationFound("c", "broadcast(a, c)") ::
msgJoinHintOverridden("merge") ::
msgJoinHintOverridden("shuffle_replicate_nl") :: Nil
)
}
}
Expand Down