Skip to content

Commit 91e9389

Browse files
JoshRosenrxin
authored andcommitted
[SPARK-9729] [SPARK-9363] [SQL] Use sort merge join for left and right outer join
This patch adds a new `SortMergeOuterJoin` operator that performs left and right outer joins using sort merge join. It also refactors `SortMergeJoin` in order to improve performance and code clarity. Along the way, I also performed a couple pieces of minor cleanup and optimization: - Rename the `HashJoin` physical planner rule to `EquiJoinSelection`, since it's also used for non-hash joins. - Rewrite the comment at the top of `HashJoin` to better explain the precedence for choosing join operators. - Update `JoinSuite` to use `SqlTestUtils.withConf` for changing SQLConf settings. This patch incorporates several ideas from adrian-wang's patch, #5717. Closes #5717. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/7904) <!-- Reviewable:end --> Author: Josh Rosen <[email protected]> Author: Daoyuan Wang <[email protected]> Closes #7904 from JoshRosen/outer-join-smj and squashes 1 commits.
1 parent 071bbad commit 91e9389

File tree

13 files changed

+1165
-319
lines changed

13 files changed

+1165
-319
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,20 @@ class JoinedRow extends InternalRow {
3737
}
3838

3939
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
40-
def apply(r1: InternalRow, r2: InternalRow): InternalRow = {
40+
def apply(r1: InternalRow, r2: InternalRow): JoinedRow = {
4141
row1 = r1
4242
row2 = r2
4343
this
4444
}
4545

4646
/** Updates this JoinedRow by updating its left base row. Returns itself. */
47-
def withLeft(newLeft: InternalRow): InternalRow = {
47+
def withLeft(newLeft: InternalRow): JoinedRow = {
4848
row1 = newLeft
4949
this
5050
}
5151

5252
/** Updates this JoinedRow by updating its right base row. Returns itself. */
53-
def withRight(newRight: InternalRow): InternalRow = {
53+
def withRight(newRight: InternalRow): JoinedRow = {
5454
row2 = newRight
5555
this
5656
}

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
873873
HashAggregation ::
874874
Aggregation ::
875875
LeftSemiJoin ::
876-
HashJoin ::
876+
EquiJoinSelection ::
877877
InMemoryScans ::
878878
BasicOperators ::
879879
CartesianProduct ::
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import java.util.NoSuchElementException
21+
22+
import org.apache.spark.sql.catalyst.InternalRow
23+
24+
/**
25+
* An internal iterator interface which presents a more restrictive API than
26+
* [[scala.collection.Iterator]].
27+
*
28+
* One major departure from the Scala iterator API is the fusing of the `hasNext()` and `next()`
29+
* calls: Scala's iterator allows users to call `hasNext()` without immediately advancing the
30+
* iterator to consume the next row, whereas RowIterator combines these calls into a single
31+
* [[advanceNext()]] method.
32+
*/
33+
private[sql] abstract class RowIterator {
34+
/**
35+
* Advance this iterator by a single row. Returns `false` if this iterator has no more rows
36+
* and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling
37+
* [[getRow]].
38+
*/
39+
def advanceNext(): Boolean
40+
41+
/**
42+
* Retrieve the row from this iterator. This method is idempotent. It is illegal to call this
43+
* method after [[advanceNext()]] has returned `false`.
44+
*/
45+
def getRow: InternalRow
46+
47+
/**
48+
* Convert this RowIterator into a [[scala.collection.Iterator]].
49+
*/
50+
def toScala: Iterator[InternalRow] = new RowIteratorToScala(this)
51+
}
52+
53+
object RowIterator {
54+
def fromScala(scalaIter: Iterator[InternalRow]): RowIterator = {
55+
scalaIter match {
56+
case wrappedRowIter: RowIteratorToScala => wrappedRowIter.rowIter
57+
case _ => new RowIteratorFromScala(scalaIter)
58+
}
59+
}
60+
}
61+
62+
private final class RowIteratorToScala(val rowIter: RowIterator) extends Iterator[InternalRow] {
63+
private [this] var hasNextWasCalled: Boolean = false
64+
private [this] var _hasNext: Boolean = false
65+
override def hasNext: Boolean = {
66+
// Idempotency:
67+
if (!hasNextWasCalled) {
68+
_hasNext = rowIter.advanceNext()
69+
hasNextWasCalled = true
70+
}
71+
_hasNext
72+
}
73+
override def next(): InternalRow = {
74+
if (!hasNext) throw new NoSuchElementException
75+
hasNextWasCalled = false
76+
rowIter.getRow
77+
}
78+
}
79+
80+
private final class RowIteratorFromScala(scalaIter: Iterator[InternalRow]) extends RowIterator {
81+
private[this] var _next: InternalRow = null
82+
override def advanceNext(): Boolean = {
83+
if (scalaIter.hasNext) {
84+
_next = scalaIter.next()
85+
true
86+
} else {
87+
_next = null
88+
false
89+
}
90+
}
91+
override def getRow: InternalRow = _next
92+
override def toScala: Iterator[InternalRow] = scalaIter
93+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
6363
}
6464

6565
/**
66-
* Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be
67-
* evaluated by matching hash keys.
66+
* Uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the predicates
67+
* can be evaluated by matching join keys.
6868
*
69-
* This strategy applies a simple optimization based on the estimates of the physical sizes of
70-
* the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an
71-
* estimated physical size smaller than the user-settable threshold
72-
* [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the
73-
* ''build'' relation and mark the other relation as the ''stream'' side. The build table will be
74-
* ''broadcasted'' to all of the executors involved in the join, as a
75-
* [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they
76-
* will instead be used to decide the build side in a [[joins.ShuffledHashJoin]].
69+
* Join implementations are chosen with the following precedence:
70+
*
71+
* - Broadcast: if one side of the join has an estimated physical size that is smaller than the
72+
* user-configurable [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold
73+
* or if that side has an explicit broadcast hint (e.g. the user applied the
74+
* [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side
75+
* of the join will be broadcasted and the other side will be streamed, with no shuffling
76+
* performed. If both sides of the join are eligible to be broadcasted then the
77+
* - Sort merge: if the matching join keys are sortable and
78+
* [[org.apache.spark.sql.SQLConf.SORTMERGE_JOIN]] is enabled (default), then sort merge join
79+
* will be used.
80+
* - Hash: will be chosen if neither of the above optimizations apply to this join.
7781
*/
78-
object HashJoin extends Strategy with PredicateHelper {
82+
object EquiJoinSelection extends Strategy with PredicateHelper {
7983

8084
private[this] def makeBroadcastHashJoin(
8185
leftKeys: Seq[Expression],
@@ -90,14 +94,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
9094
}
9195

9296
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
97+
98+
// --- Inner joins --------------------------------------------------------------------------
99+
93100
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
94101
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)
95102

96103
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
97104
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
98105

99-
// If the sort merge join option is set, we want to use sort merge join prior to hashjoin
100-
// for now let's support inner join first, then add outer join
101106
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
102107
if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
103108
val mergeJoin =
@@ -115,6 +120,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
115120
leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
116121
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil
117122

123+
// --- Outer joins --------------------------------------------------------------------------
124+
118125
case ExtractEquiJoinKeys(
119126
LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
120127
joins.BroadcastHashOuterJoin(
@@ -125,10 +132,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
125132
joins.BroadcastHashOuterJoin(
126133
leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
127134

135+
case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right)
136+
if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
137+
joins.SortMergeOuterJoin(
138+
leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil
139+
140+
case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right)
141+
if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
142+
joins.SortMergeOuterJoin(
143+
leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
144+
128145
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
129146
joins.ShuffledHashOuterJoin(
130147
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
131148

149+
// --- Cases where this strategy does not apply ---------------------------------------------
150+
132151
case _ => Nil
133152
}
134153
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ case class BroadcastNestedLoopJoin(
6565
left.output.map(_.withNullability(true)) ++ right.output
6666
case FullOuter =>
6767
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
68-
case _ =>
69-
left.output ++ right.output
68+
case x =>
69+
throw new IllegalArgumentException(
70+
s"BroadcastNestedLoopJoin should not take $x as the JoinType")
7071
}
7172
}
7273

0 commit comments

Comments
 (0)