1515 * limitations under the License.
1616 */
1717
18- package org .apache .spark .sql
19- package execution
18+ package org .apache .spark .sql .execution
2019
2120import scala .collection .mutable .{ArrayBuffer , BitSet }
2221
23- import org .apache .spark .rdd .RDD
2422import org .apache .spark .SparkContext
2523
26- import catalyst .errors ._
27- import catalyst .expressions ._
28- import catalyst .plans ._
29- import catalyst .plans .physical .{ClusteredDistribution , Partitioning }
24+ import org .apache .spark .sql .catalyst .expressions ._
25+ import org .apache .spark .sql .catalyst .plans ._
26+ import org .apache .spark .sql .catalyst .plans .physical .{ClusteredDistribution , Partitioning }
3027
3128sealed abstract class BuildSide
3229case object BuildLeft extends BuildSide
3330case object BuildRight extends BuildSide
3431
35- object InterpretCondition {
36- def apply (expression : Expression ): (Row => Boolean ) = {
37- (r : Row ) => expression.apply(r).asInstanceOf [Boolean ]
38- }
39- }
40-
4132case class HashJoin (
4233 leftKeys : Seq [Expression ],
4334 rightKeys : Seq [Expression ],
@@ -69,11 +60,12 @@ case class HashJoin(
6960 def execute () = {
7061
7162 buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
63+ // TODO: Use Spark's HashMap implementation.
7264 val hashTable = new java.util.HashMap [Row , ArrayBuffer [Row ]]()
7365 var currentRow : Row = null
7466
7567 // Create a mapping of buildKeys -> rows
76- while (buildIter.hasNext) {
68+ while (buildIter.hasNext) {
7769 currentRow = buildIter.next()
7870 val rowKey = buildSideKeyGenerator(currentRow)
7971 if (! rowKey.anyNull) {
@@ -90,40 +82,49 @@ case class HashJoin(
9082 }
9183
9284 new Iterator [Row ] {
93- private [this ] var currentRow : Row = _
94- private [this ] var currentMatches : ArrayBuffer [Row ] = _
95- private [this ] var currentPosition : Int = - 1
85+ private [this ] var currentStreamedRow : Row = _
86+ private [this ] var currentHashMatches : ArrayBuffer [Row ] = _
87+ private [this ] var currentMatchPosition : Int = - 1
9688
9789 // Mutable per row objects.
9890 private [this ] val joinRow = new JoinedRow
9991
100- @ transient private val joinKeys = streamSideKeyGenerator()
92+ private [ this ] val joinKeys = streamSideKeyGenerator()
10193
102- def hasNext : Boolean =
103- (currentPosition != - 1 && currentPosition < currentMatches.size) ||
104- (streamIter.hasNext && fetchNext())
94+ override final def hasNext : Boolean =
95+ if (currentMatchPosition != - 1 ) {
96+ currentMatchPosition < currentHashMatches.size
97+ } else {
98+ fetchNext()
99+ }
105100
106- def next () = {
107- val ret = joinRow(currentRow, currentMatches(currentPosition ))
108- currentPosition += 1
101+ override final def next () = {
102+ val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition ))
103+ currentMatchPosition += 1
109104 ret
110105 }
111106
112- private def fetchNext (): Boolean = {
113- currentMatches = null
114- currentPosition = - 1
115-
116- while (currentMatches == null && streamIter.hasNext) {
117- currentRow = streamIter.next()
118- if (! joinKeys(currentRow).anyNull) {
119- currentMatches = hashTable.get(joinKeys.currentValue)
107+ /**
108+ * Searches the streamed iterator for the next row that has at least one match in hashtable.
109+ *
110+ * @return true if the search is successful, and false the streamed iterator runs out of
111+ * tuples.
112+ */
113+ private final def fetchNext (): Boolean = {
114+ currentHashMatches = null
115+ currentMatchPosition = - 1
116+
117+ while (currentHashMatches == null && streamIter.hasNext) {
118+ currentStreamedRow = streamIter.next()
119+ if (! joinKeys(currentStreamedRow).anyNull) {
120+ currentHashMatches = hashTable.get(joinKeys.currentValue)
120121 }
121122 }
122123
123- if (currentMatches == null ) {
124+ if (currentHashMatches == null ) {
124125 false
125126 } else {
126- currentPosition = 0
127+ currentMatchPosition = 0
127128 true
128129 }
129130 }
@@ -158,7 +159,7 @@ case class BroadcastNestedLoopJoin(
158159 def right = broadcast
159160
160161 @ transient lazy val boundCondition =
161- InterpretCondition (
162+ InterpretedPredicate (
162163 condition
163164 .map(c => BindReferences .bindReference(c, left.output ++ right.output))
164165 .getOrElse(Literal (true )))
@@ -169,8 +170,8 @@ case class BroadcastNestedLoopJoin(
169170
170171 val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
171172 val matchedRows = new ArrayBuffer [Row ]
172- val includedBroadcastTuples =
173- new scala.collection.mutable. BitSet (broadcastedRelation.value.size)
173+ // TODO: Use Spark's BitSet.
174+ val includedBroadcastTuples = new BitSet (broadcastedRelation.value.size)
174175 val joinedRow = new JoinedRow
175176
176177 streamedIter.foreach { streamedRow =>
0 commit comments