From a8d367941a3fea9f475d6834b886d1c25883dc34 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 15 Oct 2015 17:18:43 -0700 Subject: [PATCH 1/4] support broadcast based Cartesian/Cross Join --- .../spark/sql/execution/SparkStrategies.scala | 21 +++++ .../joins/BroadcastNestedLoopJoin.scala | 3 +- .../org/apache/spark/sql/JoinSuite.scala | 92 +++++++++++++++++++ ... JOIN #1-0-abfc0b99ee357f71639f6162345fe8e | 20 ++++ ...JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd | 20 ++++ ...JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 | 20 ++++ ...JOIN #4-0-45f8602d257655322b7d18cad09f6a0f | 20 ++++ .../sql/hive/execution/HivePlanTest.scala | 1 + .../sql/hive/execution/HiveQuerySuite.scala | 54 +++++++++++ 9 files changed, 250 insertions(+), 1 deletion(-) create mode 100644 sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e create mode 100644 sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd create mode 100644 sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 create mode 100644 sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 79bd1a41808de..82ec1ffd5c1a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -312,6 +312,27 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + // Not like the equal-join, BroadcastNestedLoopJoin doesn't support condition + // for cartesian join, as in cartesian join, probably, the records satisfy the + // condition, but exists in another partition of the large table, so we may not able + // to eliminate the duplicates. + case logical.Join( + CanBroadcast(left), right, joinType @ (FullOuter | LeftOuter | RightOuter), None) => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildLeft, joinType, None) :: Nil + case logical.Join( + left, CanBroadcast(right), joinType @ (FullOuter | LeftOuter | RightOuter), None) => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildRight, joinType, None) :: Nil + // Since BroadCastNestedLoopJoin supports condition already, we simply passed it down. + case logical.Join( + CanBroadcast(left), right, Inner, condition) => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildLeft, Inner, condition) :: Nil + case logical.Join( + left, CanBroadcast(right), Inner, condition) => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildRight, Inner, condition) :: Nil case logical.Join(left, right, _, None) => execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index efef8c8a8b96a..176cdc9e1dc0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.CompactBuffer @@ -67,6 +67,7 @@ case class BroadcastNestedLoopJoin( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case Inner => left.output ++ right.output case x => throw new IllegalArgumentException( s"BroadcastNestedLoopJoin should not take $x as the JoinType") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index b1fb06815868c..a9ca46cab067d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -28,6 +28,10 @@ class JoinSuite extends QueryTest with SharedSQLContext { setupTestData() + def statisticSizeInByte(df: DataFrame): BigInt = { + df.queryExecution.optimizedPlan.statistics.sizeInBytes + } + test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") @@ -466,6 +470,94 @@ class JoinSuite extends QueryTest with SharedSQLContext { sql("UNCACHE TABLE testData") } + test("cross join with broadcast") { + sql("CACHE TABLE testData") + + val sizeInByteOfTestData = statisticSizeInByte(sqlContext.table("testData")) + + // we set the threshold is greater than statistic of the cached table testData + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) { + + assert(statisticSizeInByte(sqlContext.table("testData2")) > + sqlContext.conf.autoBroadcastJoinThreshold) + + assert(statisticSizeInByte(sqlContext.table("testData")) < + sqlContext.conf.autoBroadcastJoinThreshold) + + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", + classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2", + classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData LEFT JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData JOIN testData2 WHERE key > a", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData left JOIN testData2 WHERE (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData right JOIN testData2 WHERE (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData full JOIN testData2 WHERE (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + + checkAnswer( + sql( + """ + SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key = 2 + """.stripMargin), + Row("2", 1, 1) :: + Row("2", 1, 2) :: + Row("2", 2, 1) :: + Row("2", 2, 2) :: + Row("2", 3, 1) :: + Row("2", 3, 2) :: Nil) + + checkAnswer( + sql( + """ + SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key < y.a + """.stripMargin), + Row("1", 2, 1) :: + Row("1", 2, 2) :: + Row("1", 3, 1) :: + Row("1", 3, 2) :: + Row("2", 3, 1) :: + Row("2", 3, 2) :: Nil) + + checkAnswer( + sql( + """ + SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y ON x.key < y.a + """.stripMargin), + Row("1", 2, 1) :: + Row("1", 2, 2) :: + Row("1", 3, 1) :: + Row("1", 3, 2) :: + Row("2", 3, 1) :: + Row("2", 3, 2) :: Nil) + } + + sql("UNCACHE TABLE testData") + } + test("left semi join") { val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e new file mode 100644 index 0000000000000..0bb9399af0c45 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1-0-abfc0b99ee357f71639f6162345fe8e @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +306 0 +306 0 +306 0 +307 0 +307 0 +307 0 +307 0 +307 0 +307 0 +308 0 +308 0 +308 0 +309 0 +309 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd new file mode 100644 index 0000000000000..4e455ed255117 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2-0-8412a39ee57885ccb0aaf848db8ef1dd @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +305 2 +305 4 +306 0 +306 0 +306 0 +306 2 +306 4 +306 5 +306 5 +306 5 +307 0 +307 0 +307 0 +307 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 new file mode 100644 index 0000000000000..4e455ed255117 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3-0-e8a0427dbde35eea6011144443e5ffb4 @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +305 2 +305 4 +306 0 +306 0 +306 0 +306 2 +306 4 +306 5 +306 5 +306 5 +307 0 +307 0 +307 0 +307 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f new file mode 100644 index 0000000000000..4e455ed255117 --- /dev/null +++ b/sql/hive/src/test/resources/golden/SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4-0-45f8602d257655322b7d18cad09f6a0f @@ -0,0 +1,20 @@ +302 0 +302 0 +302 0 +305 0 +305 0 +305 0 +305 2 +305 4 +306 0 +306 0 +306 0 +306 2 +306 4 +306 5 +306 5 +306 5 +307 0 +307 0 +307 0 +307 0 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index cd055f9eca37e..6a2b3cd46d177 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.plans.logical diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2878500453141..2eb2a62f20756 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} +import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin + import scala.util.Try import org.scalatest.BeforeAndAfter @@ -69,6 +71,58 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + // Testing the Broadcast based join for cartesian join (cross join) + // We assume that the Broadcast Join Threshold will works since the src is a small table + private val spark_10484_1 = """ + | SELECT a.key, b.key + | FROM src a LEFT JOIN src b WHERE a.key > b.key + 300 + | ORDER BY b.key, a.key + | LIMIT 20 + """.stripMargin + private val spark_10484_2 = """ + | SELECT a.key, b.key + | FROM src a RIGHT JOIN src b WHERE a.key > b.key + 300 + | ORDER BY a.key, b.key + | LIMIT 20 + """.stripMargin + private val spark_10484_3 = """ + | SELECT a.key, b.key + | FROM src a FULL OUTER JOIN src b WHERE a.key > b.key + 300 + | ORDER BY a.key, b.key + | LIMIT 20 + """.stripMargin + private val spark_10484_4 = """ + | SELECT a.key, b.key + | FROM src a JOIN src b WHERE a.key > b.key + 300 + | ORDER BY a.key, b.key + | LIMIT 20 + """.stripMargin + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #1", + spark_10484_1) + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #2", + spark_10484_2) + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #3", + spark_10484_3) + + createQueryTest("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN #4", + spark_10484_4) + + test("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN") { + def assertBroadcastNestedLoopJoin(sqlText: String): Boolean = { + sql(sqlText).queryExecution.sparkPlan.collect { + case _: BroadcastNestedLoopJoin => 1 + }.size > 0 + } + + assertBroadcastNestedLoopJoin(spark_10484_1) + assertBroadcastNestedLoopJoin(spark_10484_2) + assertBroadcastNestedLoopJoin(spark_10484_3) + assertBroadcastNestedLoopJoin(spark_10484_4) + } + createQueryTest("SPARK-8976 Wrong Result for Rollup #1", """ SELECT count(*) AS cnt, key % 5,GROUPING__ID FROM src group by key%5 WITH ROLLUP From 024a1fb17ab5181840400206a044eef808a8a293 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Sun, 18 Oct 2015 17:50:31 -0700 Subject: [PATCH 2/4] update the style issue and add extra assert --- .../apache/spark/sql/execution/SparkStrategies.scala | 10 +++++----- .../apache/spark/sql/hive/execution/HivePlanTest.scala | 1 - .../spark/sql/hive/execution/HiveQuerySuite.scala | 8 ++++---- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 82ec1ffd5c1a4..f49325cf6a751 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -312,10 +312,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // Not like the equal-join, BroadcastNestedLoopJoin doesn't support condition - // for cartesian join, as in cartesian join, probably, the records satisfy the - // condition, but exists in another partition of the large table, so we may not able - // to eliminate the duplicates. + // Not like the equal-join, BroadcastNestedLoopJoin doesn't support condition + // for cartesian join, as in cartesian join, probably, the records satisfy the + // condition, but exists in another partition of the large table, so we may not able + // to eliminate the duplicates. case logical.Join( CanBroadcast(left), right, joinType @ (FullOuter | LeftOuter | RightOuter), None) => execution.joins.BroadcastNestedLoopJoin( @@ -324,7 +324,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left, CanBroadcast(right), joinType @ (FullOuter | LeftOuter | RightOuter), None) => execution.joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), joins.BuildRight, joinType, None) :: Nil - // Since BroadCastNestedLoopJoin supports condition already, we simply passed it down. + // Since BroadCastNestedLoopJoin supports condition already, we simply passed it down. case logical.Join( CanBroadcast(left), right, Inner, condition) => execution.joins.BroadcastNestedLoopJoin( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index 6a2b3cd46d177..cd055f9eca37e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.plans.logical diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2eb2a62f20756..b52f7d4b57899 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -73,7 +73,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Testing the Broadcast based join for cartesian join (cross join) // We assume that the Broadcast Join Threshold will works since the src is a small table - private val spark_10484_1 = """ + private val spark_10484_1 = """ | SELECT a.key, b.key | FROM src a LEFT JOIN src b WHERE a.key > b.key + 300 | ORDER BY b.key, a.key @@ -111,10 +111,10 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { spark_10484_4) test("SPARK-10484 Optimize the Cartesian (Cross) Join with broadcast based JOIN") { - def assertBroadcastNestedLoopJoin(sqlText: String): Boolean = { - sql(sqlText).queryExecution.sparkPlan.collect { + def assertBroadcastNestedLoopJoin(sqlText: String): Unit = { + assert(sql(sqlText).queryExecution.sparkPlan.collect { case _: BroadcastNestedLoopJoin => 1 - }.size > 0 + }.nonEmpty) } assertBroadcastNestedLoopJoin(spark_10484_1) From 975eb461d06d5c49a0d6f5a3dde5d682ff05ebf5 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Sun, 25 Oct 2015 22:49:38 -0700 Subject: [PATCH 3/4] Add DefaultJoin Strategy --- .../spark/sql/execution/SparkPlanner.scala | 3 +- .../spark/sql/execution/SparkStrategies.scala | 59 ++++++++----------- .../joins/BroadcastNestedLoopJoin.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 3 +- 4 files changed, 31 insertions(+), 36 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index b346f43faebe2..0f98fe88b2101 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -44,8 +44,9 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { EquiJoinSelection :: InMemoryScans :: BasicOperators :: + BroadcastNestedLoop :: CartesianProduct :: - BroadcastNestedLoopJoin :: Nil) + DefaultJoin :: Nil) /** * Used to build table scan operators where complex projection and filtering are done using diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f49325cf6a751..60fd60b9fca69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -294,46 +294,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - - object BroadcastNestedLoopJoin extends Strategy { + object BroadcastNestedLoop extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join(left, right, joinType, condition) => - val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - joins.BuildRight - } else { - joins.BuildLeft - } - joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + case logical.Join( + CanBroadcast(left), right, joinType, condition) if joinType != LeftSemiJoin => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil + case logical.Join( + left, CanBroadcast(right), joinType, condition) if joinType != LeftSemiJoin => + execution.joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil case _ => Nil } } object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // Not like the equal-join, BroadcastNestedLoopJoin doesn't support condition - // for cartesian join, as in cartesian join, probably, the records satisfy the - // condition, but exists in another partition of the large table, so we may not able - // to eliminate the duplicates. - case logical.Join( - CanBroadcast(left), right, joinType @ (FullOuter | LeftOuter | RightOuter), None) => - execution.joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joins.BuildLeft, joinType, None) :: Nil - case logical.Join( - left, CanBroadcast(right), joinType @ (FullOuter | LeftOuter | RightOuter), None) => - execution.joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joins.BuildRight, joinType, None) :: Nil - // Since BroadCastNestedLoopJoin supports condition already, we simply passed it down. - case logical.Join( - CanBroadcast(left), right, Inner, condition) => - execution.joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joins.BuildLeft, Inner, condition) :: Nil - case logical.Join( - left, CanBroadcast(right), Inner, condition) => - execution.joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joins.BuildRight, Inner, condition) :: Nil - case logical.Join(left, right, _, None) => + // TODO CartesianProduct doesn't support the Left Semi Join + case logical.Join(left, right, joinType, None) if joinType != LeftSemiJoin => execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, @@ -342,6 +320,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object DefaultJoin extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.Join(left, right, joinType, condition) => + val buildSide = + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + joins.BuildRight + } else { + joins.BuildLeft + } + joins.BroadcastNestedLoopJoin( + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + case _ => Nil + } + } + protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object TakeOrderedAndProject extends Strategy { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 176cdc9e1dc0d..0924f336f3ce1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -68,7 +68,7 @@ case class BroadcastNestedLoopJoin( case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) case Inner => left.output ++ right.output - case x => + case x => // TODO support the Left Semi Join throw new IllegalArgumentException( s"BroadcastNestedLoopJoin should not take $x as the JoinType") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 4d8a3f728e6b5..e62551903e3cf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -587,8 +587,9 @@ class HiveContext private[hive]( LeftSemiJoin, EquiJoinSelection, BasicOperators, + BroadcastNestedLoop, CartesianProduct, - BroadcastNestedLoopJoin + DefaultJoin ) } From 7fda51170c1c994c608be9e362f5464990b3204f Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 27 Oct 2015 14:24:00 +0800 Subject: [PATCH 4/4] Add TODO for further improvement. --- .../spark/sql/execution/joins/BroadcastNestedLoopJoin.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 0924f336f3ce1..05d20f511aef8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -67,7 +67,9 @@ case class BroadcastNestedLoopJoin( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case Inner => left.output ++ right.output + case Inner => + // TODO we can avoid breaking the lineage, since we union an empty RDD for Inner Join case + left.output ++ right.output case x => // TODO support the Left Semi Join throw new IllegalArgumentException( s"BroadcastNestedLoopJoin should not take $x as the JoinType")