Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ import org.apache.spark.sql.catalyst.plans.logical.{HintErrorHandler, HintInfo}
* The hint error handler that logs warnings for each hint error.
*/
object HintErrorLogger extends HintErrorHandler with Logging {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

override def hintNotRecognized(name: String, parameters: Seq[Any]): Unit = {
logWarning(s"Unrecognized hint: ${hintToPrettyString(name, parameters)}")
}

override def hintRelationsNotFound(
name: String, parameters: Seq[Any], invalidRelations: Set[String]): Unit = {
invalidRelations.foreach { n =>
logWarning(s"Count not find relation '$n' specified in hint " +
name: String, parameters: Seq[Any], invalidRelations: Set[Seq[String]]): Unit = {
invalidRelations.foreach { ident =>
logWarning(s"Count not find relation '${ident.quoted}' specified in hint " +
s"'${hintToPrettyString(name, parameters)}'.")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,31 +64,59 @@ object ResolveHints {
_.toUpperCase(Locale.ROOT)).contains(hintName.toUpperCase(Locale.ROOT))))
}

// This method checks if given multi-part identifiers are matched with each other.
// The [[ResolveJoinStrategyHints]] rule is applied before the resolution batch
// in the analyzer and we cannot semantically compare them at this stage.
// Therefore, we follow a simple rule; they match if an identifier in a hint
// is a tail of an identifier in a relation. This process is independent of a session
// catalog (`currentDb` in [[SessionCatalog]]) and it just compares them literally.
//
// For example,
// * in a query `SELECT /*+ BROADCAST(t) */ * FROM db1.t JOIN t`,
// the broadcast hint will match both tables, `db1.t` and `t`,
// even when the current db is `db2`.
// * in a query `SELECT /*+ BROADCAST(default.t) */ * FROM default.t JOIN t`,
// the broadcast hint will match the left-side table only, `default.t`.
private def matchedIdentifier(identInHint: Seq[String], identInQuery: Seq[String]): Boolean = {
if (identInHint.length <= identInQuery.length) {
identInHint.zip(identInQuery.takeRight(identInHint.length))
.forall { case (i1, i2) => resolver(i1, i2) }
} else {
false
}
}

private def extractIdentifier(r: SubqueryAlias): Seq[String] = {
r.identifier.qualifier :+ r.identifier.name
}

private def applyJoinStrategyHint(
plan: LogicalPlan,
relations: mutable.HashSet[String],
relationsInHint: Set[Seq[String]],
relationsInHintWithMatch: mutable.HashSet[Seq[String]],
hintName: String): LogicalPlan = {
// Whether to continue recursing down the tree
var recurse = true

def matchedIdentifierInHint(identInQuery: Seq[String]): Boolean = {
relationsInHint.find(matchedIdentifier(_, identInQuery))
.map(relationsInHintWithMatch.add).nonEmpty
}

val newNode = CurrentOrigin.withOrigin(plan.origin) {
plan match {
case ResolvedHint(u @ UnresolvedRelation(ident), hint)
if relations.exists(resolver(_, ident.last)) =>
relations.remove(ident.last)
if matchedIdentifierInHint(ident) =>
ResolvedHint(u, createHintInfo(hintName).merge(hint, hintErrorHandler))

case ResolvedHint(r: SubqueryAlias, hint)
if relations.exists(resolver(_, r.alias)) =>
relations.remove(r.alias)
if matchedIdentifierInHint(extractIdentifier(r)) =>
ResolvedHint(r, createHintInfo(hintName).merge(hint, hintErrorHandler))

case u @ UnresolvedRelation(ident) if relations.exists(resolver(_, ident.last)) =>
relations.remove(ident.last)
case UnresolvedRelation(ident) if matchedIdentifierInHint(ident) =>
ResolvedHint(plan, createHintInfo(hintName))

case r: SubqueryAlias if relations.exists(resolver(_, r.alias)) =>
relations.remove(r.alias)
case r: SubqueryAlias if matchedIdentifierInHint(extractIdentifier(r)) =>
ResolvedHint(plan, createHintInfo(hintName))

case _: ResolvedHint | _: View | _: With | _: SubqueryAlias =>
Expand All @@ -107,7 +135,9 @@ object ResolveHints {
}

if ((plan fastEquals newNode) && recurse) {
newNode.mapChildren(child => applyJoinStrategyHint(child, relations, hintName))
newNode.mapChildren { child =>
applyJoinStrategyHint(child, relationsInHint, relationsInHintWithMatch, hintName)
}
} else {
newNode
}
Expand All @@ -120,17 +150,19 @@ object ResolveHints {
ResolvedHint(h.child, createHintInfo(h.name))
} else {
// Otherwise, find within the subtree query plans to apply the hint.
val relationNames = h.parameters.map {
case tableName: String => tableName
case tableId: UnresolvedAttribute => tableId.name
val relationNamesInHint = h.parameters.map {
case tableName: String => UnresolvedAttribute.parseAttributeName(tableName)
case tableId: UnresolvedAttribute => tableId.nameParts
case unsupported => throw new AnalysisException("Join strategy hint parameter " +
s"should be an identifier or string but was $unsupported (${unsupported.getClass}")
}
val relationNameSet = new mutable.HashSet[String]
relationNames.foreach(relationNameSet.add)

val applied = applyJoinStrategyHint(h.child, relationNameSet, h.name)
hintErrorHandler.hintRelationsNotFound(h.name, h.parameters, relationNameSet.toSet)
}.toSet
val relationsInHintWithMatch = new mutable.HashSet[Seq[String]]
val applied = applyJoinStrategyHint(
h.child, relationNamesInHint, relationsInHintWithMatch, h.name)

// Filters unmatched relation identifiers in the hint
val unmatchedIdents = relationNamesInHint -- relationsInHintWithMatch
hintErrorHandler.hintRelationsNotFound(h.name, h.parameters, unmatchedIdents)
applied
}
}
Expand Down Expand Up @@ -246,5 +278,4 @@ object ResolveHints {
h.child
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ package object expressions {
// For example, consider an example where "cat" is the catalog name, "db1" is the database
// name, "a" is the table name and "b" is the column name and "c" is the struct field name.
// If the name parts is cat.db1.a.b.c, then Attribute will match
// Attribute(b, qualifier("cat", "db1, "a")) and List("c") will be the second element
// Attribute(b, qualifier("cat", "db1", "a")) and List("c") will be the second element
var matches: (Seq[Attribute], Seq[String]) = nameParts match {
case catalogPart +: dbPart +: tblPart +: name +: nestedFields =>
val key = (catalogPart.toLowerCase(Locale.ROOT), dbPart.toLowerCase(Locale.ROOT),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ trait HintErrorHandler {
* @param parameters the hint parameters
* @param invalidRelations the set of relation names that cannot be associated
*/
def hintRelationsNotFound(name: String, parameters: Seq[Any], invalidRelations: Set[String]): Unit
def hintRelationsNotFound(
name: String, parameters: Seq[Any], invalidRelations: Set[Seq[String]]): Unit

/**
* Callback for a join hint specified on a relation that is not part of a join.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ trait AnalysisTest extends PlanTest {
catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true)
catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true)
catalog.createTempView("TaBlE3", TestRelations.testRelation3, overrideIfExists = true)
catalog.createGlobalTempView("TaBlE4", TestRelations.testRelation4, overrideIfExists = true)
catalog.createGlobalTempView("TaBlE5", TestRelations.testRelation5, overrideIfExists = true)
new Analyzer(catalog, conf) {
override val extendedResolutionRules = EliminateSubqueryAliases +: extendedAnalysisRules
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,4 +241,52 @@ class ResolveHintsSuite extends AnalysisTest {
Project(testRelation.output, testRelation),
caseSensitive = false)
}

test("Supports multi-part table names for broadcast hint resolution") {
// local temp table (single-part identifier case)
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table", "table2"),
table("TaBlE").join(table("TaBlE2"))),
Join(
ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
ResolvedHint(testRelation2, HintInfo(strategy = Some(BROADCAST))),
Inner,
None,
JoinHint.NONE),
caseSensitive = false)

checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("TaBlE", "table2"),
table("TaBlE").join(table("TaBlE2"))),
Join(
ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
testRelation2,
Inner,
None,
JoinHint.NONE),
caseSensitive = true)

// global temp table (multi-part identifier case)
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("GlOBal_TeMP.table4", "table5"),
table("global_temp", "table4").join(table("global_temp", "table5"))),
Join(
ResolvedHint(testRelation4, HintInfo(strategy = Some(BROADCAST))),
ResolvedHint(testRelation5, HintInfo(strategy = Some(BROADCAST))),
Inner,
None,
JoinHint.NONE),
caseSensitive = false)

checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("global_temp.TaBlE4", "table5"),
table("global_temp", "TaBlE4").join(table("global_temp", "TaBlE5"))),
Join(
ResolvedHint(testRelation4, HintInfo(strategy = Some(BROADCAST))),
testRelation5,
Inner,
None,
JoinHint.NONE),
caseSensitive = true)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ object TestRelations {
AttributeReference("g", StringType)(),
AttributeReference("h", MapType(IntegerType, IntegerType))())

val testRelation5 = LocalRelation(AttributeReference("i", StringType)())

val nestedRelation = LocalRelation(
AttributeReference("top", StructType(
StructField("duplicateField", StringType) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Filter, HintInfo, Join, JoinHint, LogicalPlan, Project}
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -322,4 +326,96 @@ class DataFrameJoinSuite extends QueryTest
}
}
}

test("Supports multi-part names for broadcast hint resolution") {
val (table1Name, table2Name) = ("t1", "t2")

withTempDatabase { dbName =>
withTable(table1Name, table2Name) {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
spark.range(50).write.saveAsTable(s"$dbName.$table1Name")
spark.range(100).write.saveAsTable(s"$dbName.$table2Name")

def checkIfHintApplied(df: DataFrame): Unit = {
val sparkPlan = df.queryExecution.executedPlan
val broadcastHashJoins = sparkPlan.collect { case p: BroadcastHashJoinExec => p }
assert(broadcastHashJoins.size == 1)
val broadcastExchanges = broadcastHashJoins.head.collect {
case p: BroadcastExchangeExec => p
}
assert(broadcastExchanges.size == 1)
val tables = broadcastExchanges.head.collect {
case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) => tableIdent
}
assert(tables.size == 1)
assert(tables.head === TableIdentifier(table1Name, Some(dbName)))
}

def checkIfHintNotApplied(df: DataFrame): Unit = {
val sparkPlan = df.queryExecution.executedPlan
val broadcastHashJoins = sparkPlan.collect { case p: BroadcastHashJoinExec => p }
assert(broadcastHashJoins.isEmpty)
}

def sqlTemplate(tableName: String, hintTableName: String): DataFrame = {
sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " +
s"FROM $tableName, $dbName.$table2Name " +
s"WHERE $tableName.id = $table2Name.id")
}

def dfTemplate(tableName: String, hintTableName: String): DataFrame = {
spark.table(tableName).join(spark.table(s"$dbName.$table2Name"), "id")
.hint("broadcast", hintTableName)
}

sql(s"USE $dbName")

checkIfHintApplied(sqlTemplate(table1Name, table1Name))
checkIfHintApplied(sqlTemplate(s"$dbName.$table1Name", s"$dbName.$table1Name"))
checkIfHintApplied(sqlTemplate(s"$dbName.$table1Name", table1Name))
checkIfHintNotApplied(sqlTemplate(table1Name, s"$dbName.$table1Name"))

checkIfHintApplied(dfTemplate(table1Name, table1Name))
checkIfHintApplied(dfTemplate(s"$dbName.$table1Name", s"$dbName.$table1Name"))
checkIfHintApplied(dfTemplate(s"$dbName.$table1Name", table1Name))
checkIfHintApplied(dfTemplate(table1Name, s"$dbName.$table1Name"))
checkIfHintApplied(dfTemplate(table1Name,
s"${CatalogManager.SESSION_CATALOG_NAME}.$dbName.$table1Name"))

withView("tv") {
sql(s"CREATE VIEW tv AS SELECT * FROM $dbName.$table1Name")
checkIfHintApplied(sqlTemplate("tv", "tv"))
checkIfHintNotApplied(sqlTemplate("tv", s"$dbName.tv"))

checkIfHintApplied(dfTemplate("tv", "tv"))
checkIfHintApplied(dfTemplate("tv", s"$dbName.tv"))
}
}
}
}
}

test("The same table name exists in two databases for broadcast hint resolution") {
val (db1Name, db2Name) = ("db1", "db2")

withDatabase(db1Name, db2Name) {
withTable("t") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
sql(s"CREATE DATABASE $db1Name")
sql(s"CREATE DATABASE $db2Name")
spark.range(1).write.saveAsTable(s"$db1Name.t")
spark.range(1).write.saveAsTable(s"$db2Name.t")

// Checks if a broadcast hint applied in both sides
val statement = s"SELECT /*+ BROADCASTJOIN(t) */ * FROM $db1Name.t, $db2Name.t " +
s"WHERE $db1Name.t.id = $db2Name.t.id"
sql(statement).queryExecution.optimizedPlan match {
case Join(_, _, _, _, JoinHint(Some(HintInfo(Some(BROADCAST))),
Some(HintInfo(Some(BROADCAST))))) =>
case _ => fail("broadcast hint not found in both tables")
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalog.Table
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, Join, JoinHint}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -170,4 +171,25 @@ class GlobalTempViewSuite extends QueryTest with SharedSparkSession {
isTemporary = true).toString)
}
}

test("broadcast hint on global temp view") {
withGlobalTempView("v1") {
spark.range(10).createGlobalTempView("v1")
withTempView("v2") {
spark.range(10).createTempView("v2")

withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
Seq(
"SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id",
"SELECT /*+ MAPJOIN(global_temp.v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id"
).foreach { statement =>
sql(statement).queryExecution.optimizedPlan match {
case Join(_, _, _, _, JoinHint(Some(HintInfo(Some(BROADCAST))), None)) =>
case _ => fail("broadcast hint not found in a left-side table")
}
}
}
}
}
}
}