Skip to content

Commit 9d824fe

Browse files
concretevitaminmarmbrus
authored andcommitted
[SQL] SPARK-1800 Add broadcast hash join operator & associated hints.
This PR is based off Michael's [PR 734](#734) and includes a bunch of cleanups. Moreover, this PR also - makes `SparkLogicalPlan` take a `tableName: String`, which facilitates testing. - moves join-related tests to a single file. Author: Zongheng Yang <[email protected]> Author: Michael Armbrust <[email protected]> Closes #1163 from concretevitamin/auto-broadcast-hash-join and squashes the following commits: d0f4991 [Zongheng Yang] Fix bug in broadcast hash join & add test to cover it. af080d7 [Zongheng Yang] Fix in joinIterators()'s next(). 440d277 [Zongheng Yang] Fixes to imports; add back requiredChildDistribution (lost when merging) 208d5f6 [Zongheng Yang] Make LeftSemiJoinHash mix in HashJoin. ad6c7cc [Zongheng Yang] Minor cleanups. 814b3bf [Zongheng Yang] Merge branch 'master' into auto-broadcast-hash-join a8a093e [Zongheng Yang] Minor cleanups. 6fd8443 [Zongheng Yang] Cut down size estimation related stuff. a4267be [Zongheng Yang] Add test for broadcast hash join and related necessary refactorings: 0e64b08 [Zongheng Yang] Scalastyle fix. 91461c2 [Zongheng Yang] Merge branch 'master' into auto-broadcast-hash-join 7c7158b [Zongheng Yang] Prototype of auto conversion to broadcast hash join. 0ad122f [Zongheng Yang] Merge branch 'master' into auto-broadcast-hash-join 3e5d77c [Zongheng Yang] WIP: giant and messy WIP. a92ed0c [Michael Armbrust] Formatting. 76ca434 [Michael Armbrust] A simple strategy that broadcasts tables only when they are found in a configuration hint. cf6b381 [Michael Armbrust] Split out generic logic for hash joins and create two concrete physical operators: BroadcastHashJoin and ShuffledHashJoin. a8420ca [Michael Armbrust] Copy records in executeCollect to avoid issues with mutable rows.
1 parent 1132e47 commit 9d824fe

File tree

15 files changed

+395
-233
lines changed

15 files changed

+395
-233
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) {
4545
* that schema.
4646
*
4747
* In contrast to a normal projection, a MutableProjection reuses the same underlying row object
48-
* each time an input row is added. This significatly reduces the cost of calcuating the
49-
* projection, but means that it is not safe
48+
* each time an input row is added. This significantly reduces the cost of calculating the
49+
* projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()`
50+
* has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()`
51+
* and hold on to the returned [[Row]] before calling `next()`.
5052
*/
5153
case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) {
5254
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
@@ -67,7 +69,7 @@ case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row)
6769
}
6870

6971
/**
70-
* A mutable wrapper that makes two rows appear appear as a single concatenated row. Designed to
72+
* A mutable wrapper that makes two rows appear as a single concatenated row. Designed to
7173
* be instantiated once per thread and reused.
7274
*/
7375
class JoinedRow extends Row {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,4 @@ abstract class BaseRelation extends LeafNode {
2121
self: Product =>
2222

2323
def tableName: String
24-
def isPartitioned: Boolean = false
2524
}

sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,26 @@ import scala.collection.JavaConverters._
2929
*/
3030
trait SQLConf {
3131

32+
/** ************************ Spark SQL Params/Hints ******************* */
33+
// TODO: refactor so that these hints accessors don't pollute the name space of SQLContext?
34+
3235
/** Number of partitions to use for shuffle operators. */
3336
private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt
3437

38+
/**
39+
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
40+
* a broadcast value during the physical executions of join operations. Setting this to 0
41+
* effectively disables auto conversion.
42+
* Hive setting: hive.auto.convert.join.noconditionaltask.size.
43+
*/
44+
private[spark] def autoConvertJoinSize: Int =
45+
get("spark.sql.auto.convert.join.size", "10000").toInt
46+
47+
/** A comma-separated list of table names marked to be broadcasted during joins. */
48+
private[spark] def joinBroadcastTables: String = get("spark.sql.join.broadcastTables", "")
49+
50+
/** ********************** SQLConf functionality methods ************ */
51+
3552
@transient
3653
private val settings = java.util.Collections.synchronizedMap(
3754
new java.util.HashMap[String, String]())

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
170170
* @group userf
171171
*/
172172
def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = {
173-
catalog.registerTable(None, tableName, rdd.logicalPlan)
173+
val name = tableName
174+
val newPlan = rdd.logicalPlan transform {
175+
case s @ SparkLogicalPlan(ExistingRdd(_, _), _) => s.copy(tableName = name)
176+
}
177+
catalog.registerTable(None, tableName, newPlan)
174178
}
175179

176180
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ import org.apache.spark.sql.{Logging, Row}
2323
import org.apache.spark.sql.catalyst.trees
2424
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
2525
import org.apache.spark.sql.catalyst.expressions.GenericRow
26-
import org.apache.spark.sql.catalyst.plans.{QueryPlan, logical}
26+
import org.apache.spark.sql.catalyst.plans.QueryPlan
27+
import org.apache.spark.sql.catalyst.plans.logical.BaseRelation
2728
import org.apache.spark.sql.catalyst.plans.physical._
28-
import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
2929

3030
/**
3131
* :: DeveloperApi ::
@@ -66,19 +66,20 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
6666
* linking.
6767
*/
6868
@DeveloperApi
69-
case class SparkLogicalPlan(alreadyPlanned: SparkPlan)
70-
extends logical.LogicalPlan with MultiInstanceRelation {
69+
case class SparkLogicalPlan(alreadyPlanned: SparkPlan, tableName: String = "SparkLogicalPlan")
70+
extends BaseRelation with MultiInstanceRelation {
7171

7272
def output = alreadyPlanned.output
73-
def references = Set.empty
74-
def children = Nil
73+
override def references = Set.empty
74+
override def children = Nil
7575

7676
override final def newInstance: this.type = {
7777
SparkLogicalPlan(
7878
alreadyPlanned match {
7979
case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
8080
case _ => sys.error("Multiple instance of the same relation detected.")
81-
}).asInstanceOf[this.type]
81+
}, tableName)
82+
.asInstanceOf[this.type]
8283
}
8384
}
8485

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ import org.apache.spark.sql.{SQLContext, execution}
2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.planning._
2323
import org.apache.spark.sql.catalyst.plans._
24-
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
24+
import org.apache.spark.sql.catalyst.plans.logical.{BaseRelation, LogicalPlan}
2525
import org.apache.spark.sql.catalyst.plans.physical._
26-
import org.apache.spark.sql.parquet._
2726
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
27+
import org.apache.spark.sql.parquet._
2828

2929
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
3030
self: SQLContext#SparkPlanner =>
@@ -45,14 +45,52 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
4545
}
4646
}
4747

48+
/**
49+
* Uses the HashFilteredJoin pattern to find joins where at least some of the predicates can be
50+
* evaluated by matching hash keys.
51+
*/
4852
object HashJoin extends Strategy with PredicateHelper {
53+
private[this] def broadcastHashJoin(
54+
leftKeys: Seq[Expression],
55+
rightKeys: Seq[Expression],
56+
left: LogicalPlan,
57+
right: LogicalPlan,
58+
condition: Option[Expression],
59+
side: BuildSide) = {
60+
val broadcastHashJoin = execution.BroadcastHashJoin(
61+
leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext)
62+
condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
63+
}
64+
65+
def broadcastTables: Seq[String] = sqlContext.joinBroadcastTables.split(",").toBuffer
66+
4967
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
50-
// Find inner joins where at least some predicates can be evaluated by matching hash keys
51-
// using the HashFilteredJoin pattern.
68+
case HashFilteredJoin(
69+
Inner,
70+
leftKeys,
71+
rightKeys,
72+
condition,
73+
left,
74+
right @ PhysicalOperation(_, _, b: BaseRelation))
75+
if broadcastTables.contains(b.tableName) =>
76+
broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight)
77+
78+
case HashFilteredJoin(
79+
Inner,
80+
leftKeys,
81+
rightKeys,
82+
condition,
83+
left @ PhysicalOperation(_, _, b: BaseRelation),
84+
right)
85+
if broadcastTables.contains(b.tableName) =>
86+
broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft)
87+
5288
case HashFilteredJoin(Inner, leftKeys, rightKeys, condition, left, right) =>
5389
val hashJoin =
54-
execution.HashJoin(leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
90+
execution.ShuffledHashJoin(
91+
leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
5592
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil
93+
5694
case _ => Nil
5795
}
5896
}
@@ -62,10 +100,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
62100
case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
63101
// Collect all aggregate expressions.
64102
val allAggregates =
65-
aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a})
103+
aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a })
66104
// Collect all aggregate expressions that can be computed partially.
67105
val partialAggregates =
68-
aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p})
106+
aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p })
69107

70108
// Only do partial aggregation if supported by all aggregate expressions.
71109
if (allAggregates.size == partialAggregates.size) {
@@ -242,7 +280,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
242280
execution.ExistingRdd(Nil, singleRowRdd) :: Nil
243281
case logical.Repartition(expressions, child) =>
244282
execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
245-
case SparkLogicalPlan(existingPlan) => existingPlan :: Nil
283+
case SparkLogicalPlan(existingPlan, _) => existingPlan :: Nil
246284
case _ => Nil
247285
}
248286
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,4 +205,3 @@ object ExistingRdd {
205205
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
206206
override def execute() = rdd
207207
}
208-

0 commit comments

Comments
 (0)