diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 558d990e8c4b..59a42d893192 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -50,6 +50,7 @@ trait CodegenSupport extends SparkPlan { private def variablePrefix: String = this match { case _: HashAggregateExec => "agg" case _: BroadcastHashJoinExec => "bhj" + case _: ShuffledHashJoinExec => "shj" case _: SortMergeJoinExec => "smj" case _: RDDScanExec => "rdd" case _: DataSourceScanExec => "scan" @@ -903,6 +904,10 @@ case class CollapseCodegenStages( // The children of SortMergeJoin should do codegen separately. j.withNewChildren(j.children.map( child => InputAdapter(insertWholeStageCodegen(child)))) + case j: ShuffledHashJoinExec => + // The children of ShuffledHashJoin should do codegen separately. + j.withNewChildren(j.children.map( + child => InputAdapter(insertWholeStageCodegen(child)))) case p => p.withNewChildren(p.children.map(insertInputAdapter)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 2a283013acee..e4935c8c7222 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -25,13 +25,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution} import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.{BooleanType, LongType} /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -197,23 +195,6 @@ case class BroadcastHashJoinExec( override def needCopyResult: Boolean = streamedPlan.asInstanceOf[CodegenSupport].needCopyResult || multipleOutputForOneInput - override def doProduce(ctx: CodegenContext): String = { - streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) - } - - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - joinType match { - case _: InnerLike => codegenInner(ctx, input) - case LeftOuter | RightOuter => codegenOuter(ctx, input) - case LeftSemi => codegenSemi(ctx, input) - case LeftAnti => codegenAnti(ctx, input) - case j: ExistenceJoin => codegenExistence(ctx, input) - case x => - throw new IllegalArgumentException( - s"BroadcastHashJoin should not take $x as the JoinType") - } - } - /** * Returns a tuple of Broadcast of HashedRelation and the variable name for it. */ @@ -232,411 +213,55 @@ case class BroadcastHashJoinExec( (broadcastRelation, relationTerm) } - /** - * Returns the code for generating join key for stream side, and expression of whether the key - * has any null in it or not. - */ - private def genStreamSideJoinKey( - ctx: CodegenContext, - input: Seq[ExprCode]): (ExprCode, String) = { - ctx.currentVars = input - if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType) { - // generate the join key as Long - val ev = streamedBoundKeys.head.genCode(ctx) - (ev, ev.isNull) - } else { - // generate the join key as UnsafeRow - val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys) - (ev, s"${ev.value}.anyNull()") - } - } - - /** - * Generates the code for variable of build side. - */ - private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = { - ctx.currentVars = null - ctx.INPUT_ROW = matched - buildPlan.output.zipWithIndex.map { case (a, i) => - val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx) - if (joinType.isInstanceOf[InnerLike]) { - ev - } else { - // the variables are needed even there is no matched rows - val isNull = ctx.freshName("isNull") - val value = ctx.freshName("value") - val javaType = CodeGenerator.javaType(a.dataType) - val code = code""" - |boolean $isNull = true; - |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; - |if ($matched != null) { - | ${ev.code} - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; - |} - """.stripMargin - ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)) - } - } - } - - /** - * Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi - * and Left Anti joins. - */ - private def getJoinCondition( - ctx: CodegenContext, - input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = { - val matched = ctx.freshName("matched") - val buildVars = genBuildSideVars(ctx, matched) - val checkCondition = if (condition.isDefined) { - val expr = condition.get - // evaluate the variables from build side that used by condition - val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) - // filter the output via condition - ctx.currentVars = input ++ buildVars - val ev = - BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) - val skipRow = s"${ev.isNull} || !${ev.value}" - s""" - |$eval - |${ev.code} - |if (!($skipRow)) - """.stripMargin - } else { - "" - } - (matched, checkCondition, buildVars) - } - - /** - * Generates the code for Inner join. - */ - private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input) - val numOutput = metricTerm(ctx, "numOutputRows") - - val resultVars = buildSide match { - case BuildLeft => buildVars ++ input - case BuildRight => input ++ buildVars - } - if (broadcastRelation.value.keyIsUnique) { - s""" - |// generate join key for stream side - |${keyEv.code} - |// find matches from HashedRelation - |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - |if ($matched != null) { - | $checkCondition { - | $numOutput.add(1); - | ${consume(ctx, resultVars)} - | } - |} - """.stripMargin - - } else { - val matches = ctx.freshName("matches") - val iteratorCls = classOf[Iterator[UnsafeRow]].getName - s""" - |// generate join key for stream side - |${keyEv.code} - |// find matches from HashRelation - |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); - |if ($matches != null) { - | while ($matches.hasNext()) { - | UnsafeRow $matched = (UnsafeRow) $matches.next(); - | $checkCondition { - | $numOutput.add(1); - | ${consume(ctx, resultVars)} - | } - | } - |} - """.stripMargin - } - } - - /** - * Generates the code for left or right outer join. - */ - private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val matched = ctx.freshName("matched") - val buildVars = genBuildSideVars(ctx, matched) - val numOutput = metricTerm(ctx, "numOutputRows") - - // filter the output via condition - val conditionPassed = ctx.freshName("conditionPassed") - val checkCondition = if (condition.isDefined) { - val expr = condition.get - // evaluate the variables from build side that used by condition - val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) - ctx.currentVars = input ++ buildVars - val ev = - BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) - s""" - |boolean $conditionPassed = true; - |${eval.trim} - |if ($matched != null) { - | ${ev.code} - | $conditionPassed = !${ev.isNull} && ${ev.value}; - |} - """.stripMargin - } else { - s"final boolean $conditionPassed = true;" - } - - val resultVars = buildSide match { - case BuildLeft => buildVars ++ input - case BuildRight => input ++ buildVars - } - if (broadcastRelation.value.keyIsUnique) { - s""" - |// generate join key for stream side - |${keyEv.code} - |// find matches from HashedRelation - |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - |${checkCondition.trim} - |if (!$conditionPassed) { - | $matched = null; - | // reset the variables those are already evaluated. - | ${buildVars.filter(_.code.isEmpty).map(v => s"${v.isNull} = true;").mkString("\n")} - |} - |$numOutput.add(1); - |${consume(ctx, resultVars)} - """.stripMargin - - } else { - val matches = ctx.freshName("matches") - val iteratorCls = classOf[Iterator[UnsafeRow]].getName - val found = ctx.freshName("found") - s""" - |// generate join key for stream side - |${keyEv.code} - |// find matches from HashRelation - |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); - |boolean $found = false; - |// the last iteration of this loop is to emit an empty row if there is no matched rows. - |while ($matches != null && $matches.hasNext() || !$found) { - | UnsafeRow $matched = $matches != null && $matches.hasNext() ? - | (UnsafeRow) $matches.next() : null; - | ${checkCondition.trim} - | if ($conditionPassed) { - | $found = true; - | $numOutput.add(1); - | ${consume(ctx, resultVars)} - | } - |} - """.stripMargin - } - } - - /** - * Generates the code for left semi join. - */ - private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = { + protected override def prepareRelation(ctx: CodegenContext): (String, Boolean) = { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val (matched, checkCondition, _) = getJoinCondition(ctx, input) - val numOutput = metricTerm(ctx, "numOutputRows") - if (broadcastRelation.value.keyIsUnique) { - s""" - |// generate join key for stream side - |${keyEv.code} - |// find matches from HashedRelation - |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - |if ($matched != null) { - | $checkCondition { - | $numOutput.add(1); - | ${consume(ctx, input)} - | } - |} - """.stripMargin - } else { - val matches = ctx.freshName("matches") - val iteratorCls = classOf[Iterator[UnsafeRow]].getName - val found = ctx.freshName("found") - s""" - |// generate join key for stream side - |${keyEv.code} - |// find matches from HashRelation - |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); - |if ($matches != null) { - | boolean $found = false; - | while (!$found && $matches.hasNext()) { - | UnsafeRow $matched = (UnsafeRow) $matches.next(); - | $checkCondition { - | $found = true; - | } - | } - | if ($found) { - | $numOutput.add(1); - | ${consume(ctx, input)} - | } - |} - """.stripMargin - } + (relationTerm, broadcastRelation.value.keyIsUnique) } /** * Generates the code for anti join. + * Handles NULL-aware anti join (NAAJ) separately here. */ - private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - val uniqueKeyCodePath = broadcastRelation.value.keyIsUnique - val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val (matched, checkCondition, _) = getJoinCondition(ctx, input) - val numOutput = metricTerm(ctx, "numOutputRows") - + protected override def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = { if (isNullAwareAntiJoin) { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, _, _) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + if (broadcastRelation.value == EmptyHashedRelation) { - return s""" - |// If the right side is empty, NAAJ simply returns the left side. - |$numOutput.add(1); - |${consume(ctx, input)} - """.stripMargin + s""" + |// If the right side is empty, NAAJ simply returns the left side. + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin } else if (broadcastRelation.value == EmptyHashedRelationWithAllNullKeys) { - return s""" - |// If the right side contains any all-null key, NAAJ simply returns Nothing. - """.stripMargin + s""" + |// If the right side contains any all-null key, NAAJ simply returns Nothing. + """.stripMargin } else { val found = ctx.freshName("found") - return s""" - |boolean $found = false; - |// generate join key for stream side - |${keyEv.code} - |if ($anyNull) { - | $found = true; - |} else { - | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - | if ($matched != null) { - | $found = true; - | } - |} - | - |if (!$found) { - | $numOutput.add(1); - | ${consume(ctx, input)} - |} - """.stripMargin + s""" + |boolean $found = false; + |// generate join key for stream side + |${keyEv.code} + |if ($anyNull) { + | $found = true; + |} else { + | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + | if ($matched != null) { + | $found = true; + | } + |} + | + |if (!$found) { + | $numOutput.add(1); + | ${consume(ctx, input)} + |} + """.stripMargin } - } - - if (uniqueKeyCodePath) { - val found = ctx.freshName("found") - s""" - |boolean $found = false; - |// generate join key for stream side - |${keyEv.code} - |// Check if the key has nulls. - |if (!($anyNull)) { - | // Check if the HashedRelation exists. - | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - | if ($matched != null) { - | // Evaluate the condition. - | $checkCondition { - | $found = true; - | } - | } - |} - |if (!$found) { - | $numOutput.add(1); - | ${consume(ctx, input)} - |} - """.stripMargin - } else { - val matches = ctx.freshName("matches") - val iteratorCls = classOf[Iterator[UnsafeRow]].getName - val found = ctx.freshName("found") - s""" - |boolean $found = false; - |// generate join key for stream side - |${keyEv.code} - |// Check if the key has nulls. - |if (!($anyNull)) { - | // Check if the HashedRelation exists. - | $iteratorCls $matches = ($iteratorCls)$relationTerm.get(${keyEv.value}); - | if ($matches != null) { - | // Evaluate the condition. - | while (!$found && $matches.hasNext()) { - | UnsafeRow $matched = (UnsafeRow) $matches.next(); - | $checkCondition { - | $found = true; - | } - | } - | } - |} - |if (!$found) { - | $numOutput.add(1); - | ${consume(ctx, input)} - |} - """.stripMargin - } - } - - /** - * Generates the code for existence join. - */ - private def codegenExistence(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val numOutput = metricTerm(ctx, "numOutputRows") - val existsVar = ctx.freshName("exists") - - val matched = ctx.freshName("matched") - val buildVars = genBuildSideVars(ctx, matched) - val checkCondition = if (condition.isDefined) { - val expr = condition.get - // evaluate the variables from build side that used by condition - val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) - // filter the output via condition - ctx.currentVars = input ++ buildVars - val ev = - BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) - s""" - |$eval - |${ev.code} - |$existsVar = !${ev.isNull} && ${ev.value}; - """.stripMargin - } else { - s"$existsVar = true;" - } - - val resultVar = input ++ Seq(ExprCode.forNonNullValue( - JavaCode.variable(existsVar, BooleanType))) - if (broadcastRelation.value.keyIsUnique) { - s""" - |// generate join key for stream side - |${keyEv.code} - |// find matches from HashedRelation - |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - |boolean $existsVar = false; - |if ($matched != null) { - | $checkCondition - |} - |$numOutput.add(1); - |${consume(ctx, resultVar)} - """.stripMargin } else { - val matches = ctx.freshName("matches") - val iteratorCls = classOf[Iterator[UnsafeRow]].getName - s""" - |// generate join key for stream side - |${keyEv.code} - |// find matches from HashRelation - |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); - |boolean $existsVar = false; - |if ($matches != null) { - | while (!$existsVar && $matches.hasNext()) { - | UnsafeRow $matched = (UnsafeRow) $matches.next(); - | $checkCondition - | } - |} - |$numOutput.add(1); - |${consume(ctx, resultVar)} - """.stripMargin + super.codegenAnti(ctx, input) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 4f22007b6584..1c6504b14189 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -20,14 +20,16 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{ExplainUtils, RowIterator} +import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils, RowIterator} import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.types.{IntegralType, LongType} +import org.apache.spark.sql.types.{BooleanType, IntegralType, LongType} -trait HashJoin extends BaseJoinExec { +trait HashJoin extends BaseJoinExec with CodegenSupport { def buildSide: BuildSide override def simpleStringWithNodeId(): String = { @@ -316,6 +318,409 @@ trait HashJoin extends BaseJoinExec { resultProj(r) } } + + override def doProduce(ctx: CodegenContext): String = { + streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + joinType match { + case _: InnerLike => codegenInner(ctx, input) + case LeftOuter | RightOuter => codegenOuter(ctx, input) + case LeftSemi => codegenSemi(ctx, input) + case LeftAnti => codegenAnti(ctx, input) + case _: ExistenceJoin => codegenExistence(ctx, input) + case x => + throw new IllegalArgumentException( + s"HashJoin should not take $x as the JoinType") + } + } + + /** + * Returns the code for generating join key for stream side, and expression of whether the key + * has any null in it or not. + */ + protected def genStreamSideJoinKey( + ctx: CodegenContext, + input: Seq[ExprCode]): (ExprCode, String) = { + ctx.currentVars = input + if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType) { + // generate the join key as Long + val ev = streamedBoundKeys.head.genCode(ctx) + (ev, ev.isNull) + } else { + // generate the join key as UnsafeRow + val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys) + (ev, s"${ev.value}.anyNull()") + } + } + + /** + * Generates the code for variable of build side. + */ + private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = { + ctx.currentVars = null + ctx.INPUT_ROW = matched + buildPlan.output.zipWithIndex.map { case (a, i) => + val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx) + if (joinType.isInstanceOf[InnerLike]) { + ev + } else { + // the variables are needed even there is no matched rows + val isNull = ctx.freshName("isNull") + val value = ctx.freshName("value") + val javaType = CodeGenerator.javaType(a.dataType) + val code = code""" + |boolean $isNull = true; + |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; + |if ($matched != null) { + | ${ev.code} + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + |} + """.stripMargin + ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)) + } + } + } + + /** + * Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi + * and Left Anti joins. + */ + protected def getJoinCondition( + ctx: CodegenContext, + input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = { + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val checkCondition = if (condition.isDefined) { + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + // filter the output via condition + ctx.currentVars = input ++ buildVars + val ev = + BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) + val skipRow = s"${ev.isNull} || !${ev.value}" + s""" + |$eval + |${ev.code} + |if (!($skipRow)) + """.stripMargin + } else { + "" + } + (matched, checkCondition, buildVars) + } + + /** + * Generates the code for Inner join. + */ + protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (relationTerm, keyIsUnique) = prepareRelation(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } + + if (keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |if ($matched != null) { + | $checkCondition { + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + |} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$iteratorCls $matches = $anyNull ? + | null : ($iteratorCls)$relationTerm.get(${keyEv.value}); + |if ($matches != null) { + | while ($matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition { + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + | } + |} + """.stripMargin + } + } + + /** + * Generates the code for left or right outer join. + */ + protected def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (relationTerm, keyIsUnique) = prepareRelation(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val numOutput = metricTerm(ctx, "numOutputRows") + + // filter the output via condition + val conditionPassed = ctx.freshName("conditionPassed") + val checkCondition = if (condition.isDefined) { + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + ctx.currentVars = input ++ buildVars + val ev = + BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) + s""" + |boolean $conditionPassed = true; + |${eval.trim} + |if ($matched != null) { + | ${ev.code} + | $conditionPassed = !${ev.isNull} && ${ev.value}; + |} + """.stripMargin + } else { + s"final boolean $conditionPassed = true;" + } + + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } + + if (keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |${checkCondition.trim} + |if (!$conditionPassed) { + | $matched = null; + | // reset the variables those are already evaluated. + | ${buildVars.filter(_.code.isEmpty).map(v => s"${v.isNull} = true;").mkString("\n")} + |} + |$numOutput.add(1); + |${consume(ctx, resultVars)} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + val found = ctx.freshName("found") + + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); + |boolean $found = false; + |// the last iteration of this loop is to emit an empty row if there is no matched rows. + |while ($matches != null && $matches.hasNext() || !$found) { + | UnsafeRow $matched = $matches != null && $matches.hasNext() ? + | (UnsafeRow) $matches.next() : null; + | ${checkCondition.trim} + | if ($conditionPassed) { + | $found = true; + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + |} + """.stripMargin + } + } + + /** + * Generates the code for left semi join. + */ + protected def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (relationTerm, keyIsUnique) = prepareRelation(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, _) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + + if (keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |if ($matched != null) { + | $checkCondition { + | $numOutput.add(1); + | ${consume(ctx, input)} + | } + |} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + val found = ctx.freshName("found") + + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); + |if ($matches != null) { + | boolean $found = false; + | while (!$found && $matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition { + | $found = true; + | } + | } + | if ($found) { + | $numOutput.add(1); + | ${consume(ctx, input)} + | } + |} + """.stripMargin + } + } + + /** + * Generates the code for anti join. + */ + protected def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (relationTerm, keyIsUnique) = prepareRelation(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, _) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + + if (keyIsUnique) { + val found = ctx.freshName("found") + s""" + |boolean $found = false; + |// generate join key for stream side + |${keyEv.code} + |// Check if the key has nulls. + |if (!($anyNull)) { + | // Check if the HashedRelation exists. + | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + | if ($matched != null) { + | // Evaluate the condition. + | $checkCondition { + | $found = true; + | } + | } + |} + |if (!$found) { + | $numOutput.add(1); + | ${consume(ctx, input)} + |} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + val found = ctx.freshName("found") + s""" + |boolean $found = false; + |// generate join key for stream side + |${keyEv.code} + |// Check if the key has nulls. + |if (!($anyNull)) { + | // Check if the HashedRelation exists. + | $iteratorCls $matches = ($iteratorCls)$relationTerm.get(${keyEv.value}); + | if ($matches != null) { + | // Evaluate the condition. + | while (!$found && $matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition { + | $found = true; + | } + | } + | } + |} + |if (!$found) { + | $numOutput.add(1); + | ${consume(ctx, input)} + |} + """.stripMargin + } + } + + /** + * Generates the code for existence join. + */ + protected def codegenExistence(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (relationTerm, keyIsUnique) = prepareRelation(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + val existsVar = ctx.freshName("exists") + + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val checkCondition = if (condition.isDefined) { + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + // filter the output via condition + ctx.currentVars = input ++ buildVars + val ev = + BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) + s""" + |$eval + |${ev.code} + |$existsVar = !${ev.isNull} && ${ev.value}; + """.stripMargin + } else { + s"$existsVar = true;" + } + + val resultVar = input ++ Seq(ExprCode.forNonNullValue( + JavaCode.variable(existsVar, BooleanType))) + + if (keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |boolean $existsVar = false; + |if ($matched != null) { + | $checkCondition + |} + |$numOutput.add(1); + |${consume(ctx, resultVar)} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); + |boolean $existsVar = false; + |if ($matches != null) { + | while (!$existsVar && $matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition + | } + |} + |$numOutput.add(1); + |${consume(ctx, resultVar)} + """.stripMargin + } + } + + /** + * Returns a tuple of variable name for HashedRelation, + * and a boolean to indicate whether keys of HashedRelation + * known to be unique in code-gen time. + */ + protected def prepareRelation(ctx: CodegenContext): (String, Boolean) } object HashJoin { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 3b398dd7120c..9f811cddef6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -49,7 +50,10 @@ case class ShuffledHashJoinExec( override def outputPartitioning: Partitioning = super[ShuffledJoin].outputPartitioning - private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { + /** + * This is called by generated Java class, should be public. + */ + def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val buildDataSize = longMetric("buildDataSize") val buildTime = longMetric("buildTime") val start = System.nanoTime() @@ -70,4 +74,20 @@ case class ShuffledHashJoinExec( join(streamIter, hashed, numOutputRows) } } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + streamedPlan.execute() :: buildPlan.execute() :: Nil + } + + override def needCopyResult: Boolean = true + + protected override def prepareRelation(ctx: CodegenContext): (String, Boolean) = { + val thisPlan = ctx.addReferenceObj("plan", this) + val clsName = classOf[HashedRelation].getName + + // Inline mutable state since not many join operations in a task + val relationTerm = ctx.addMutableState(clsName, "relation", + v => s"$v = $thisPlan.buildHashedRelation(inputs[1]);", forceInline = true) + (relationTerm, false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 03596d8654c6..fe40d7dce344 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAnd import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec -import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec -import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -71,6 +70,31 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) } + test("ShuffledHashJoin should be included in WholeStageCodegen") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "30", + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { + val df1 = spark.range(5).select($"id".as("k1")) + val df2 = spark.range(15).select($"id".as("k2")) + val df3 = spark.range(6).select($"id".as("k3")) + + // test one shuffled hash join + val oneJoinDF = df1.join(df2, $"k1" === $"k2") + assert(oneJoinDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true + }.size === 1) + checkAnswer(oneJoinDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4))) + + // test two shuffled hash joins + val twoJoinsDF = df1.join(df2, $"k1" === $"k2").join(df3, $"k1" === $"k3") + assert(twoJoinsDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true + }.size === 2) + checkAnswer(twoJoinsDF, + Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 4))) + } + } + test("Sort should be included in WholeStageCodegen") { val df = spark.range(3, 0, -1).toDF().sort(col("id")) val plan = df.queryExecution.executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 50652690339a..078a3ba029e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -346,8 +346,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils val rightDf = (1 to 10).map(i => (i, i.toString)).toSeq.toDF("key2", "value") Seq((0L, "right_outer", leftDf, rightDf, 10L, false), (0L, "left_outer", rightDf, leftDf, 10L, false), - (0L, "right_outer", leftDf, rightDf, 10L, true), - (0L, "left_outer", rightDf, leftDf, 10L, true), + (1L, "right_outer", leftDf, rightDf, 10L, true), + (1L, "left_outer", rightDf, leftDf, 10L, true), (2L, "left_anti", rightDf, leftDf, 8L, true), (2L, "left_semi", rightDf, leftDf, 2L, true), (1L, "left_anti", rightDf, leftDf, 8L, false),