1717
1818package org .apache .spark .sql .execution
1919
20- import scala .collection .mutable
20+ import scala .collection .mutable .{ ArrayBuffer , BitSet }
2121
22- import org .apache .spark .rdd .RDD
2322import org .apache .spark .SparkContext
2423
25- import org .apache .spark .sql .catalyst .errors ._
2624import org .apache .spark .sql .catalyst .expressions ._
2725import org .apache .spark .sql .catalyst .plans ._
2826import org .apache .spark .sql .catalyst .plans .physical .{ClusteredDistribution , Partitioning }
2927
30- import org .apache .spark .rdd .PartitionLocalRDDFunctions ._
28+ sealed abstract class BuildSide
29+ case object BuildLeft extends BuildSide
30+ case object BuildRight extends BuildSide
3131
32- case class SparkEquiInnerJoin (
32+ case class HashJoin (
3333 leftKeys : Seq [Expression ],
3434 rightKeys : Seq [Expression ],
35+ buildSide : BuildSide ,
3536 left : SparkPlan ,
3637 right : SparkPlan ) extends BinaryNode {
3738
@@ -40,33 +41,93 @@ case class SparkEquiInnerJoin(
4041 override def requiredChildDistribution =
4142 ClusteredDistribution (leftKeys) :: ClusteredDistribution (rightKeys) :: Nil
4243
44+ val (buildPlan, streamedPlan) = buildSide match {
45+ case BuildLeft => (left, right)
46+ case BuildRight => (right, left)
47+ }
48+
49+ val (buildKeys, streamedKeys) = buildSide match {
50+ case BuildLeft => (leftKeys, rightKeys)
51+ case BuildRight => (rightKeys, leftKeys)
52+ }
53+
4354 def output = left.output ++ right.output
4455
45- def execute () = attachTree(this , " execute" ) {
46- val leftWithKeys = left.execute().mapPartitions { iter =>
47- val generateLeftKeys = new Projection (leftKeys, left.output)
48- iter.map(row => (generateLeftKeys(row), row.copy()))
49- }
56+ @ transient lazy val buildSideKeyGenerator = new Projection (buildKeys, buildPlan.output)
57+ @ transient lazy val streamSideKeyGenerator =
58+ () => new MutableProjection (streamedKeys, streamedPlan.output)
5059
51- val rightWithKeys = right.execute().mapPartitions { iter =>
52- val generateRightKeys = new Projection (rightKeys, right.output)
53- iter.map(row => (generateRightKeys(row), row.copy()))
54- }
60+ def execute () = {
5561
56- // Do the join.
57- val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys))
58- // Drop join keys and merge input tuples.
59- joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) }
60- }
62+ buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
63+ // TODO: Use Spark's HashMap implementation.
64+ val hashTable = new java.util.HashMap [Row , ArrayBuffer [Row ]]()
65+ var currentRow : Row = null
66+
67+ // Create a mapping of buildKeys -> rows
68+ while (buildIter.hasNext) {
69+ currentRow = buildIter.next()
70+ val rowKey = buildSideKeyGenerator(currentRow)
71+ if (! rowKey.anyNull) {
72+ val existingMatchList = hashTable.get(rowKey)
73+ val matchList = if (existingMatchList == null ) {
74+ val newMatchList = new ArrayBuffer [Row ]()
75+ hashTable.put(rowKey, newMatchList)
76+ newMatchList
77+ } else {
78+ existingMatchList
79+ }
80+ matchList += currentRow.copy()
81+ }
82+ }
83+
84+ new Iterator [Row ] {
85+ private [this ] var currentStreamedRow : Row = _
86+ private [this ] var currentHashMatches : ArrayBuffer [Row ] = _
87+ private [this ] var currentMatchPosition : Int = - 1
6188
62- /**
63- * Filters any rows where the any of the join keys is null, ensuring three-valued
64- * logic for the equi-join conditions.
65- */
66- protected def filterNulls (rdd : RDD [(Row , Row )]) =
67- rdd.filter {
68- case (key : Seq [_], _) => ! key.exists(_ == null )
89+ // Mutable per row objects.
90+ private [this ] val joinRow = new JoinedRow
91+
92+ private [this ] val joinKeys = streamSideKeyGenerator()
93+
94+ override final def hasNext : Boolean =
95+ (currentMatchPosition != - 1 && currentMatchPosition < currentHashMatches.size) ||
96+ (streamIter.hasNext && fetchNext())
97+
98+ override final def next () = {
99+ val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
100+ currentMatchPosition += 1
101+ ret
102+ }
103+
104+ /**
105+ * Searches the streamed iterator for the next row that has at least one match in hashtable.
106+ *
107+ * @return true if the search is successful, and false the streamed iterator runs out of
108+ * tuples.
109+ */
110+ private final def fetchNext (): Boolean = {
111+ currentHashMatches = null
112+ currentMatchPosition = - 1
113+
114+ while (currentHashMatches == null && streamIter.hasNext) {
115+ currentStreamedRow = streamIter.next()
116+ if (! joinKeys(currentStreamedRow).anyNull) {
117+ currentHashMatches = hashTable.get(joinKeys.currentValue)
118+ }
119+ }
120+
121+ if (currentHashMatches == null ) {
122+ false
123+ } else {
124+ currentMatchPosition = 0
125+ true
126+ }
127+ }
128+ }
69129 }
130+ }
70131}
71132
72133case class CartesianProduct (left : SparkPlan , right : SparkPlan ) extends BinaryNode {
@@ -95,17 +156,19 @@ case class BroadcastNestedLoopJoin(
95156 def right = broadcast
96157
97158 @ transient lazy val boundCondition =
98- condition
99- .map(c => BindReferences .bindReference(c, left.output ++ right.output))
100- .getOrElse(Literal (true ))
159+ InterpretedPredicate (
160+ condition
161+ .map(c => BindReferences .bindReference(c, left.output ++ right.output))
162+ .getOrElse(Literal (true )))
101163
102164
103165 def execute () = {
104166 val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
105167
106168 val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
107- val matchedRows = new mutable.ArrayBuffer [Row ]
108- val includedBroadcastTuples = new mutable.BitSet (broadcastedRelation.value.size)
169+ val matchedRows = new ArrayBuffer [Row ]
170+ // TODO: Use Spark's BitSet.
171+ val includedBroadcastTuples = new BitSet (broadcastedRelation.value.size)
109172 val joinedRow = new JoinedRow
110173
111174 streamedIter.foreach { streamedRow =>
@@ -115,7 +178,7 @@ case class BroadcastNestedLoopJoin(
115178 while (i < broadcastedRelation.value.size) {
116179 // TODO: One bitset per partition instead of per row.
117180 val broadcastedRow = broadcastedRelation.value(i)
118- if (boundCondition(joinedRow(streamedRow, broadcastedRow)). asInstanceOf [ Boolean ] ) {
181+ if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
119182 matchedRows += buildRow(streamedRow ++ broadcastedRow)
120183 matched = true
121184 includedBroadcastTuples += i
0 commit comments