Skip to content

Commit 8e6f2a2

Browse files
committed
Review comments.
1 parent 1e9fb63 commit 8e6f2a2

File tree

3 files changed

+46
-39
lines changed

3 files changed

+46
-39
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ trait Row extends Seq[Any] with Serializable {
4848
/** Returns true if there are any NULL values in this row. */
4949
def anyNull: Boolean = {
5050
var i = 0
51-
while(i < length) {
52-
if(isNullAt(i)) return true
51+
while (i < length) {
52+
if (isNullAt(i)) { return true }
5353
i += 1
5454
}
5555
false

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ import org.apache.spark.sql.catalyst.trees
2121
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2222
import org.apache.spark.sql.catalyst.types.{BooleanType, StringType}
2323

24+
object InterpretedPredicate {
25+
def apply(expression: Expression): (Row => Boolean) = {
26+
(r: Row) => expression.apply(r).asInstanceOf[Boolean]
27+
}
28+
}
29+
2430
trait Predicate extends Expression {
2531
self: Product =>
2632

sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,20 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.sql
19-
package execution
18+
package org.apache.spark.sql.execution
2019

2120
import scala.collection.mutable.{ArrayBuffer, BitSet}
2221

23-
import org.apache.spark.rdd.RDD
2422
import org.apache.spark.SparkContext
2523

26-
import catalyst.errors._
27-
import catalyst.expressions._
28-
import catalyst.plans._
29-
import catalyst.plans.physical.{ClusteredDistribution, Partitioning}
24+
import org.apache.spark.sql.catalyst.expressions._
25+
import org.apache.spark.sql.catalyst.plans._
26+
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
3027

3128
sealed abstract class BuildSide
3229
case object BuildLeft extends BuildSide
3330
case object BuildRight extends BuildSide
3431

35-
object InterpretCondition {
36-
def apply(expression: Expression): (Row => Boolean) = {
37-
(r: Row) => expression.apply(r).asInstanceOf[Boolean]
38-
}
39-
}
40-
4132
case class HashJoin(
4233
leftKeys: Seq[Expression],
4334
rightKeys: Seq[Expression],
@@ -69,11 +60,12 @@ case class HashJoin(
6960
def execute() = {
7061

7162
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
63+
// TODO: Use Spark's HashMap implementation.
7264
val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
7365
var currentRow: Row = null
7466

7567
// Create a mapping of buildKeys -> rows
76-
while(buildIter.hasNext) {
68+
while (buildIter.hasNext) {
7769
currentRow = buildIter.next()
7870
val rowKey = buildSideKeyGenerator(currentRow)
7971
if(!rowKey.anyNull) {
@@ -90,40 +82,49 @@ case class HashJoin(
9082
}
9183

9284
new Iterator[Row] {
93-
private[this] var currentRow: Row = _
94-
private[this] var currentMatches: ArrayBuffer[Row] = _
95-
private[this] var currentPosition: Int = -1
85+
private[this] var currentStreamedRow: Row = _
86+
private[this] var currentHashMatches: ArrayBuffer[Row] = _
87+
private[this] var currentMatchPosition: Int = -1
9688

9789
// Mutable per row objects.
9890
private[this] val joinRow = new JoinedRow
9991

100-
@transient private val joinKeys = streamSideKeyGenerator()
92+
private[this] val joinKeys = streamSideKeyGenerator()
10193

102-
def hasNext: Boolean =
103-
(currentPosition != -1 && currentPosition < currentMatches.size) ||
104-
(streamIter.hasNext && fetchNext())
94+
override final def hasNext: Boolean =
95+
if (currentMatchPosition != -1) {
96+
currentMatchPosition < currentHashMatches.size
97+
} else {
98+
fetchNext()
99+
}
105100

106-
def next() = {
107-
val ret = joinRow(currentRow, currentMatches(currentPosition))
108-
currentPosition += 1
101+
override final def next() = {
102+
val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
103+
currentMatchPosition += 1
109104
ret
110105
}
111106

112-
private def fetchNext(): Boolean = {
113-
currentMatches = null
114-
currentPosition = -1
115-
116-
while (currentMatches == null && streamIter.hasNext) {
117-
currentRow = streamIter.next()
118-
if(!joinKeys(currentRow).anyNull) {
119-
currentMatches = hashTable.get(joinKeys.currentValue)
107+
/**
108+
* Searches the streamed iterator for the next row that has at least one match in hashtable.
109+
*
110+
* @return true if the search is successful, and false the streamed iterator runs out of
111+
* tuples.
112+
*/
113+
private final def fetchNext(): Boolean = {
114+
currentHashMatches = null
115+
currentMatchPosition = -1
116+
117+
while (currentHashMatches == null && streamIter.hasNext) {
118+
currentStreamedRow = streamIter.next()
119+
if (!joinKeys(currentStreamedRow).anyNull) {
120+
currentHashMatches = hashTable.get(joinKeys.currentValue)
120121
}
121122
}
122123

123-
if (currentMatches == null) {
124+
if (currentHashMatches == null) {
124125
false
125126
} else {
126-
currentPosition = 0
127+
currentMatchPosition = 0
127128
true
128129
}
129130
}
@@ -158,7 +159,7 @@ case class BroadcastNestedLoopJoin(
158159
def right = broadcast
159160

160161
@transient lazy val boundCondition =
161-
InterpretCondition(
162+
InterpretedPredicate(
162163
condition
163164
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
164165
.getOrElse(Literal(true)))
@@ -169,8 +170,8 @@ case class BroadcastNestedLoopJoin(
169170

170171
val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
171172
val matchedRows = new ArrayBuffer[Row]
172-
val includedBroadcastTuples =
173-
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
173+
// TODO: Use Spark's BitSet.
174+
val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size)
174175
val joinedRow = new JoinedRow
175176

176177
streamedIter.foreach { streamedRow =>

0 commit comments

Comments
 (0)