diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index ca7d28401cf2a..d3bd5d07a0932 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -242,51 +242,56 @@ class ResolveHintsSuite extends AnalysisTest { caseSensitive = false) } - test("Supports multi-part table names for broadcast hint resolution") { - // local temp table (single-part identifier case) - checkAnalysis( - UnresolvedHint("MAPJOIN", Seq("table", "table2"), - table("TaBlE").join(table("TaBlE2"))), - Join( - ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), - ResolvedHint(testRelation2, HintInfo(strategy = Some(BROADCAST))), - Inner, - None, - JoinHint.NONE), - caseSensitive = false) - - checkAnalysis( - UnresolvedHint("MAPJOIN", Seq("TaBlE", "table2"), - table("TaBlE").join(table("TaBlE2"))), - Join( - ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), - testRelation2, - Inner, - None, - JoinHint.NONE), - caseSensitive = true) + test("Supports multi-part table names for join strategy hint resolution") { + Seq(("MAPJOIN", BROADCAST), + ("MERGEJOIN", SHUFFLE_MERGE), + ("SHUFFLE_HASH", SHUFFLE_HASH), + ("SHUFFLE_REPLICATE_NL", SHUFFLE_REPLICATE_NL)).foreach { case (hintName, st) => + // local temp table (single-part identifier case) + checkAnalysis( + UnresolvedHint(hintName, Seq("table", "table2"), + table("TaBlE").join(table("TaBlE2"))), + Join( + ResolvedHint(testRelation, HintInfo(strategy = Some(st))), + ResolvedHint(testRelation2, HintInfo(strategy = Some(st))), + Inner, + None, + JoinHint.NONE), + caseSensitive = false) - // global temp table (multi-part identifier case) - checkAnalysis( - UnresolvedHint("MAPJOIN", Seq("GlOBal_TeMP.table4", "table5"), - table("global_temp", "table4").join(table("global_temp", "table5"))), - Join( - ResolvedHint(testRelation4, HintInfo(strategy = Some(BROADCAST))), - ResolvedHint(testRelation5, HintInfo(strategy = Some(BROADCAST))), - Inner, - None, - JoinHint.NONE), - caseSensitive = false) + checkAnalysis( + UnresolvedHint(hintName, Seq("TaBlE", "table2"), + table("TaBlE").join(table("TaBlE2"))), + Join( + ResolvedHint(testRelation, HintInfo(strategy = Some(st))), + testRelation2, + Inner, + None, + JoinHint.NONE), + caseSensitive = true) + + // global temp table (multi-part identifier case) + checkAnalysis( + UnresolvedHint(hintName, Seq("GlOBal_TeMP.table4", "table5"), + table("global_temp", "table4").join(table("global_temp", "table5"))), + Join( + ResolvedHint(testRelation4, HintInfo(strategy = Some(st))), + ResolvedHint(testRelation5, HintInfo(strategy = Some(st))), + Inner, + None, + JoinHint.NONE), + caseSensitive = false) - checkAnalysis( - UnresolvedHint("MAPJOIN", Seq("global_temp.TaBlE4", "table5"), - table("global_temp", "TaBlE4").join(table("global_temp", "TaBlE5"))), - Join( - ResolvedHint(testRelation4, HintInfo(strategy = Some(BROADCAST))), - testRelation5, - Inner, - None, - JoinHint.NONE), - caseSensitive = true) + checkAnalysis( + UnresolvedHint(hintName, Seq("global_temp.TaBlE4", "table5"), + table("global_temp", "TaBlE4").join(table("global_temp", "TaBlE5"))), + Join( + ResolvedHint(testRelation4, HintInfo(strategy = Some(st))), + testRelation5, + Inner, + None, + JoinHint.NONE), + caseSensitive = true) + } } }