From 52efe91168a4be7ce721d2f56e2b1e7aab9379db Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 8 Feb 2016 16:33:44 -0800 Subject: [PATCH 1/8] generated broadcast outer join --- .../spark/sql/execution/SparkStrategies.scala | 16 +- .../sql/execution/WholeStageCodegen.scala | 13 +- .../execution/joins/BroadcastHashJoin.scala | 146 ++++++++++++++++- .../joins/BroadcastHashOuterJoin.scala | 145 ----------------- .../spark/sql/execution/joins/HashJoin.scala | 95 ++++++++++- .../sql/execution/joins/HashOuterJoin.scala | 153 ------------------ .../org/apache/spark/sql/JoinSuite.scala | 5 +- .../BenchmarkWholeStageCodegen.scala | 11 ++ .../execution/joins/BroadcastJoinSuite.scala | 2 +- .../sql/execution/joins/InnerJoinSuite.scala | 2 +- .../sql/execution/joins/OuterJoinSuite.scala | 7 +- 11 files changed, 265 insertions(+), 330 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ee392e4e8deb..4f80bd8e5ecc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -83,12 +83,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Inner joins -------------------------------------------------------------------------- case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - joins.BroadcastHashJoin( - leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, Inner, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => - joins.BroadcastHashJoin( - leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => @@ -99,13 +99,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys( LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys( RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 4ca2d85406bb..d6c915138401 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight} -import org.apache.spark.util.Utils /** * An interface for those physical operators that support codegen. @@ -38,7 +37,7 @@ trait CodegenSupport extends SparkPlan { /** Prefix used in the current operator's variable names. */ private def variablePrefix: String = this match { case _: TungstenAggregate => "agg" - case _: BroadcastHashJoin => "bhj" + case _: BroadcastHashJoin => "join" case _ => nodeName.toLowerCase } @@ -366,17 +365,13 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru def apply(plan: SparkPlan): SparkPlan = { if (sqlContext.conf.wholeStageEnabled) { plan.transform { - case plan: CodegenSupport if supportCodegen(plan) && - // Whole stage codegen is only useful when there are at least two levels of operators that - // support it (save at least one projection/iterator). - (Utils.isTesting || plan.children.exists(supportCodegen)) => - + case plan: CodegenSupport if supportCodegen(plan) => var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { // The build side can't be compiled together - case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) => + case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) => b.copy(left = apply(left)) - case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) => + case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) => b.copy(right = apply(right)) case p if !supportCodegen(p) => val input = apply(p) // collapse them recursively diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index cbd549763ac9..df7173e79b4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -24,10 +24,11 @@ import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SQLExecution, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ThreadUtils import org.apache.spark.util.collection.CompactBuffer @@ -41,6 +42,7 @@ import org.apache.spark.util.collection.CompactBuffer case class BroadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + joinType: JoinType, buildSide: BuildSide, condition: Option[Expression], left: SparkPlan, @@ -117,9 +119,36 @@ case class BroadcastHashJoin( val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => - val hashedRelation = broadcastRelation.value - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) - hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) + val joinedRow = new JoinedRow() + val hashTable = broadcastRelation.value + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize) + val keyGenerator = streamSideKeyGenerator + val resultProj = createResultProjection + + joinType match { + case Inner => + hashJoin(streamedIter, numStreamedRows, hashTable, numOutputRows) + + case LeftOuter => + streamedIter.flatMap(currentRow => { + numStreamedRows += 1 + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) + }) + + case RightOuter => + streamedIter.flatMap(currentRow => { + numStreamedRows += 1 + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) + }) + + case x => + throw new IllegalArgumentException( + s"BroadcastHashJoin should not take $x as the JoinType") + } } } @@ -149,6 +178,14 @@ case class BroadcastHashJoin( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + if (joinType == Inner) { + doConsumeInnerJoin(ctx, input) + } else { + doConsumeOuterJoin(ctx, input) + } + } + + private def doConsumeInnerJoin(ctx: CodegenContext, input: Seq[ExprCode]): String = { // generate the key as UnsafeRow or Long ctx.currentVars = input val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) { @@ -223,6 +260,105 @@ case class BroadcastHashJoin( """.stripMargin } } + + private def doConsumeOuterJoin(ctx: CodegenContext, input: Seq[ExprCode]): String = { + // generate the key as UnsafeRow or Long + ctx.currentVars = input + val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) { + val expr = rewriteKeyExpr(streamedKeys).head + val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) + (ev, ev.isNull) + } else { + val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) + val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr) + (ev, s"${ev.value}.anyNull()") + } + + // find the matches from HashedRelation + val matched = ctx.freshName("matched") + val valid = ctx.freshName("invalid") + + // create variables for output + ctx.currentVars = null + ctx.INPUT_ROW = matched + val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => + val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx) + val isNull = ctx.freshName("isNull") + val value = ctx.freshName("value") + val code = s""" + |boolean $isNull = true; + |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; + |if ($matched != null) { + | ${ev.code} + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + |} + """.stripMargin + ExprCode(code, isNull, value) + } + + // output variables + val resultVars = buildSide match { + case BuildLeft => buildColumns ++ input + case BuildRight => input ++ buildColumns + } + + // filter the output via condition + val checkCondition = if (condition.isDefined) { + ctx.currentVars = resultVars + val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + s""" + |boolean $valid = true; + |if ($matched != null) { + | ${ev.code} + | $valid = !${ev.isNull} && ${ev.value}; + |} + """.stripMargin + } else { + s"final boolean $valid = true;" + } + + if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + s""" + |// generate join key + |${keyVal.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value}); + |${buildColumns.map(_.code).mkString("\n")} + |${checkCondition.trim} + |if (!$valid) { + | // reset to null + | ${buildColumns.map(v => s"${v.isNull} = true;").mkString("\n")} + |} + |${consume(ctx, resultVars)} + """.stripMargin + + } else { + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + val found = ctx.freshName("found") + s""" + |// generate join key + |${keyVal.code} + |// find matches from HashRelation + |$bufferType $matches = $anyNull ? null : + | ($bufferType) $relationTerm.get(${keyVal.value}); + |int $size = $matches != null ? $matches.size() : 0; + |boolean $found = false; + |for (int $i = 0; $i <= $size; $i++) { + | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null; + | ${buildColumns.map(_.code).mkString("\n")} + | ${checkCondition.trim} + | if ($valid && ($i < $size || !$found)) { + | $found = true; + | ${consume(ctx, resultVars)} + | } + |} + """.stripMargin + } + } } object BroadcastHashJoin { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala deleted file mode 100644 index ad3275696e63..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import scala.concurrent._ -import scala.concurrent.duration._ - -import org.apache.spark.{InternalAccumulator, TaskContext} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Performs a outer hash join for two child relations. When the output RDD of this operator is - * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed - * relation is not shuffled. - */ -case class BroadcastHashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashOuterJoin { - - override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - val timeout = { - val timeoutValue = sqlContext.conf.broadcastTimeout - if (timeoutValue < 0) { - Duration.Inf - } else { - timeoutValue.seconds - } - } - - override def requiredChildDistribution: Seq[Distribution] = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - - // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value - // for the same query. - @transient - private lazy val broadcastFuture = { - val numBuildRows = joinType match { - case RightOuter => longMetric("numLeftRows") - case LeftOuter => longMetric("numRightRows") - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } - - // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - Future { - // This will run in another thread. Set the execution id so that we can connect these jobs - // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { - // Note that we use .execute().collect() because we don't want to convert data to Scala - // types - val input: Array[InternalRow] = buildPlan.execute().map { row => - numBuildRows += 1 - row.copy() - }.collect() - // The following line doesn't run in a job so we cannot track the metric value. However, we - // have already tracked it in the above lines. So here we can use - // `SQLMetrics.nullLongMetric` to ignore it. - val hashed = HashedRelation( - input.iterator, SQLMetrics.nullLongMetric, buildKeyGenerator, input.size) - sparkContext.broadcast(hashed) - } - }(BroadcastHashJoin.broadcastHashJoinExecutionContext) - } - - protected override def doPrepare(): Unit = { - broadcastFuture - } - - override def doExecute(): RDD[InternalRow] = { - val numStreamedRows = joinType match { - case RightOuter => longMetric("numRightRows") - case LeftOuter => longMetric("numLeftRows") - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } - val numOutputRows = longMetric("numOutputRows") - - val broadcastRelation = Await.result(broadcastFuture, timeout) - - streamedPlan.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow() - val hashTable = broadcastRelation.value - val keyGenerator = streamedKeyGenerator - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize) - - val resultProj = resultProjection - joinType match { - case LeftOuter => - streamedIter.flatMap(currentRow => { - numStreamedRows += 1 - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) - }) - - case RightOuter => - streamedIter.flatMap(currentRow => { - numStreamedRows += 1 - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) - }) - - case x => - throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index ecbb1ac64b7c..3fc63d6c7ad8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -21,20 +21,38 @@ import java.util.NoSuchElementException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.sql.types.{IntegralType, LongType} +import org.apache.spark.util.collection.CompactBuffer trait HashJoin { self: SparkPlan => val leftKeys: Seq[Expression] val rightKeys: Seq[Expression] + val joinType: JoinType val buildSide: BuildSide val condition: Option[Expression] val left: SparkPlan val right: SparkPlan + override def output: Seq[Attribute] = { + joinType match { + case Inner => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType") + } + } + protected lazy val (buildPlan, streamedPlan) = buildSide match { case BuildLeft => (left, right) case BuildRight => (right, left) @@ -45,8 +63,6 @@ trait HashJoin { case BuildRight => (rightKeys, leftKeys) } - override def output: Seq[Attribute] = left.output ++ right.output - /** * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. * @@ -97,6 +113,9 @@ trait HashJoin { (r: InternalRow) => true } + protected def createResultProjection: (InternalRow) => InternalRow = + UnsafeProjection.create(self.schema) + protected def hashJoin( streamIter: Iterator[InternalRow], numStreamRows: LongSQLMetric, @@ -110,8 +129,7 @@ trait HashJoin { // Mutable per row objects. private[this] val joinRow = new JoinedRow - private[this] val resultProjection: (InternalRow) => InternalRow = - UnsafeProjection.create(self.schema) + private[this] val resultProjection = createResultProjection private[this] val joinKeys = streamSideKeyGenerator @@ -165,4 +183,73 @@ trait HashJoin { } } } + + @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() + + @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) + @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) + + protected[this] def leftOuterIterator( + key: InternalRow, + joinedRow: JoinedRow, + rightIter: Iterable[InternalRow], + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + val ret: Iterable[InternalRow] = { + if (!key.anyNull) { + val temp = if (rightIter != null) { + rightIter.collect { + case r if boundCondition(joinedRow.withRight(r)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } + } + } else { + List.empty + } + if (temp.isEmpty) { + numOutputRows += 1 + resultProjection(joinedRow.withRight(rightNullRow)) :: Nil + } else { + temp + } + } else { + numOutputRows += 1 + resultProjection(joinedRow.withRight(rightNullRow)) :: Nil + } + } + ret.iterator + } + + protected[this] def rightOuterIterator( + key: InternalRow, + leftIter: Iterable[InternalRow], + joinedRow: JoinedRow, + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + val ret: Iterable[InternalRow] = { + if (!key.anyNull) { + val temp = if (leftIter != null) { + leftIter.collect { + case l if boundCondition(joinedRow.withLeft(l)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } + } + } else { + List.empty + } + if (temp.isEmpty) { + numOutputRows += 1 + resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil + } else { + temp + } + } else { + numOutputRows += 1 + resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil + } + } + ret.iterator + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala deleted file mode 100644 index 9e614309de12..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.util.collection.CompactBuffer - - -trait HashOuterJoin { - self: SparkPlan => - - val leftKeys: Seq[Expression] - val rightKeys: Seq[Expression] - val joinType: JoinType - val condition: Option[Expression] - val left: SparkPlan - val right: SparkPlan - - override def output: Seq[Attribute] = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - } - - protected[this] lazy val (buildPlan, streamedPlan) = joinType match { - case RightOuter => (left, right) - case LeftOuter => (right, left) - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } - - protected[this] lazy val (buildKeys, streamedKeys) = joinType match { - case RightOuter => (leftKeys, rightKeys) - case LeftOuter => (rightKeys, leftKeys) - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } - - protected def buildKeyGenerator: Projection = - UnsafeProjection.create(buildKeys, buildPlan.output) - - protected[this] def streamedKeyGenerator: Projection = - UnsafeProjection.create(streamedKeys, streamedPlan.output) - - protected[this] def resultProjection: InternalRow => InternalRow = - UnsafeProjection.create(output, output) - - @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) - @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() - - @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) - @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) - @transient private[this] lazy val boundCondition = if (condition.isDefined) { - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - } else { - (row: InternalRow) => true - } - - // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala - // iterator for performance purpose. - - protected[this] def leftOuterIterator( - key: InternalRow, - joinedRow: JoinedRow, - rightIter: Iterable[InternalRow], - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (rightIter != null) { - rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => { - numOutputRows += 1 - resultProjection(joinedRow).copy() - } - } - } else { - List.empty - } - if (temp.isEmpty) { - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) :: Nil - } else { - temp - } - } else { - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) :: Nil - } - } - ret.iterator - } - - protected[this] def rightOuterIterator( - key: InternalRow, - leftIter: Iterable[InternalRow], - joinedRow: JoinedRow, - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (leftIter != null) { - leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => { - numOutputRows += 1 - resultProjection(joinedRow).copy() - } - } - } else { - List.empty - } - if (temp.isEmpty) { - numOutputRows += 1 - resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil - } else { - temp - } - } else { - numOutputRows += 1 - resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil - } - } - ret.iterator - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 9a3c262e9485..92ff7e73fad8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -46,7 +46,6 @@ class JoinSuite extends QueryTest with SharedSQLContext { val operators = physical.collect { case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j - case j: BroadcastHashOuterJoin => j case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j @@ -123,9 +122,9 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), + classOf[BroadcastHashJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[BroadcastHashOuterJoin]) + classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } sql("UNCACHE TABLE testData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index f015d297048a..9578fca17de4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -155,6 +155,17 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Join w 2 ints codegen=true 1265 / 1424 82.0 12.2 9.0X */ + runBenchmark("outer join w long", N) { + sqlContext.range(N).join(dim, (col("id") % 60000) === col("k"), "left").count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + outer join w long codegen=false 19438 / 19879 5.4 185.4 1.0X + outer join w long codegen=true 1098 / 1129 95.5 10.5 17.7X + */ } ignore("hash and BytesToBytesMap") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index aee8e84db56e..e25b5e0610ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -73,7 +73,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { } test("unsafe broadcast hash outer join updates peak execution memory") { - testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") + testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash outer join", "left_outer") } test("unsafe broadcast left semi join updates peak execution memory") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 149f34dbd748..e22a810a6b42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -88,7 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan, side: BuildSide) = { - joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan) + joins.BroadcastHashJoin(leftKeys, rightKeys, Inner, side, boundCondition, leftPlan, rightPlan) } def makeSortMergeJoin( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 3d3e9a7b9092..e102ffe0fb01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -76,10 +76,15 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { if (joinType != FullOuter) { test(s"$testName using BroadcastHashOuterJoin") { + val buildSide = joinType match { + case LeftOuter => BuildRight + case RightOuter => BuildLeft + } extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + BroadcastHashJoin( + leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } From 9525782c971f52c1343830402276086cd0e4ae8f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 9 Feb 2016 10:18:45 -0800 Subject: [PATCH 2/8] refactor --- .../execution/joins/BroadcastHashJoin.scala | 209 ++++++++---------- 1 file changed, 90 insertions(+), 119 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index df7173e79b4f..92858ec9bc0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -178,90 +178,6 @@ case class BroadcastHashJoin( } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { - if (joinType == Inner) { - doConsumeInnerJoin(ctx, input) - } else { - doConsumeOuterJoin(ctx, input) - } - } - - private def doConsumeInnerJoin(ctx: CodegenContext, input: Seq[ExprCode]): String = { - // generate the key as UnsafeRow or Long - ctx.currentVars = input - val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) { - val expr = rewriteKeyExpr(streamedKeys).head - val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) - (ev, ev.isNull) - } else { - val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) - val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr) - (ev, s"${ev.value}.anyNull()") - } - - // find the matches from HashedRelation - val matched = ctx.freshName("matched") - - // create variables for output - ctx.currentVars = null - ctx.INPUT_ROW = matched - val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => - BoundReference(i, a.dataType, a.nullable).gen(ctx) - } - val resultVars = buildSide match { - case BuildLeft => buildColumns ++ input - case BuildRight => input ++ buildColumns - } - - val outputCode = if (condition.isDefined) { - // filter the output via condition - ctx.currentVars = resultVars - val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) - s""" - | ${ev.code} - | if (!${ev.isNull} && ${ev.value}) { - | ${consume(ctx, resultVars)} - | } - """.stripMargin - } else { - consume(ctx, resultVars) - } - - if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { - s""" - | // generate join key - | ${keyVal.code} - | // find matches from HashedRelation - | UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value}); - | if ($matched != null) { - | ${buildColumns.map(_.code).mkString("\n")} - | $outputCode - | } - """.stripMargin - - } else { - val matches = ctx.freshName("matches") - val bufferType = classOf[CompactBuffer[UnsafeRow]].getName - val i = ctx.freshName("i") - val size = ctx.freshName("size") - s""" - | // generate join key - | ${keyVal.code} - | // find matches from HashRelation - | $bufferType $matches = ${anyNull} ? null : - | ($bufferType) $relationTerm.get(${keyVal.value}); - | if ($matches != null) { - | int $size = $matches.size(); - | for (int $i = 0; $i < $size; $i++) { - | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); - | ${buildColumns.map(_.code).mkString("\n")} - | $outputCode - | } - | } - """.stripMargin - } - } - - private def doConsumeOuterJoin(ctx: CodegenContext, input: Seq[ExprCode]): String = { // generate the key as UnsafeRow or Long ctx.currentVars = input val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) { @@ -283,18 +199,22 @@ case class BroadcastHashJoin( ctx.INPUT_ROW = matched val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx) - val isNull = ctx.freshName("isNull") - val value = ctx.freshName("value") - val code = s""" - |boolean $isNull = true; - |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; - |if ($matched != null) { - | ${ev.code} - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; - |} - """.stripMargin - ExprCode(code, isNull, value) + if (joinType == Inner) { + ev + } else { + val isNull = ctx.freshName("isNull") + val value = ctx.freshName("value") + val code = s""" + |boolean $isNull = true; + |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; + |if ($matched != null) { + | ${ev.code} + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + |} + """.stripMargin + ExprCode(code, isNull, value) + } } // output variables @@ -303,23 +223,74 @@ case class BroadcastHashJoin( case BuildRight => input ++ buildColumns } - // filter the output via condition - val checkCondition = if (condition.isDefined) { - ctx.currentVars = resultVars - val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) - s""" + if (joinType == Inner) { + val outputCode = if (condition.isDefined) { + // filter the output via condition + ctx.currentVars = resultVars + val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + s""" + |${ev.code} + |if (!${ev.isNull} && ${ev.value}) { + | ${consume(ctx, resultVars)} + |} + """.stripMargin + } else { + consume(ctx, resultVars) + } + + if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + s""" + |// generate join key + |${keyVal.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value}); + |if ($matched != null) { + | ${buildColumns.map(_.code).mkString("\n")} + | $outputCode + |} + """.stripMargin + + } else { + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + s""" + |// generate join key + |${keyVal.code} + |// find matches from HashRelation + |$bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get(${keyVal.value}); + |if ($matches != null) { + | int $size = $matches.size(); + | for (int $i = 0; $i < $size; $i++) { + | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); + | ${buildColumns.map(_.code).mkString("\n")} + | $outputCode + | } + |} + """.stripMargin + } + + } else { + // LeftOuter and RightOuter + + // filter the output via condition + val checkCondition = if (condition.isDefined) { + ctx.currentVars = resultVars + val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + s""" |boolean $valid = true; |if ($matched != null) { | ${ev.code} | $valid = !${ev.isNull} && ${ev.value}; |} - """.stripMargin - } else { - s"final boolean $valid = true;" - } + """.stripMargin + } else { + s"final boolean $valid = true;" + } - if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { - s""" + if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + s""" |// generate join key |${keyVal.code} |// find matches from HashedRelation @@ -331,20 +302,19 @@ case class BroadcastHashJoin( | ${buildColumns.map(v => s"${v.isNull} = true;").mkString("\n")} |} |${consume(ctx, resultVars)} - """.stripMargin - - } else { - val matches = ctx.freshName("matches") - val bufferType = classOf[CompactBuffer[UnsafeRow]].getName - val i = ctx.freshName("i") - val size = ctx.freshName("size") - val found = ctx.freshName("found") - s""" + """.stripMargin + + } else { + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + val found = ctx.freshName("found") + s""" |// generate join key |${keyVal.code} |// find matches from HashRelation - |$bufferType $matches = $anyNull ? null : - | ($bufferType) $relationTerm.get(${keyVal.value}); + |$bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get(${keyVal.value}); |int $size = $matches != null ? $matches.size() : 0; |boolean $found = false; |for (int $i = 0; $i <= $size; $i++) { @@ -356,7 +326,8 @@ case class BroadcastHashJoin( | ${consume(ctx, resultVars)} | } |} - """.stripMargin + """.stripMargin + } } } } From 9a1f5325e954d8464d28ebf415c9dca665e15d35 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 9 Feb 2016 10:21:31 -0800 Subject: [PATCH 3/8] fix style --- .../apache/spark/sql/execution/joins/BroadcastHashJoin.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 92858ec9bc0c..615405937b53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SQLExecution, SparkPlan} +import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ThreadUtils import org.apache.spark.util.collection.CompactBuffer From edbc284921281358a38b300218ff288c33cdc3b4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 9 Feb 2016 11:48:58 -0800 Subject: [PATCH 4/8] fix tests --- .../apache/spark/sql/execution/joins/OuterJoinSuite.scala | 2 +- .../spark/sql/execution/metric/SQLMetricsSuite.scala | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index e102ffe0fb01..f4b01fbad058 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -75,7 +75,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } if (joinType != FullOuter) { - test(s"$testName using BroadcastHashOuterJoin") { + test(s"$testName using BroadcastHashJoin") { val buildSide = joinType match { case LeftOuter => BuildRight case RightOuter => BuildLeft diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 2260e4870299..2c588be277a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -232,14 +232,14 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } - test("BroadcastHashOuterJoin metrics") { + test("BroadcastHashJoin(outer) metrics") { val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") // Assume the execution plan is - // ... -> BroadcastHashOuterJoin(nodeId = 0) + // ... -> BroadcastHashJoin(nodeId = 0) val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") testSparkPlanMetrics(df, 2, Map( - 0L -> ("BroadcastHashOuterJoin", Map( + 0L -> ("BroadcastHashJoin", Map( "number of left rows" -> 3L, "number of right rows" -> 4L, "number of output rows" -> 5L))) @@ -247,7 +247,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") testSparkPlanMetrics(df3, 2, Map( - 0L -> ("BroadcastHashOuterJoin", Map( + 0L -> ("BroadcastHashJoin", Map( "number of left rows" -> 3L, "number of right rows" -> 4L, "number of output rows" -> 6L))) From da45df1536f112a14bfe15d6d30d307cdbd99d5b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 10 Feb 2016 11:20:33 -0800 Subject: [PATCH 5/8] address comments --- .../execution/joins/BroadcastHashJoin.scala | 136 ++++++++++-------- .../spark/sql/execution/joins/HashJoin.scala | 23 ++- 2 files changed, 88 insertions(+), 71 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 615405937b53..17db4443bb3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -180,7 +180,7 @@ case class BroadcastHashJoin( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { // generate the key as UnsafeRow or Long ctx.currentVars = input - val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) { + val (keyEv, anyNull) = if (canJoinKeyFitWithinLong) { val expr = rewriteKeyExpr(streamedKeys).head val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) (ev, ev.isNull) @@ -192,12 +192,11 @@ case class BroadcastHashJoin( // find the matches from HashedRelation val matched = ctx.freshName("matched") - val valid = ctx.freshName("invalid") // create variables for output ctx.currentVars = null ctx.INPUT_ROW = matched - val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => + val buildVars = buildPlan.output.zipWithIndex.map { case (a, i) => val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx) if (joinType == Inner) { ev @@ -219,98 +218,118 @@ case class BroadcastHashJoin( // output variables val resultVars = buildSide match { - case BuildLeft => buildColumns ++ input - case BuildRight => input ++ buildColumns + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars } if (joinType == Inner) { - val outputCode = if (condition.isDefined) { - // filter the output via condition - ctx.currentVars = resultVars - val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) - s""" + codegenInner(ctx, keyEv, anyNull, matched, buildVars, resultVars) + } else { + // LeftOuter and RightOuter + codegenOuter(ctx, keyEv, anyNull, matched, buildVars, resultVars) + } + } + + private def codegenInner( + ctx: CodegenContext, + keyEv: ExprCode, + anyNull: String, + matched: String, + buildVars: Seq[ExprCode], + resultVars: Seq[ExprCode]): String = { + val outputCode = if (condition.isDefined) { + // filter the output via condition + ctx.currentVars = resultVars + val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + s""" |${ev.code} |if (!${ev.isNull} && ${ev.value}) { | ${consume(ctx, resultVars)} |} """.stripMargin - } else { - consume(ctx, resultVars) - } + } else { + consume(ctx, resultVars) + } - if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { - s""" + if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + s""" |// generate join key - |${keyVal.code} + |${keyEv.code} |// find matches from HashedRelation - |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value}); + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); |if ($matched != null) { - | ${buildColumns.map(_.code).mkString("\n")} + | ${buildVars.map(_.code).mkString("\n")} | $outputCode |} """.stripMargin - } else { - val matches = ctx.freshName("matches") - val bufferType = classOf[CompactBuffer[UnsafeRow]].getName - val i = ctx.freshName("i") - val size = ctx.freshName("size") - s""" + } else { + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + s""" |// generate join key - |${keyVal.code} + |${keyEv.code} |// find matches from HashRelation - |$bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get(${keyVal.value}); + |$bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get(${keyEv.value}); |if ($matches != null) { | int $size = $matches.size(); | for (int $i = 0; $i < $size; $i++) { | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); - | ${buildColumns.map(_.code).mkString("\n")} + | ${buildVars.map(_.code).mkString("\n")} | $outputCode | } |} """.stripMargin - } - - } else { - // LeftOuter and RightOuter + } + } - // filter the output via condition - val checkCondition = if (condition.isDefined) { - ctx.currentVars = resultVars - val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) - s""" - |boolean $valid = true; + private def codegenOuter( + ctx: CodegenContext, + keyVal: ExprCode, + anyNull: String, + matched: String, + buildVars: Seq[ExprCode], + resultVars: Seq[ExprCode]): String = { + // filter the output via condition + val passedFilter = ctx.freshName("passedFilter") + val checkCondition = if (condition.isDefined) { + ctx.currentVars = resultVars + val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + s""" + |boolean $passedFilter = true; |if ($matched != null) { | ${ev.code} - | $valid = !${ev.isNull} && ${ev.value}; + | $passedFilter = !${ev.isNull} && ${ev.value}; |} - """.stripMargin - } else { - s"final boolean $valid = true;" - } + """.stripMargin + } else { + s"final boolean $passedFilter = true;" + } - if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { - s""" + if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + s""" |// generate join key |${keyVal.code} |// find matches from HashedRelation |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value}); - |${buildColumns.map(_.code).mkString("\n")} + |${buildVars.map(_.code).mkString("\n")} |${checkCondition.trim} - |if (!$valid) { + |if (!$passedFilter) { | // reset to null - | ${buildColumns.map(v => s"${v.isNull} = true;").mkString("\n")} + | ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")} |} |${consume(ctx, resultVars)} - """.stripMargin + """.stripMargin - } else { - val matches = ctx.freshName("matches") - val bufferType = classOf[CompactBuffer[UnsafeRow]].getName - val i = ctx.freshName("i") - val size = ctx.freshName("size") - val found = ctx.freshName("found") - s""" + } else { + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + val found = ctx.freshName("found") + s""" |// generate join key |${keyVal.code} |// find matches from HashRelation @@ -319,15 +338,14 @@ case class BroadcastHashJoin( |boolean $found = false; |for (int $i = 0; $i <= $size; $i++) { | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null; - | ${buildColumns.map(_.code).mkString("\n")} + | ${buildVars.map(_.code).mkString("\n")} | ${checkCondition.trim} - | if ($valid && ($i < $size || !$found)) { + | if ($passedFilter && ($i < $size || !$found)) { | $found = true; | ${consume(ctx, resultVars)} | } |} - """.stripMargin - } + """.stripMargin } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 3fc63d6c7ad8..9366bb2e002e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -120,8 +120,7 @@ trait HashJoin { streamIter: Iterator[InternalRow], numStreamRows: LongSQLMetric, hashedRelation: HashedRelation, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = - { + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { new Iterator[InternalRow] { private[this] var currentStreamedRow: InternalRow = _ private[this] var currentHashMatches: Seq[InternalRow] = _ @@ -190,11 +189,11 @@ trait HashJoin { @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) protected[this] def leftOuterIterator( - key: InternalRow, - joinedRow: JoinedRow, - rightIter: Iterable[InternalRow], - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + key: InternalRow, + joinedRow: JoinedRow, + rightIter: Iterable[InternalRow], + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (rightIter != null) { @@ -222,11 +221,11 @@ trait HashJoin { } protected[this] def rightOuterIterator( - key: InternalRow, - leftIter: Iterable[InternalRow], - joinedRow: JoinedRow, - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + key: InternalRow, + leftIter: Iterable[InternalRow], + joinedRow: JoinedRow, + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { val temp = if (leftIter != null) { From 1c0ee96e80d5cc1909d7d5ec794b74e76979ae45 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 12 Feb 2016 12:40:30 -0800 Subject: [PATCH 6/8] fix worst case of broadcast join with two ints --- .../spark/sql/execution/joins/HashJoin.scala | 13 +- .../BenchmarkWholeStageCodegen.scala | 128 +++++++++++++++--- 2 files changed, 123 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 84b6820c3534..2fe9c06cc953 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.sql.types.{IntegralType, LongType} +import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType} import org.apache.spark.util.collection.CompactBuffer trait HashJoin { @@ -83,8 +83,17 @@ trait HashJoin { width = dt.defaultSize } else { val bits = dt.defaultSize * 8 + // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same + // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys + // with two same ints have hash code 0, we rotate the bits of second one. + val rotated = if (e.dataType == IntegerType) { + // (e >>> 15) | (e << 17) + BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17))) + } else { + e + } keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1))) width -= bits } // TODO: support BooleanType, DateType and TimestampType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index d83e8fd1da5d..77c07981a548 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.execution +import java.util.HashMap + import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{StringType, IntegerType} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.map.BytesToBytesMap @@ -124,52 +126,69 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ignore("broadcast hash join") { val N = 100 << 20 - val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v")) + val M = 1 << 16 + val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v")) runBenchmark("Join w long", N) { - sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count() + sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k")).count() } /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long codegen=false 10174 / 10317 10.0 100.0 1.0X - Join w long codegen=true 1069 / 1107 98.0 10.2 9.5X + Join w long codegen=false 5744 / 5814 18.3 54.8 1.0X + Join w long codegen=true 735 / 853 142.7 7.0 7.8X */ - val dim2 = broadcast(sqlContext.range(1 << 16) + val dim2 = broadcast(sqlContext.range(M) .selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v")) runBenchmark("Join w 2 ints", N) { sqlContext.range(N).join(dim2, - (col("id") bitwiseAND 60000).cast(IntegerType) === col("k1") - && (col("id") bitwiseAND 50000).cast(IntegerType) === col("k2")).count() + (col("id") bitwiseAND M).cast(IntegerType) === col("k1") + && (col("id") bitwiseAND M).cast(IntegerType) === col("k2")).count() } /** Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 ints codegen=false 11435 / 11530 9.0 111.1 1.0X - Join w 2 ints codegen=true 1265 / 1424 82.0 12.2 9.0X + Join w 2 ints codegen=false 7159 / 7224 14.6 68.3 1.0X + Join w 2 ints codegen=true 1135 / 1197 92.4 10.8 6.3X */ + val dim3 = broadcast(sqlContext.range(M) + .selectExpr("id as k1", "id as k2", "cast(id as string) as v")) + + runBenchmark("Join w 2 longs", N) { + sqlContext.range(N).join(dim3, + (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2")) + .count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Join w 2 longs codegen=false 7877 / 8358 13.3 75.1 1.0X + Join w 2 longs codegen=true 3877 / 3937 27.0 37.0 2.0X + */ runBenchmark("outer join w long", N) { - sqlContext.range(N).join(dim, (col("id") % 60000) === col("k"), "left").count() + sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count() } /** Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - outer join w long codegen=false 19438 / 19879 5.4 185.4 1.0X - outer join w long codegen=true 1098 / 1129 95.5 10.5 17.7X + outer join w long codegen=false 15280 / 16497 6.9 145.7 1.0X + outer join w long codegen=true 769 / 796 136.3 7.3 19.9X */ } ignore("hash and BytesToBytesMap") { - val N = 50 << 20 + val N = 10 << 20 val benchmark = new Benchmark("BytesToBytesMap", N) @@ -221,6 +240,80 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } + benchmark.addCase("Java HashMap (Long)") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[Long, UnsafeRow]() + while (i < 65536) { + value.setInt(0, i) + map.put(i.toLong, value) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + if (map.get(i % 100000) != null) { + s += 1 + } + i += 1 + } + } + + benchmark.addCase("Java HashMap (two ints) ") { iter => + var i = 0 + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[Long, UnsafeRow]() + while (i < 65536) { + value.setInt(0, i) + val key = (i.toLong << 32) + Integer.rotateRight(i, 15) + map.put(key, value) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + val key = ((i & 100000).toLong << 32) + Integer.rotateRight(i & 100000, 15) + if (map.get(key) != null) { + s += 1 + } + i += 1 + } + } + + benchmark.addCase("Java HashMap (UnsafeRow)") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[UnsafeRow, UnsafeRow]() + while (i < 65536) { + key.setInt(0, i) + value.setInt(0, i) + map.put(key, value.copy()) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + key.setInt(0, i % 100000) + if (map.get(key) != null) { + s += 1 + } + i += 1 + } + } + Seq("off", "on").foreach { heap => benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => val taskMemoryManager = new TaskMemoryManager( @@ -262,6 +355,9 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { hash 651 / 678 80.0 12.5 1.0X fast hash 336 / 343 155.9 6.4 1.9X arrayEqual 417 / 428 125.0 8.0 1.6X + Java HashMap (Long) 145 / 168 72.2 13.8 0.8X + Java HashMap (two ints) 157 / 164 66.8 15.0 0.8X + Java HashMap (UnsafeRow) 538 / 573 19.5 51.3 0.2X BytesToBytesMap (off Heap) 2594 / 2664 20.2 49.5 0.2X BytesToBytesMap (on Heap) 2693 / 2989 19.5 51.4 0.2X */ From 57241806ae31130429cb68a58a4086f15c3965f4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 15 Feb 2016 17:21:57 -0800 Subject: [PATCH 7/8] add more comments --- .../execution/joins/BroadcastHashJoin.scala | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 9a295606c4b7..a64da225800a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -118,18 +118,18 @@ case class BroadcastHashJoin( hashJoin(streamedIter, hashTable, numOutputRows) case LeftOuter => - streamedIter.flatMap(currentRow => { + streamedIter.flatMap { currentRow => val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) - }) + } case RightOuter => - streamedIter.flatMap(currentRow => { + streamedIter.flatMap { currentRow => val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) - }) + } case x => throw new IllegalArgumentException( @@ -155,6 +155,9 @@ case class BroadcastHashJoin( } } + /** + * Returns a tuple of Broadcast of HashedRelation and the variable name for it. + */ private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = { // create a name for HashedRelation val broadcastRelation = Await.result(broadcastFuture, timeout) @@ -169,7 +172,13 @@ case class BroadcastHashJoin( (broadcastRelation, relationTerm) } - private def genJoinKey(ctx: CodegenContext, input: Seq[ExprCode]): (ExprCode, String) = { + /** + * Returns the code for generating join key for stream side, and expression of whether the key + * has any null in it or not. + */ + private def genStreamSideJoinKey( + ctx: CodegenContext, + input: Seq[ExprCode]): (ExprCode, String) = { ctx.currentVars = input if (canJoinKeyFitWithinLong) { // generate the join key as Long @@ -184,6 +193,9 @@ case class BroadcastHashJoin( } } + /** + * Generates the code for variable of build side. + */ private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = { ctx.currentVars = null ctx.INPUT_ROW = matched @@ -209,9 +221,12 @@ case class BroadcastHashJoin( } } + /** + * Generates the code for Inner join. + */ private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - val (keyEv, anyNull) = genJoinKey(ctx, input) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) val matched = ctx.freshName("matched") val buildVars = genBuildSideVars(ctx, matched) val resultVars = buildSide match { @@ -240,7 +255,7 @@ case class BroadcastHashJoin( if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { s""" - |// generate join key + |// generate join key for stream side |${keyEv.code} |// find matches from HashedRelation |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); @@ -256,7 +271,7 @@ case class BroadcastHashJoin( val i = ctx.freshName("i") val size = ctx.freshName("size") s""" - |// generate join key + |// generate join key for stream side |${keyEv.code} |// find matches from HashRelation |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); @@ -273,9 +288,12 @@ case class BroadcastHashJoin( } + /** + * Generates the code for left or right outer join. + */ private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - val (keyEv, anyNull) = genJoinKey(ctx, input) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) val matched = ctx.freshName("matched") val buildVars = genBuildSideVars(ctx, matched) val resultVars = buildSide match { @@ -302,7 +320,7 @@ case class BroadcastHashJoin( if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { s""" - |// generate join key + |// generate join key for stream side |${keyEv.code} |// find matches from HashedRelation |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); @@ -323,12 +341,13 @@ case class BroadcastHashJoin( val size = ctx.freshName("size") val found = ctx.freshName("found") s""" - |// generate join key + |// generate join key for stream side |${keyEv.code} |// find matches from HashRelation |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); |int $size = $matches != null ? $matches.size() : 0; |boolean $found = false; + |// the last iteration of this loop is to emit an empty row if there is no matched rows. |for (int $i = 0; $i <= $size; $i++) { | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null; | ${buildVars.map(_.code).mkString("\n")} From 5744941063ba05b07e4a7265277162c331a9c48c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 15 Feb 2016 17:41:01 -0800 Subject: [PATCH 8/8] fix style --- .../apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 77c07981a548..b3bfea8a8aa6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -24,7 +24,7 @@ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{StringType, IntegerType} +import org.apache.spark.sql.types.IntegerType import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.map.BytesToBytesMap