Skip to content

Commit a39be8c

Browse files
committed
[SPARK-3857] Create a join package for various join operators.
1 parent bc44187 commit a39be8c

File tree

15 files changed

+843
-647
lines changed

15 files changed

+843
-647
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,19 @@ import org.apache.spark.sql.catalyst.types._
2727
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
2828
import org.apache.spark.sql.parquet._
2929

30+
3031
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
3132
self: SQLContext#SparkPlanner =>
3233

3334
object LeftSemiJoin extends Strategy with PredicateHelper {
3435
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
3536
// Find left semi joins where at least some predicates can be evaluated by matching join keys
3637
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
37-
val semiJoin = execution.LeftSemiJoinHash(
38-
leftKeys, rightKeys, planLater(left), planLater(right))
38+
val semiJoin = join.LeftSemiJoinHash(leftKeys, rightKeys, planLater(left), planLater(right))
3939
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
4040
// no predicate can be evaluated by matching hash keys
4141
case logical.Join(left, right, LeftSemi, condition) =>
42-
execution.LeftSemiJoinBNL(
43-
planLater(left), planLater(right), condition) :: Nil
42+
join.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil
4443
case _ => Nil
4544
}
4645
}
@@ -50,13 +49,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
5049
* evaluated by matching hash keys.
5150
*
5251
* This strategy applies a simple optimization based on the estimates of the physical sizes of
53-
* the two join sides. When planning a [[execution.BroadcastHashJoin]], if one side has an
52+
* the two join sides. When planning a [[join.BroadcastHashJoin]], if one side has an
5453
* estimated physical size smaller than the user-settable threshold
5554
* [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the
5655
* ''build'' relation and mark the other relation as the ''stream'' side. The build table will be
5756
* ''broadcasted'' to all of the executors involved in the join, as a
5857
* [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they
59-
* will instead be used to decide the build side in a [[execution.ShuffledHashJoin]].
58+
* will instead be used to decide the build side in a [[join.ShuffledHashJoin]].
6059
*/
6160
object HashJoin extends Strategy with PredicateHelper {
6261

@@ -66,8 +65,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
6665
left: LogicalPlan,
6766
right: LogicalPlan,
6867
condition: Option[Expression],
69-
side: BuildSide) = {
70-
val broadcastHashJoin = execution.BroadcastHashJoin(
68+
side: join.BuildSide) = {
69+
val broadcastHashJoin = execution.join.BroadcastHashJoin(
7170
leftKeys, rightKeys, side, planLater(left), planLater(right))
7271
condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
7372
}
@@ -76,27 +75,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
7675
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
7776
if sqlContext.autoBroadcastJoinThreshold > 0 &&
7877
right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
79-
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight)
78+
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, join.BuildRight)
8079

8180
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
8281
if sqlContext.autoBroadcastJoinThreshold > 0 &&
8382
left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
84-
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft)
83+
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, join.BuildLeft)
8584

8685
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
8786
val buildSide =
8887
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
89-
BuildRight
88+
join.BuildRight
9089
} else {
91-
BuildLeft
90+
join.BuildLeft
9291
}
93-
val hashJoin =
94-
execution.ShuffledHashJoin(
95-
leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
92+
val hashJoin = join.ShuffledHashJoin(
93+
leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
9694
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil
9795

9896
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
99-
execution.HashOuterJoin(
97+
join.HashOuterJoin(
10098
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
10199

102100
case _ => Nil
@@ -164,8 +162,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
164162
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
165163
case logical.Join(left, right, joinType, condition) =>
166164
val buildSide =
167-
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) BuildRight else BuildLeft
168-
execution.BroadcastNestedLoopJoin(
165+
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
166+
join.BuildRight
167+
} else {
168+
join.BuildLeft
169+
}
170+
join.BroadcastNestedLoopJoin(
169171
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
170172
case _ => Nil
171173
}
@@ -174,10 +176,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
174176
object CartesianProduct extends Strategy {
175177
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
176178
case logical.Join(left, right, _, None) =>
177-
execution.CartesianProduct(planLater(left), planLater(right)) :: Nil
179+
execution.join.CartesianProduct(planLater(left), planLater(right)) :: Nil
178180
case logical.Join(left, right, Inner, Some(condition)) =>
179181
execution.Filter(condition,
180-
execution.CartesianProduct(planLater(left), planLater(right))) :: Nil
182+
execution.join.CartesianProduct(planLater(left), planLater(right))) :: Nil
181183
case _ => Nil
182184
}
183185
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.join
19+
20+
import scala.concurrent._
21+
import scala.concurrent.duration._
22+
import scala.concurrent.ExecutionContext.Implicits.global
23+
24+
import org.apache.spark.annotation.DeveloperApi
25+
import org.apache.spark.sql.catalyst.expressions.Expression
26+
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution}
27+
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
28+
29+
/**
30+
* :: DeveloperApi ::
31+
* Performs an inner hash join of two child relations. When the output RDD of this operator is
32+
* being constructed, a Spark job is asynchronously started to calculate the values for the
33+
* broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed
34+
* relation is not shuffled.
35+
*/
36+
@DeveloperApi
37+
case class BroadcastHashJoin(
38+
leftKeys: Seq[Expression],
39+
rightKeys: Seq[Expression],
40+
buildSide: BuildSide,
41+
left: SparkPlan,
42+
right: SparkPlan)
43+
extends BinaryNode with HashJoin {
44+
45+
override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
46+
47+
override def requiredChildDistribution =
48+
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
49+
50+
@transient
51+
private val broadcastFuture = future {
52+
sparkContext.broadcast(buildPlan.executeCollect())
53+
}
54+
55+
override def execute() = {
56+
val broadcastRelation = Await.result(broadcastFuture, 5.minute)
57+
58+
streamedPlan.execute().mapPartitions { streamedIter =>
59+
joinIterators(broadcastRelation.value.iterator, streamedIter)
60+
}
61+
}
62+
}
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.join
19+
20+
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
23+
import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
24+
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
25+
import org.apache.spark.util.collection.CompactBuffer
26+
27+
/**
28+
* :: DeveloperApi ::
29+
*/
30+
@DeveloperApi
31+
case class BroadcastNestedLoopJoin(
32+
left: SparkPlan,
33+
right: SparkPlan,
34+
buildSide: BuildSide,
35+
joinType: JoinType,
36+
condition: Option[Expression]) extends BinaryNode {
37+
// TODO: Override requiredChildDistribution.
38+
39+
/** BuildRight means the right relation <=> the broadcast relation. */
40+
private val (streamed, broadcast) = buildSide match {
41+
case BuildRight => (left, right)
42+
case BuildLeft => (right, left)
43+
}
44+
45+
override def outputPartitioning: Partitioning = streamed.outputPartitioning
46+
47+
override def output = {
48+
joinType match {
49+
case LeftOuter =>
50+
left.output ++ right.output.map(_.withNullability(true))
51+
case RightOuter =>
52+
left.output.map(_.withNullability(true)) ++ right.output
53+
case FullOuter =>
54+
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
55+
case _ =>
56+
left.output ++ right.output
57+
}
58+
}
59+
60+
@transient private lazy val boundCondition =
61+
InterpretedPredicate(
62+
condition
63+
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
64+
.getOrElse(Literal(true)))
65+
66+
override def execute() = {
67+
val broadcastedRelation =
68+
sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
69+
70+
/** All rows that either match both-way, or rows from streamed joined with nulls. */
71+
val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
72+
val matchedRows = new CompactBuffer[Row]
73+
// TODO: Use Spark's BitSet.
74+
val includedBroadcastTuples =
75+
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
76+
val joinedRow = new JoinedRow
77+
val leftNulls = new GenericMutableRow(left.output.size)
78+
val rightNulls = new GenericMutableRow(right.output.size)
79+
80+
streamedIter.foreach { streamedRow =>
81+
var i = 0
82+
var streamRowMatched = false
83+
84+
while (i < broadcastedRelation.value.size) {
85+
// TODO: One bitset per partition instead of per row.
86+
val broadcastedRow = broadcastedRelation.value(i)
87+
buildSide match {
88+
case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
89+
matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
90+
streamRowMatched = true
91+
includedBroadcastTuples += i
92+
case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
93+
matchedRows += joinedRow(broadcastedRow, streamedRow).copy()
94+
streamRowMatched = true
95+
includedBroadcastTuples += i
96+
case _ =>
97+
}
98+
i += 1
99+
}
100+
101+
(streamRowMatched, joinType, buildSide) match {
102+
case (false, LeftOuter | FullOuter, BuildRight) =>
103+
matchedRows += joinedRow(streamedRow, rightNulls).copy()
104+
case (false, RightOuter | FullOuter, BuildLeft) =>
105+
matchedRows += joinedRow(leftNulls, streamedRow).copy()
106+
case _ =>
107+
}
108+
}
109+
Iterator((matchedRows, includedBroadcastTuples))
110+
}
111+
112+
val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2)
113+
val allIncludedBroadcastTuples =
114+
if (includedBroadcastTuples.count == 0) {
115+
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
116+
} else {
117+
includedBroadcastTuples.reduce(_ ++ _)
118+
}
119+
120+
val leftNulls = new GenericMutableRow(left.output.size)
121+
val rightNulls = new GenericMutableRow(right.output.size)
122+
/** Rows from broadcasted joined with nulls. */
123+
val broadcastRowsWithNulls: Seq[Row] = {
124+
val buf: CompactBuffer[Row] = new CompactBuffer()
125+
var i = 0
126+
val rel = broadcastedRelation.value
127+
while (i < rel.length) {
128+
if (!allIncludedBroadcastTuples.contains(i)) {
129+
(joinType, buildSide) match {
130+
case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i))
131+
case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls)
132+
case _ =>
133+
}
134+
}
135+
i += 1
136+
}
137+
buf.toSeq
138+
}
139+
140+
// TODO: Breaks lineage.
141+
sparkContext.union(
142+
matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls))
143+
}
144+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.join
19+
20+
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.sql.catalyst.expressions.JoinedRow
22+
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
23+
24+
/**
25+
* :: DeveloperApi ::
26+
*/
27+
@DeveloperApi
28+
case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
29+
override def output = left.output ++ right.output
30+
31+
override def execute() = {
32+
val leftResults = left.execute().map(_.copy())
33+
val rightResults = right.execute().map(_.copy())
34+
35+
leftResults.cartesian(rightResults).mapPartitions { iter =>
36+
val joinedRow = new JoinedRow
37+
iter.map(r => joinedRow(r._1, r._2))
38+
}
39+
}
40+
}

0 commit comments

Comments
 (0)