Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// --- Inner joins --------------------------------------------------------------------------

case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
joins.BroadcastHashJoin(
leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil
Seq(joins.BroadcastHashJoin(
leftKeys, rightKeys, Inner, BuildRight, condition, planLater(left), planLater(right)))

case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
joins.BroadcastHashJoin(
leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil
Seq(joins.BroadcastHashJoin(
leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right)))

case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
Expand All @@ -124,13 +124,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

case ExtractEquiJoinKeys(
LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
joins.BroadcastHashOuterJoin(
leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil
Seq(joins.BroadcastHashJoin(
leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right)))

case ExtractEquiJoinKeys(
RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
joins.BroadcastHashOuterJoin(
leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
Seq(joins.BroadcastHashJoin(
leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))

case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
import org.apache.spark.sql.execution.metric.{LongSQLMetric, LongSQLMetricValue, SQLMetric}
import org.apache.spark.sql.execution.metric.LongSQLMetricValue

/**
* An interface for those physical operators that support codegen.
Expand All @@ -38,7 +38,7 @@ trait CodegenSupport extends SparkPlan {
/** Prefix used in the current operator's variable names. */
private def variablePrefix: String = this match {
case _: TungstenAggregate => "agg"
case _: BroadcastHashJoin => "bhj"
case _: BroadcastHashJoin => "join"
case _ => nodeName.toLowerCase
}

Expand Down Expand Up @@ -391,9 +391,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
// The build side can't be compiled together
case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) =>
case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) =>
b.copy(left = apply(left))
case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) =>
case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
b.copy(right = apply(right))
case p if !supportCodegen(p) =>
val input = apply(p) // collapse them recursively
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.metric.SQLMetrics
Expand All @@ -41,6 +42,7 @@ import org.apache.spark.util.collection.CompactBuffer
case class BroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
buildSide: BuildSide,
condition: Option[Expression],
left: SparkPlan,
Expand Down Expand Up @@ -105,75 +107,144 @@ case class BroadcastHashJoin(
val broadcastRelation = Await.result(broadcastFuture, timeout)

streamedPlan.execute().mapPartitions { streamedIter =>
val hashedRelation = broadcastRelation.value
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
hashJoin(streamedIter, hashedRelation, numOutputRows)
val joinedRow = new JoinedRow()
val hashTable = broadcastRelation.value
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
val keyGenerator = streamSideKeyGenerator
val resultProj = createResultProjection

joinType match {
case Inner =>
hashJoin(streamedIter, hashTable, numOutputRows)

case LeftOuter =>
streamedIter.flatMap { currentRow =>
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
}

case RightOuter =>
streamedIter.flatMap { currentRow =>
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
}

case x =>
throw new IllegalArgumentException(
s"BroadcastHashJoin should not take $x as the JoinType")
}
}
}

private var broadcastRelation: Broadcast[HashedRelation] = _
// the term for hash relation
private var relationTerm: String = _

override def upstream(): RDD[InternalRow] = {
streamedPlan.asInstanceOf[CodegenSupport].upstream()
}

override def doProduce(ctx: CodegenContext): String = {
streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
if (joinType == Inner) {
codegenInner(ctx, input)
} else {
// LeftOuter and RightOuter
codegenOuter(ctx, input)
}
}

/**
* Returns a tuple of Broadcast of HashedRelation and the variable name for it.
*/
private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

document what the fields in the return tuple mean

// create a name for HashedRelation
broadcastRelation = Await.result(broadcastFuture, timeout)
val broadcastRelation = Await.result(broadcastFuture, timeout)
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
relationTerm = ctx.freshName("relation")
val relationTerm = ctx.freshName("relation")
val clsName = broadcastRelation.value.getClass.getName
ctx.addMutableState(clsName, relationTerm,
s"""
| $relationTerm = ($clsName) $broadcast.value();
| incPeakExecutionMemory($relationTerm.getMemorySize());
""".stripMargin)

s"""
| ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)}
""".stripMargin
(broadcastRelation, relationTerm)
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// generate the key as UnsafeRow or Long
/**
* Returns the code for generating join key for stream side, and expression of whether the key
* has any null in it or not.
*/
private def genStreamSideJoinKey(
ctx: CodegenContext,
input: Seq[ExprCode]): (ExprCode, String) = {
ctx.currentVars = input
val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) {
if (canJoinKeyFitWithinLong) {
// generate the join key as Long
val expr = rewriteKeyExpr(streamedKeys).head
val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx)
(ev, ev.isNull)
} else {
// generate the join key as UnsafeRow
val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr)
(ev, s"${ev.value}.anyNull()")
}
}

// find the matches from HashedRelation
val matched = ctx.freshName("matched")

// create variables for output
/**
* Generates the code for variable of build side.
*/
private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = {
ctx.currentVars = null
ctx.INPUT_ROW = matched
val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
BoundReference(i, a.dataType, a.nullable).gen(ctx)
buildPlan.output.zipWithIndex.map { case (a, i) =>
val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx)
if (joinType == Inner) {
ev
} else {
// the variables are needed even there is no matched rows
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really confused by this. What is this doing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok never mind i think i get it now.

val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val code = s"""
|boolean $isNull = true;
|${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)};
|if ($matched != null) {
| ${ev.code}
| $isNull = ${ev.isNull};
| $value = ${ev.value};
|}
""".stripMargin
ExprCode(code, isNull, value)
}
}
}

/**
* Generates the code for Inner join.
*/
private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val matched = ctx.freshName("matched")
val buildVars = genBuildSideVars(ctx, matched)
val resultVars = buildSide match {
case BuildLeft => buildColumns ++ input
case BuildRight => input ++ buildColumns
case BuildLeft => buildVars ++ input
case BuildRight => input ++ buildVars
}

val numOutput = metricTerm(ctx, "numOutputRows")

val outputCode = if (condition.isDefined) {
// filter the output via condition
ctx.currentVars = resultVars
val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
s"""
| ${ev.code}
| if (!${ev.isNull} && ${ev.value}) {
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
| }
|${ev.code}
|if (!${ev.isNull} && ${ev.value}) {
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
|}
""".stripMargin
} else {
s"""
Expand All @@ -184,36 +255,110 @@ case class BroadcastHashJoin(

if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
s"""
| // generate join key
| ${keyVal.code}
| // find matches from HashedRelation
| UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value});
| if ($matched != null) {
| ${buildColumns.map(_.code).mkString("\n")}
| $outputCode
| }
""".stripMargin
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashedRelation
|UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
|if ($matched != null) {
| ${buildVars.map(_.code).mkString("\n")}
| $outputCode
|}
""".stripMargin

} else {
val matches = ctx.freshName("matches")
val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
val i = ctx.freshName("i")
val size = ctx.freshName("size")
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashRelation
|$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
|if ($matches != null) {
| int $size = $matches.size();
| for (int $i = 0; $i < $size; $i++) {
| UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
| ${buildVars.map(_.code).mkString("\n")}
| $outputCode
| }
|}
""".stripMargin
}
}


/**
* Generates the code for left or right outer join.
*/
private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are lots of code duplication from codegenInner. Can we merge them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already tried hard to reduce the duplicated codes, it will be harder to understand if we have more small fragments.

val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val matched = ctx.freshName("matched")
val buildVars = genBuildSideVars(ctx, matched)
val resultVars = buildSide match {
case BuildLeft => buildVars ++ input
case BuildRight => input ++ buildVars
}
val numOutput = metricTerm(ctx, "numOutputRows")

// filter the output via condition
val conditionPassed = ctx.freshName("conditionPassed")
val checkCondition = if (condition.isDefined) {
ctx.currentVars = resultVars
val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
s"""
|boolean $conditionPassed = true;
|if ($matched != null) {
| ${ev.code}
| $conditionPassed = !${ev.isNull} && ${ev.value};
|}
""".stripMargin
} else {
s"final boolean $conditionPassed = true;"
}

if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashedRelation
|UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
|${buildVars.map(_.code).mkString("\n")}
|${checkCondition.trim}
|if (!$conditionPassed) {
| // reset to null
| ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")}
|}
|$numOutput.add(1);
|${consume(ctx, resultVars)}
""".stripMargin

} else {
val matches = ctx.freshName("matches")
val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
val i = ctx.freshName("i")
val size = ctx.freshName("size")
val found = ctx.freshName("found")
s"""
| // generate join key
| ${keyVal.code}
| // find matches from HashRelation
| $bufferType $matches = ${anyNull} ? null :
| ($bufferType) $relationTerm.get(${keyVal.value});
| if ($matches != null) {
| int $size = $matches.size();
| for (int $i = 0; $i < $size; $i++) {
| UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
| ${buildColumns.map(_.code).mkString("\n")}
| $outputCode
| }
| }
""".stripMargin
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashRelation
|$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
|int $size = $matches != null ? $matches.size() : 0;
|boolean $found = false;
|// the last iteration of this loop is to emit an empty row if there is no matched rows.
|for (int $i = 0; $i <= $size; $i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is clever, but i think you need to document it (i.e. you are adding an extra iteration at the end of the loop to handle null)

| UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null;
| ${buildVars.map(_.code).mkString("\n")}
| ${checkCondition.trim}
| if ($conditionPassed && ($i < $size || !$found)) {
| $found = true;
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
| }
|}
""".stripMargin
}
}
}
Expand Down
Loading