diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala index c6e0c74527b8f..71c6d40dfba16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala @@ -24,15 +24,16 @@ import org.apache.spark.sql.catalyst.plans.logical.{HintErrorHandler, HintInfo} * The hint error handler that logs warnings for each hint error. */ object HintErrorLogger extends HintErrorHandler with Logging { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ override def hintNotRecognized(name: String, parameters: Seq[Any]): Unit = { logWarning(s"Unrecognized hint: ${hintToPrettyString(name, parameters)}") } override def hintRelationsNotFound( - name: String, parameters: Seq[Any], invalidRelations: Set[String]): Unit = { - invalidRelations.foreach { n => - logWarning(s"Count not find relation '$n' specified in hint " + + name: String, parameters: Seq[Any], invalidRelations: Set[Seq[String]]): Unit = { + invalidRelations.foreach { ident => + logWarning(s"Count not find relation '${ident.quoted}' specified in hint " + s"'${hintToPrettyString(name, parameters)}'.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 5b77d67bd1340..81de086e78f91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -64,31 +64,59 @@ object ResolveHints { _.toUpperCase(Locale.ROOT)).contains(hintName.toUpperCase(Locale.ROOT)))) } + // This method checks if given multi-part identifiers are matched with each other. + // The [[ResolveJoinStrategyHints]] rule is applied before the resolution batch + // in the analyzer and we cannot semantically compare them at this stage. + // Therefore, we follow a simple rule; they match if an identifier in a hint + // is a tail of an identifier in a relation. This process is independent of a session + // catalog (`currentDb` in [[SessionCatalog]]) and it just compares them literally. + // + // For example, + // * in a query `SELECT /*+ BROADCAST(t) */ * FROM db1.t JOIN t`, + // the broadcast hint will match both tables, `db1.t` and `t`, + // even when the current db is `db2`. + // * in a query `SELECT /*+ BROADCAST(default.t) */ * FROM default.t JOIN t`, + // the broadcast hint will match the left-side table only, `default.t`. + private def matchedIdentifier(identInHint: Seq[String], identInQuery: Seq[String]): Boolean = { + if (identInHint.length <= identInQuery.length) { + identInHint.zip(identInQuery.takeRight(identInHint.length)) + .forall { case (i1, i2) => resolver(i1, i2) } + } else { + false + } + } + + private def extractIdentifier(r: SubqueryAlias): Seq[String] = { + r.identifier.qualifier :+ r.identifier.name + } + private def applyJoinStrategyHint( plan: LogicalPlan, - relations: mutable.HashSet[String], + relationsInHint: Set[Seq[String]], + relationsInHintWithMatch: mutable.HashSet[Seq[String]], hintName: String): LogicalPlan = { // Whether to continue recursing down the tree var recurse = true + def matchedIdentifierInHint(identInQuery: Seq[String]): Boolean = { + relationsInHint.find(matchedIdentifier(_, identInQuery)) + .map(relationsInHintWithMatch.add).nonEmpty + } + val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { case ResolvedHint(u @ UnresolvedRelation(ident), hint) - if relations.exists(resolver(_, ident.last)) => - relations.remove(ident.last) + if matchedIdentifierInHint(ident) => ResolvedHint(u, createHintInfo(hintName).merge(hint, hintErrorHandler)) case ResolvedHint(r: SubqueryAlias, hint) - if relations.exists(resolver(_, r.alias)) => - relations.remove(r.alias) + if matchedIdentifierInHint(extractIdentifier(r)) => ResolvedHint(r, createHintInfo(hintName).merge(hint, hintErrorHandler)) - case u @ UnresolvedRelation(ident) if relations.exists(resolver(_, ident.last)) => - relations.remove(ident.last) + case UnresolvedRelation(ident) if matchedIdentifierInHint(ident) => ResolvedHint(plan, createHintInfo(hintName)) - case r: SubqueryAlias if relations.exists(resolver(_, r.alias)) => - relations.remove(r.alias) + case r: SubqueryAlias if matchedIdentifierInHint(extractIdentifier(r)) => ResolvedHint(plan, createHintInfo(hintName)) case _: ResolvedHint | _: View | _: With | _: SubqueryAlias => @@ -107,7 +135,9 @@ object ResolveHints { } if ((plan fastEquals newNode) && recurse) { - newNode.mapChildren(child => applyJoinStrategyHint(child, relations, hintName)) + newNode.mapChildren { child => + applyJoinStrategyHint(child, relationsInHint, relationsInHintWithMatch, hintName) + } } else { newNode } @@ -120,17 +150,19 @@ object ResolveHints { ResolvedHint(h.child, createHintInfo(h.name)) } else { // Otherwise, find within the subtree query plans to apply the hint. - val relationNames = h.parameters.map { - case tableName: String => tableName - case tableId: UnresolvedAttribute => tableId.name + val relationNamesInHint = h.parameters.map { + case tableName: String => UnresolvedAttribute.parseAttributeName(tableName) + case tableId: UnresolvedAttribute => tableId.nameParts case unsupported => throw new AnalysisException("Join strategy hint parameter " + s"should be an identifier or string but was $unsupported (${unsupported.getClass}") - } - val relationNameSet = new mutable.HashSet[String] - relationNames.foreach(relationNameSet.add) - - val applied = applyJoinStrategyHint(h.child, relationNameSet, h.name) - hintErrorHandler.hintRelationsNotFound(h.name, h.parameters, relationNameSet.toSet) + }.toSet + val relationsInHintWithMatch = new mutable.HashSet[Seq[String]] + val applied = applyJoinStrategyHint( + h.child, relationNamesInHint, relationsInHintWithMatch, h.name) + + // Filters unmatched relation identifiers in the hint + val unmatchedIdents = relationNamesInHint -- relationsInHintWithMatch + hintErrorHandler.hintRelationsNotFound(h.name, h.parameters, unmatchedIdents) applied } } @@ -246,5 +278,4 @@ object ResolveHints { h.child } } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1b59056c4da66..8bf1f19844556 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -196,7 +196,7 @@ package object expressions { // For example, consider an example where "cat" is the catalog name, "db1" is the database // name, "a" is the table name and "b" is the column name and "c" is the struct field name. // If the name parts is cat.db1.a.b.c, then Attribute will match - // Attribute(b, qualifier("cat", "db1, "a")) and List("c") will be the second element + // Attribute(b, qualifier("cat", "db1", "a")) and List("c") will be the second element var matches: (Seq[Attribute], Seq[String]) = nameParts match { case catalogPart +: dbPart +: tblPart +: name +: nestedFields => val key = (catalogPart.toLowerCase(Locale.ROOT), dbPart.toLowerCase(Locale.ROOT), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index f26e5662ee856..a325b61fcc5a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -186,7 +186,8 @@ trait HintErrorHandler { * @param parameters the hint parameters * @param invalidRelations the set of relation names that cannot be associated */ - def hintRelationsNotFound(name: String, parameters: Seq[Any], invalidRelations: Set[String]): Unit + def hintRelationsNotFound( + name: String, parameters: Seq[Any], invalidRelations: Set[Seq[String]]): Unit /** * Callback for a join hint specified on a relation that is not part of a join. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 3f8d409992381..4473c20b2cca6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -45,6 +45,8 @@ trait AnalysisTest extends PlanTest { catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true) catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true) catalog.createTempView("TaBlE3", TestRelations.testRelation3, overrideIfExists = true) + catalog.createGlobalTempView("TaBlE4", TestRelations.testRelation4, overrideIfExists = true) + catalog.createGlobalTempView("TaBlE5", TestRelations.testRelation5, overrideIfExists = true) new Analyzer(catalog, conf) { override val extendedResolutionRules = EliminateSubqueryAliases +: extendedAnalysisRules } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 5e66c038738a4..ca7d28401cf2a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -241,4 +241,52 @@ class ResolveHintsSuite extends AnalysisTest { Project(testRelation.output, testRelation), caseSensitive = false) } + + test("Supports multi-part table names for broadcast hint resolution") { + // local temp table (single-part identifier case) + checkAnalysis( + UnresolvedHint("MAPJOIN", Seq("table", "table2"), + table("TaBlE").join(table("TaBlE2"))), + Join( + ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), + ResolvedHint(testRelation2, HintInfo(strategy = Some(BROADCAST))), + Inner, + None, + JoinHint.NONE), + caseSensitive = false) + + checkAnalysis( + UnresolvedHint("MAPJOIN", Seq("TaBlE", "table2"), + table("TaBlE").join(table("TaBlE2"))), + Join( + ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), + testRelation2, + Inner, + None, + JoinHint.NONE), + caseSensitive = true) + + // global temp table (multi-part identifier case) + checkAnalysis( + UnresolvedHint("MAPJOIN", Seq("GlOBal_TeMP.table4", "table5"), + table("global_temp", "table4").join(table("global_temp", "table5"))), + Join( + ResolvedHint(testRelation4, HintInfo(strategy = Some(BROADCAST))), + ResolvedHint(testRelation5, HintInfo(strategy = Some(BROADCAST))), + Inner, + None, + JoinHint.NONE), + caseSensitive = false) + + checkAnalysis( + UnresolvedHint("MAPJOIN", Seq("global_temp.TaBlE4", "table5"), + table("global_temp", "TaBlE4").join(table("global_temp", "TaBlE5"))), + Join( + ResolvedHint(testRelation4, HintInfo(strategy = Some(BROADCAST))), + testRelation5, + Inner, + None, + JoinHint.NONE), + caseSensitive = true) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala index e12e272aedffe..33b6029070938 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -44,6 +44,8 @@ object TestRelations { AttributeReference("g", StringType)(), AttributeReference("h", MapType(IntegerType, IntegerType))()) + val testRelation5 = LocalRelation(AttributeReference("i", StringType)()) + val nestedRelation = LocalRelation( AttributeReference("top", StructType( StructField("duplicateField", StringType) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index c7545bcad8962..6b772e53ac184 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Filter, HintInfo, Join, JoinHint, LogicalPlan, Project} +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -322,4 +326,96 @@ class DataFrameJoinSuite extends QueryTest } } } + + test("Supports multi-part names for broadcast hint resolution") { + val (table1Name, table2Name) = ("t1", "t2") + + withTempDatabase { dbName => + withTable(table1Name, table2Name) { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + spark.range(50).write.saveAsTable(s"$dbName.$table1Name") + spark.range(100).write.saveAsTable(s"$dbName.$table2Name") + + def checkIfHintApplied(df: DataFrame): Unit = { + val sparkPlan = df.queryExecution.executedPlan + val broadcastHashJoins = sparkPlan.collect { case p: BroadcastHashJoinExec => p } + assert(broadcastHashJoins.size == 1) + val broadcastExchanges = broadcastHashJoins.head.collect { + case p: BroadcastExchangeExec => p + } + assert(broadcastExchanges.size == 1) + val tables = broadcastExchanges.head.collect { + case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) => tableIdent + } + assert(tables.size == 1) + assert(tables.head === TableIdentifier(table1Name, Some(dbName))) + } + + def checkIfHintNotApplied(df: DataFrame): Unit = { + val sparkPlan = df.queryExecution.executedPlan + val broadcastHashJoins = sparkPlan.collect { case p: BroadcastHashJoinExec => p } + assert(broadcastHashJoins.isEmpty) + } + + def sqlTemplate(tableName: String, hintTableName: String): DataFrame = { + sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " + + s"FROM $tableName, $dbName.$table2Name " + + s"WHERE $tableName.id = $table2Name.id") + } + + def dfTemplate(tableName: String, hintTableName: String): DataFrame = { + spark.table(tableName).join(spark.table(s"$dbName.$table2Name"), "id") + .hint("broadcast", hintTableName) + } + + sql(s"USE $dbName") + + checkIfHintApplied(sqlTemplate(table1Name, table1Name)) + checkIfHintApplied(sqlTemplate(s"$dbName.$table1Name", s"$dbName.$table1Name")) + checkIfHintApplied(sqlTemplate(s"$dbName.$table1Name", table1Name)) + checkIfHintNotApplied(sqlTemplate(table1Name, s"$dbName.$table1Name")) + + checkIfHintApplied(dfTemplate(table1Name, table1Name)) + checkIfHintApplied(dfTemplate(s"$dbName.$table1Name", s"$dbName.$table1Name")) + checkIfHintApplied(dfTemplate(s"$dbName.$table1Name", table1Name)) + checkIfHintApplied(dfTemplate(table1Name, s"$dbName.$table1Name")) + checkIfHintApplied(dfTemplate(table1Name, + s"${CatalogManager.SESSION_CATALOG_NAME}.$dbName.$table1Name")) + + withView("tv") { + sql(s"CREATE VIEW tv AS SELECT * FROM $dbName.$table1Name") + checkIfHintApplied(sqlTemplate("tv", "tv")) + checkIfHintNotApplied(sqlTemplate("tv", s"$dbName.tv")) + + checkIfHintApplied(dfTemplate("tv", "tv")) + checkIfHintApplied(dfTemplate("tv", s"$dbName.tv")) + } + } + } + } + } + + test("The same table name exists in two databases for broadcast hint resolution") { + val (db1Name, db2Name) = ("db1", "db2") + + withDatabase(db1Name, db2Name) { + withTable("t") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + sql(s"CREATE DATABASE $db1Name") + sql(s"CREATE DATABASE $db2Name") + spark.range(1).write.saveAsTable(s"$db1Name.t") + spark.range(1).write.saveAsTable(s"$db2Name.t") + + // Checks if a broadcast hint applied in both sides + val statement = s"SELECT /*+ BROADCASTJOIN(t) */ * FROM $db1Name.t, $db2Name.t " + + s"WHERE $db1Name.t.id = $db2Name.t.id" + sql(statement).queryExecution.optimizedPlan match { + case Join(_, _, _, _, JoinHint(Some(HintInfo(Some(BROADCAST))), + Some(HintInfo(Some(BROADCAST))))) => + case _ => fail("broadcast hint not found in both tables") + } + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index 7fbfa73623c85..28e82aa14e0d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalog.Table import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, Join, JoinHint} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType @@ -170,4 +171,25 @@ class GlobalTempViewSuite extends QueryTest with SharedSparkSession { isTemporary = true).toString) } } + + test("broadcast hint on global temp view") { + withGlobalTempView("v1") { + spark.range(10).createGlobalTempView("v1") + withTempView("v2") { + spark.range(10).createTempView("v2") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq( + "SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id", + "SELECT /*+ MAPJOIN(global_temp.v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id" + ).foreach { statement => + sql(statement).queryExecution.optimizedPlan match { + case Join(_, _, _, _, JoinHint(Some(HintInfo(Some(BROADCAST))), None)) => + case _ => fail("broadcast hint not found in a left-side table") + } + } + } + } + } + } }