@@ -142,29 +142,23 @@ case class HashJoin(
142142
143143/**
144144 * :: DeveloperApi ::
145+ * Build the right table's join keys into a HashSet, and iteratively go through the left
146+ * table, to find the if join keys are in the Hash set.
145147 */
146148@ DeveloperApi
147149case class LeftSemiJoinHash (
148- leftKeys : Seq [Expression ],
149- rightKeys : Seq [Expression ],
150- buildSide : BuildSide ,
151- left : SparkPlan ,
152- right : SparkPlan ) extends BinaryNode {
150+ leftKeys : Seq [Expression ],
151+ rightKeys : Seq [Expression ],
152+ left : SparkPlan ,
153+ right : SparkPlan ) extends BinaryNode {
153154
154155 override def outputPartitioning : Partitioning = left.outputPartitioning
155156
156157 override def requiredChildDistribution =
157158 ClusteredDistribution (leftKeys) :: ClusteredDistribution (rightKeys) :: Nil
158159
159- val (buildPlan, streamedPlan) = buildSide match {
160- case BuildLeft => (left, right)
161- case BuildRight => (right, left)
162- }
163-
164- val (buildKeys, streamedKeys) = buildSide match {
165- case BuildLeft => (leftKeys, rightKeys)
166- case BuildRight => (rightKeys, leftKeys)
167- }
160+ val (buildPlan, streamedPlan) = (right, left)
161+ val (buildKeys, streamedKeys) = (rightKeys, leftKeys)
168162
169163 def output = left.output
170164
@@ -175,24 +169,18 @@ case class LeftSemiJoinHash(
175169 def execute () = {
176170
177171 buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
178- // TODO: Use Spark's HashMap implementation.
179- val hashTable = new java.util.HashMap [Row , ArrayBuffer [Row ]]()
172+ val hashTable = new java.util.HashSet [Row ]()
180173 var currentRow : Row = null
181174
182- // Create a mapping of buildKeys -> rows
175+ // Create a Hash set of buildKeys
183176 while (buildIter.hasNext) {
184177 currentRow = buildIter.next()
185178 val rowKey = buildSideKeyGenerator(currentRow)
186179 if (! rowKey.anyNull) {
187- val existingMatchList = hashTable.get(rowKey)
188- val matchList = if (existingMatchList == null ) {
189- val newMatchList = new ArrayBuffer [Row ]()
190- hashTable.put(rowKey, newMatchList)
191- newMatchList
192- } else {
193- existingMatchList
180+ val keyExists = hashTable.contains(rowKey)
181+ if (! keyExists) {
182+ hashTable.add(rowKey)
194183 }
195- matchList += currentRow.copy()
196184 }
197185 }
198186
@@ -220,7 +208,7 @@ case class LeftSemiJoinHash(
220208 while (! currentHashMatched && streamIter.hasNext) {
221209 currentStreamedRow = streamIter.next()
222210 if (! joinKeys(currentStreamedRow).anyNull) {
223- currentHashMatched = true
211+ currentHashMatched = hashTable.contains(joinKeys.currentValue)
224212 }
225213 }
226214 currentHashMatched
@@ -232,6 +220,8 @@ case class LeftSemiJoinHash(
232220
233221/**
234222 * :: DeveloperApi ::
223+ * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
224+ * for hash join.
235225 */
236226@ DeveloperApi
237227case class LeftSemiJoinBNL (
@@ -261,26 +251,23 @@ case class LeftSemiJoinBNL(
261251 def execute () = {
262252 val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
263253
264- val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
254+ streamed.execute().mapPartitions { streamedIter =>
265255 val joinedRow = new JoinedRow
266256
267257 streamedIter.filter(streamedRow => {
268258 var i = 0
269259 var matched = false
270260
271261 while (i < broadcastedRelation.value.size && ! matched) {
272- // TODO: One bitset per partition instead of per row.
273262 val broadcastedRow = broadcastedRelation.value(i)
274263 if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
275264 matched = true
276265 }
277266 i += 1
278267 }
279268 matched
280- }).map(streamedRow => (streamedRow, null ))
269+ })
281270 }
282-
283- streamedPlusMatches.map(_._1)
284271 }
285272}
286273
0 commit comments