Skip to content

Commit 5731af5

Browse files
marmbrusrxin
authored andcommitted
[SQL] Rewrite join implementation to allow streaming of one relation.
Before we were materializing everything in memory. This also uses the projection interface so will be easier to plug in code gen (its ported from that branch). @rxin @liancheng Author: Michael Armbrust <[email protected]> Closes #250 from marmbrus/hashJoin and squashes the following commits: 1ad873e [Michael Armbrust] Change hasNext logic back to the correct version. 8e6f2a2 [Michael Armbrust] Review comments. 1e9fb63 [Michael Armbrust] style bc0cb84 [Michael Armbrust] Rewrite join implementation to allow streaming of one relation.
1 parent 841721e commit 5731af5

File tree

6 files changed

+116
-37
lines changed

6 files changed

+116
-37
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ trait Row extends Seq[Any] with Serializable {
4444
s"[${this.mkString(",")}]"
4545

4646
def copy(): Row
47+
48+
/** Returns true if there are any NULL values in this row. */
49+
def anyNull: Boolean = {
50+
var i = 0
51+
while (i < length) {
52+
if (isNullAt(i)) { return true }
53+
i += 1
54+
}
55+
false
56+
}
4757
}
4858

4959
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ import org.apache.spark.sql.catalyst.trees
2121
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2222
import org.apache.spark.sql.catalyst.types.{BooleanType, StringType}
2323

24+
object InterpretedPredicate {
25+
def apply(expression: Expression): (Row => Boolean) = {
26+
(r: Row) => expression.apply(r).asInstanceOf[Boolean]
27+
}
28+
}
29+
2430
trait Predicate extends Expression {
2531
self: Product =>
2632

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
117117
val strategies: Seq[Strategy] =
118118
TopK ::
119119
PartialAggregation ::
120-
SparkEquiInnerJoin ::
120+
HashJoin ::
121121
ParquetOperations ::
122122
BasicOperators ::
123123
CartesianProduct ::

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.parquet._
2828
abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
2929
self: SQLContext#SparkPlanner =>
3030

31-
object SparkEquiInnerJoin extends Strategy {
31+
object HashJoin extends Strategy {
3232
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
3333
case FilteredOperation(predicates, logical.Join(left, right, Inner, condition)) =>
3434
logger.debug(s"Considering join: ${predicates ++ condition}")
@@ -51,8 +51,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
5151
val leftKeys = joinKeys.map(_._1)
5252
val rightKeys = joinKeys.map(_._2)
5353

54-
val joinOp = execution.SparkEquiInnerJoin(
55-
leftKeys, rightKeys, planLater(left), planLater(right))
54+
val joinOp = execution.HashJoin(
55+
leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
5656

5757
// Make sure other conditions are met if present.
5858
if (otherPredicates.nonEmpty) {

sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala

Lines changed: 95 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,22 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import scala.collection.mutable
20+
import scala.collection.mutable.{ArrayBuffer, BitSet}
2121

22-
import org.apache.spark.rdd.RDD
2322
import org.apache.spark.SparkContext
2423

25-
import org.apache.spark.sql.catalyst.errors._
2624
import org.apache.spark.sql.catalyst.expressions._
2725
import org.apache.spark.sql.catalyst.plans._
2826
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
2927

30-
import org.apache.spark.rdd.PartitionLocalRDDFunctions._
28+
sealed abstract class BuildSide
29+
case object BuildLeft extends BuildSide
30+
case object BuildRight extends BuildSide
3131

32-
case class SparkEquiInnerJoin(
32+
case class HashJoin(
3333
leftKeys: Seq[Expression],
3434
rightKeys: Seq[Expression],
35+
buildSide: BuildSide,
3536
left: SparkPlan,
3637
right: SparkPlan) extends BinaryNode {
3738

@@ -40,33 +41,93 @@ case class SparkEquiInnerJoin(
4041
override def requiredChildDistribution =
4142
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
4243

44+
val (buildPlan, streamedPlan) = buildSide match {
45+
case BuildLeft => (left, right)
46+
case BuildRight => (right, left)
47+
}
48+
49+
val (buildKeys, streamedKeys) = buildSide match {
50+
case BuildLeft => (leftKeys, rightKeys)
51+
case BuildRight => (rightKeys, leftKeys)
52+
}
53+
4354
def output = left.output ++ right.output
4455

45-
def execute() = attachTree(this, "execute") {
46-
val leftWithKeys = left.execute().mapPartitions { iter =>
47-
val generateLeftKeys = new Projection(leftKeys, left.output)
48-
iter.map(row => (generateLeftKeys(row), row.copy()))
49-
}
56+
@transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
57+
@transient lazy val streamSideKeyGenerator =
58+
() => new MutableProjection(streamedKeys, streamedPlan.output)
5059

51-
val rightWithKeys = right.execute().mapPartitions { iter =>
52-
val generateRightKeys = new Projection(rightKeys, right.output)
53-
iter.map(row => (generateRightKeys(row), row.copy()))
54-
}
60+
def execute() = {
5561

56-
// Do the join.
57-
val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys))
58-
// Drop join keys and merge input tuples.
59-
joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) }
60-
}
62+
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
63+
// TODO: Use Spark's HashMap implementation.
64+
val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
65+
var currentRow: Row = null
66+
67+
// Create a mapping of buildKeys -> rows
68+
while (buildIter.hasNext) {
69+
currentRow = buildIter.next()
70+
val rowKey = buildSideKeyGenerator(currentRow)
71+
if(!rowKey.anyNull) {
72+
val existingMatchList = hashTable.get(rowKey)
73+
val matchList = if (existingMatchList == null) {
74+
val newMatchList = new ArrayBuffer[Row]()
75+
hashTable.put(rowKey, newMatchList)
76+
newMatchList
77+
} else {
78+
existingMatchList
79+
}
80+
matchList += currentRow.copy()
81+
}
82+
}
83+
84+
new Iterator[Row] {
85+
private[this] var currentStreamedRow: Row = _
86+
private[this] var currentHashMatches: ArrayBuffer[Row] = _
87+
private[this] var currentMatchPosition: Int = -1
6188

62-
/**
63-
* Filters any rows where the any of the join keys is null, ensuring three-valued
64-
* logic for the equi-join conditions.
65-
*/
66-
protected def filterNulls(rdd: RDD[(Row, Row)]) =
67-
rdd.filter {
68-
case (key: Seq[_], _) => !key.exists(_ == null)
89+
// Mutable per row objects.
90+
private[this] val joinRow = new JoinedRow
91+
92+
private[this] val joinKeys = streamSideKeyGenerator()
93+
94+
override final def hasNext: Boolean =
95+
(currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
96+
(streamIter.hasNext && fetchNext())
97+
98+
override final def next() = {
99+
val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
100+
currentMatchPosition += 1
101+
ret
102+
}
103+
104+
/**
105+
* Searches the streamed iterator for the next row that has at least one match in hashtable.
106+
*
107+
* @return true if the search is successful, and false the streamed iterator runs out of
108+
* tuples.
109+
*/
110+
private final def fetchNext(): Boolean = {
111+
currentHashMatches = null
112+
currentMatchPosition = -1
113+
114+
while (currentHashMatches == null && streamIter.hasNext) {
115+
currentStreamedRow = streamIter.next()
116+
if (!joinKeys(currentStreamedRow).anyNull) {
117+
currentHashMatches = hashTable.get(joinKeys.currentValue)
118+
}
119+
}
120+
121+
if (currentHashMatches == null) {
122+
false
123+
} else {
124+
currentMatchPosition = 0
125+
true
126+
}
127+
}
128+
}
69129
}
130+
}
70131
}
71132

72133
case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
@@ -95,17 +156,19 @@ case class BroadcastNestedLoopJoin(
95156
def right = broadcast
96157

97158
@transient lazy val boundCondition =
98-
condition
99-
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
100-
.getOrElse(Literal(true))
159+
InterpretedPredicate(
160+
condition
161+
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
162+
.getOrElse(Literal(true)))
101163

102164

103165
def execute() = {
104166
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
105167

106168
val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
107-
val matchedRows = new mutable.ArrayBuffer[Row]
108-
val includedBroadcastTuples = new mutable.BitSet(broadcastedRelation.value.size)
169+
val matchedRows = new ArrayBuffer[Row]
170+
// TODO: Use Spark's BitSet.
171+
val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size)
109172
val joinedRow = new JoinedRow
110173

111174
streamedIter.foreach { streamedRow =>
@@ -115,7 +178,7 @@ case class BroadcastNestedLoopJoin(
115178
while (i < broadcastedRelation.value.size) {
116179
// TODO: One bitset per partition instead of per row.
117180
val broadcastedRow = broadcastedRelation.value(i)
118-
if (boundCondition(joinedRow(streamedRow, broadcastedRow)).asInstanceOf[Boolean]) {
181+
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
119182
matchedRows += buildRow(streamedRow ++ broadcastedRow)
120183
matched = true
121184
includedBroadcastTuples += i

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
194194
DataSinks,
195195
Scripts,
196196
PartialAggregation,
197-
SparkEquiInnerJoin,
197+
HashJoin,
198198
BasicOperators,
199199
CartesianProduct,
200200
BroadcastNestedLoopJoin

0 commit comments

Comments
 (0)