Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
joins.BuildLeft
}
val hashJoin = joins.ShuffledHashJoin(
leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
leftKeys, rightKeys, buildSide, Inner, condition, planLater(left), planLater(right))
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil

case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) =>
joins.ShuffledHashJoin(
leftKeys, rightKeys, joins.BuildRight, LeftOuter,
condition, planLater(left), planLater(right)) :: Nil

case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) =>
joins.ShuffledHashJoin(
leftKeys, rightKeys, joins.BuildLeft, RightOuter,
condition, planLater(left), planLater(right)) :: Nil

case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
joins.HashOuterJoin(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
package org.apache.spark.sql.execution.joins

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
import org.apache.spark.sql.catalyst.plans.{Inner, FullOuter, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}

import org.apache.spark.util.collection.CompactBuffer
/**
* :: DeveloperApi ::
* Performs an inner hash join of two child relations by first shuffling the data using the join
Expand All @@ -32,19 +33,115 @@ case class ShuffledHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
buildSide: BuildSide,
joinType: JoinType,
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan)
extends BinaryNode with HashJoin {

override def outputPartitioning: Partitioning = left.outputPartitioning
override def outputPartitioning: Partitioning = joinType match {
case Inner => left.outputPartitioning
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case x => throw new Exception(s"ShuffledHashJoin should not take $x as the JoinType")
}

override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

override def output = {
joinType match {
case Inner =>
left.output ++ right.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
case x =>
throw new Exception(s"ShuffledHashJoin should not take $x as the JoinType")
}
}

private[this] lazy val nullRow = joinType match {
case LeftOuter => new GenericRow(right.output.length)
case RightOuter => new GenericRow(left.output.length)
case _ => null
}

private[this] lazy val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)

private def outerJoin(streamIter: Iterator[Row], hashedRelation: HashedRelation):Iterator[Row] = {
new Iterator[Row] {
private[this] var currentStreamedRow: Row = _
private[this] var currentHashMatches: CompactBuffer[Row] = _
private[this] var currentMatchPosition: Int = -1

// Mutable per row objects.
private[this] val joinRow = new JoinedRow2

private[this] val joinKeys = streamSideKeyGenerator()

override final def hasNext: Boolean =
(currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
(streamIter.hasNext && fetchNext())

override final def next() = {
val ret = joinType match {
case LeftOuter =>
if (currentMatchPosition == -1) {
joinRow(currentStreamedRow, nullRow)
} else {
val rightRow = currentHashMatches(currentMatchPosition)
val joinedRow = joinRow(currentStreamedRow, rightRow)
currentMatchPosition += 1
if (!boundCondition(joinedRow)) {
joinRow(currentStreamedRow, nullRow)
} else {
joinedRow
}
}
case RightOuter =>
if (currentMatchPosition == -1) {
joinRow(nullRow, currentStreamedRow)
} else {
val leftRow = currentHashMatches(currentMatchPosition)
val joinedRow = joinRow(leftRow, currentStreamedRow)
currentMatchPosition += 1
if (!boundCondition(joinedRow)) {
joinRow(nullRow, currentStreamedRow)
} else {
joinedRow
}
}
}
ret
}

private final def fetchNext(): Boolean = {
currentMatchPosition = -1
currentHashMatches = null
currentStreamedRow = streamIter.next()
if (!joinKeys(currentStreamedRow).anyNull) {
currentHashMatches = hashedRelation.get(joinKeys.currentValue)
}
if (currentHashMatches != null) {
currentMatchPosition = 0
}
true
}
}
}

override def execute() = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashed = HashedRelation(buildIter, buildSideKeyGenerator)
hashJoin(streamIter, hashed)
joinType match {
case Inner => hashJoin(streamIter, hashed)
case LeftOuter => outerJoin(streamIter, hashed)
case RightOuter => outerJoin(streamIter, hashed)
case x => throw new Exception(s"ShuffledHashJoin should not take $x as the JoinType")
}
}
}
}
6 changes: 3 additions & 3 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]),
("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]),
("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
classOf[HashOuterJoin]),
classOf[ShuffledHashJoin]),
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
classOf[HashOuterJoin]),
classOf[ShuffledHashJoin]),
("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin])
// TODO add BroadcastNestedLoopJoin
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
Expand Down