-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-13237] [SQL] generated broadcast outer join #11130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
52efe91
9525782
9a1f532
98cda0b
edbc284
da45df1
9b05c7c
1c0ee96
5724180
5744941
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,8 +24,9 @@ import org.apache.spark.TaskContext | |
| import org.apache.spark.broadcast.Broadcast | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow} | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} | ||
| import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter} | ||
| import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} | ||
| import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution} | ||
| import org.apache.spark.sql.execution.metric.SQLMetrics | ||
|
|
@@ -41,6 +42,7 @@ import org.apache.spark.util.collection.CompactBuffer | |
| case class BroadcastHashJoin( | ||
| leftKeys: Seq[Expression], | ||
| rightKeys: Seq[Expression], | ||
| joinType: JoinType, | ||
| buildSide: BuildSide, | ||
| condition: Option[Expression], | ||
| left: SparkPlan, | ||
|
|
@@ -105,75 +107,144 @@ case class BroadcastHashJoin( | |
| val broadcastRelation = Await.result(broadcastFuture, timeout) | ||
|
|
||
| streamedPlan.execute().mapPartitions { streamedIter => | ||
| val hashedRelation = broadcastRelation.value | ||
| TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) | ||
| hashJoin(streamedIter, hashedRelation, numOutputRows) | ||
| val joinedRow = new JoinedRow() | ||
| val hashTable = broadcastRelation.value | ||
| TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize) | ||
| val keyGenerator = streamSideKeyGenerator | ||
| val resultProj = createResultProjection | ||
|
|
||
| joinType match { | ||
| case Inner => | ||
| hashJoin(streamedIter, hashTable, numOutputRows) | ||
|
|
||
| case LeftOuter => | ||
| streamedIter.flatMap { currentRow => | ||
| val rowKey = keyGenerator(currentRow) | ||
| joinedRow.withLeft(currentRow) | ||
| leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) | ||
| } | ||
|
|
||
| case RightOuter => | ||
| streamedIter.flatMap { currentRow => | ||
| val rowKey = keyGenerator(currentRow) | ||
| joinedRow.withRight(currentRow) | ||
| rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) | ||
| } | ||
|
|
||
| case x => | ||
| throw new IllegalArgumentException( | ||
| s"BroadcastHashJoin should not take $x as the JoinType") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private var broadcastRelation: Broadcast[HashedRelation] = _ | ||
| // the term for hash relation | ||
| private var relationTerm: String = _ | ||
|
|
||
| override def upstream(): RDD[InternalRow] = { | ||
| streamedPlan.asInstanceOf[CodegenSupport].upstream() | ||
| } | ||
|
|
||
| override def doProduce(ctx: CodegenContext): String = { | ||
| streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) | ||
| } | ||
|
|
||
| override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { | ||
| if (joinType == Inner) { | ||
| codegenInner(ctx, input) | ||
| } else { | ||
| // LeftOuter and RightOuter | ||
| codegenOuter(ctx, input) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Returns a tuple of Broadcast of HashedRelation and the variable name for it. | ||
| */ | ||
| private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = { | ||
| // create a name for HashedRelation | ||
| broadcastRelation = Await.result(broadcastFuture, timeout) | ||
| val broadcastRelation = Await.result(broadcastFuture, timeout) | ||
| val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) | ||
| relationTerm = ctx.freshName("relation") | ||
| val relationTerm = ctx.freshName("relation") | ||
| val clsName = broadcastRelation.value.getClass.getName | ||
| ctx.addMutableState(clsName, relationTerm, | ||
| s""" | ||
| | $relationTerm = ($clsName) $broadcast.value(); | ||
| | incPeakExecutionMemory($relationTerm.getMemorySize()); | ||
| """.stripMargin) | ||
|
|
||
| s""" | ||
| | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)} | ||
| """.stripMargin | ||
| (broadcastRelation, relationTerm) | ||
| } | ||
|
|
||
| override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { | ||
| // generate the key as UnsafeRow or Long | ||
| /** | ||
| * Returns the code for generating join key for stream side, and expression of whether the key | ||
| * has any null in it or not. | ||
| */ | ||
| private def genStreamSideJoinKey( | ||
| ctx: CodegenContext, | ||
| input: Seq[ExprCode]): (ExprCode, String) = { | ||
| ctx.currentVars = input | ||
| val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) { | ||
| if (canJoinKeyFitWithinLong) { | ||
| // generate the join key as Long | ||
| val expr = rewriteKeyExpr(streamedKeys).head | ||
| val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) | ||
| (ev, ev.isNull) | ||
| } else { | ||
| // generate the join key as UnsafeRow | ||
| val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) | ||
| val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr) | ||
| (ev, s"${ev.value}.anyNull()") | ||
| } | ||
| } | ||
|
|
||
| // find the matches from HashedRelation | ||
| val matched = ctx.freshName("matched") | ||
|
|
||
| // create variables for output | ||
| /** | ||
| * Generates the code for variable of build side. | ||
| */ | ||
| private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = { | ||
| ctx.currentVars = null | ||
| ctx.INPUT_ROW = matched | ||
| val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => | ||
| BoundReference(i, a.dataType, a.nullable).gen(ctx) | ||
| buildPlan.output.zipWithIndex.map { case (a, i) => | ||
| val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx) | ||
| if (joinType == Inner) { | ||
| ev | ||
| } else { | ||
| // the variables are needed even there is no matched rows | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm really confused by this. What is this doing?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok never mind i think i get it now. |
||
| val isNull = ctx.freshName("isNull") | ||
| val value = ctx.freshName("value") | ||
| val code = s""" | ||
| |boolean $isNull = true; | ||
| |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; | ||
| |if ($matched != null) { | ||
| | ${ev.code} | ||
| | $isNull = ${ev.isNull}; | ||
| | $value = ${ev.value}; | ||
| |} | ||
| """.stripMargin | ||
| ExprCode(code, isNull, value) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Generates the code for Inner join. | ||
| */ | ||
| private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { | ||
| val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) | ||
| val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) | ||
| val matched = ctx.freshName("matched") | ||
| val buildVars = genBuildSideVars(ctx, matched) | ||
| val resultVars = buildSide match { | ||
| case BuildLeft => buildColumns ++ input | ||
| case BuildRight => input ++ buildColumns | ||
| case BuildLeft => buildVars ++ input | ||
| case BuildRight => input ++ buildVars | ||
| } | ||
|
|
||
| val numOutput = metricTerm(ctx, "numOutputRows") | ||
|
|
||
| val outputCode = if (condition.isDefined) { | ||
| // filter the output via condition | ||
| ctx.currentVars = resultVars | ||
| val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) | ||
| s""" | ||
| | ${ev.code} | ||
| | if (!${ev.isNull} && ${ev.value}) { | ||
| | $numOutput.add(1); | ||
| | ${consume(ctx, resultVars)} | ||
| | } | ||
| |${ev.code} | ||
| |if (!${ev.isNull} && ${ev.value}) { | ||
| | $numOutput.add(1); | ||
| | ${consume(ctx, resultVars)} | ||
| |} | ||
| """.stripMargin | ||
| } else { | ||
| s""" | ||
|
|
@@ -184,36 +255,110 @@ case class BroadcastHashJoin( | |
|
|
||
| if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { | ||
| s""" | ||
| | // generate join key | ||
| | ${keyVal.code} | ||
| | // find matches from HashedRelation | ||
| | UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value}); | ||
| | if ($matched != null) { | ||
| | ${buildColumns.map(_.code).mkString("\n")} | ||
| | $outputCode | ||
| | } | ||
| """.stripMargin | ||
| |// generate join key for stream side | ||
| |${keyEv.code} | ||
| |// find matches from HashedRelation | ||
| |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); | ||
| |if ($matched != null) { | ||
| | ${buildVars.map(_.code).mkString("\n")} | ||
| | $outputCode | ||
| |} | ||
| """.stripMargin | ||
|
|
||
| } else { | ||
| val matches = ctx.freshName("matches") | ||
| val bufferType = classOf[CompactBuffer[UnsafeRow]].getName | ||
| val i = ctx.freshName("i") | ||
| val size = ctx.freshName("size") | ||
| s""" | ||
| |// generate join key for stream side | ||
| |${keyEv.code} | ||
| |// find matches from HashRelation | ||
| |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); | ||
| |if ($matches != null) { | ||
| | int $size = $matches.size(); | ||
| | for (int $i = 0; $i < $size; $i++) { | ||
| | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); | ||
| | ${buildVars.map(_.code).mkString("\n")} | ||
| | $outputCode | ||
| | } | ||
| |} | ||
| """.stripMargin | ||
| } | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * Generates the code for left or right outer join. | ||
| */ | ||
| private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are lots of code duplication from codegenInner. Can we merge them?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I already tried hard to reduce the duplicated codes, it will be harder to understand if we have more small fragments. |
||
| val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) | ||
| val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) | ||
| val matched = ctx.freshName("matched") | ||
| val buildVars = genBuildSideVars(ctx, matched) | ||
| val resultVars = buildSide match { | ||
| case BuildLeft => buildVars ++ input | ||
| case BuildRight => input ++ buildVars | ||
| } | ||
| val numOutput = metricTerm(ctx, "numOutputRows") | ||
|
|
||
| // filter the output via condition | ||
| val conditionPassed = ctx.freshName("conditionPassed") | ||
| val checkCondition = if (condition.isDefined) { | ||
| ctx.currentVars = resultVars | ||
| val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) | ||
| s""" | ||
| |boolean $conditionPassed = true; | ||
| |if ($matched != null) { | ||
| | ${ev.code} | ||
| | $conditionPassed = !${ev.isNull} && ${ev.value}; | ||
| |} | ||
| """.stripMargin | ||
| } else { | ||
| s"final boolean $conditionPassed = true;" | ||
| } | ||
|
|
||
| if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { | ||
| s""" | ||
| |// generate join key for stream side | ||
| |${keyEv.code} | ||
| |// find matches from HashedRelation | ||
| |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); | ||
| |${buildVars.map(_.code).mkString("\n")} | ||
| |${checkCondition.trim} | ||
| |if (!$conditionPassed) { | ||
| | // reset to null | ||
| | ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")} | ||
| |} | ||
| |$numOutput.add(1); | ||
| |${consume(ctx, resultVars)} | ||
| """.stripMargin | ||
|
|
||
| } else { | ||
| val matches = ctx.freshName("matches") | ||
| val bufferType = classOf[CompactBuffer[UnsafeRow]].getName | ||
| val i = ctx.freshName("i") | ||
| val size = ctx.freshName("size") | ||
| val found = ctx.freshName("found") | ||
| s""" | ||
| | // generate join key | ||
| | ${keyVal.code} | ||
| | // find matches from HashRelation | ||
| | $bufferType $matches = ${anyNull} ? null : | ||
| | ($bufferType) $relationTerm.get(${keyVal.value}); | ||
| | if ($matches != null) { | ||
| | int $size = $matches.size(); | ||
| | for (int $i = 0; $i < $size; $i++) { | ||
| | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); | ||
| | ${buildColumns.map(_.code).mkString("\n")} | ||
| | $outputCode | ||
| | } | ||
| | } | ||
| """.stripMargin | ||
| |// generate join key for stream side | ||
| |${keyEv.code} | ||
| |// find matches from HashRelation | ||
| |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); | ||
| |int $size = $matches != null ? $matches.size() : 0; | ||
| |boolean $found = false; | ||
| |// the last iteration of this loop is to emit an empty row if there is no matched rows. | ||
| |for (int $i = 0; $i <= $size; $i++) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is clever, but i think you need to document it (i.e. you are adding an extra iteration at the end of the loop to handle null) |
||
| | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null; | ||
| | ${buildVars.map(_.code).mkString("\n")} | ||
| | ${checkCondition.trim} | ||
| | if ($conditionPassed && ($i < $size || !$found)) { | ||
| | $found = true; | ||
| | $numOutput.add(1); | ||
| | ${consume(ctx, resultVars)} | ||
| | } | ||
| |} | ||
| """.stripMargin | ||
| } | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
document what the fields in the return tuple mean