diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 4e02803552e8..8d272ac313e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -585,21 +585,26 @@ case class SortMergeJoinExec( val iterator = ctx.freshName("iterator") val numOutput = metricTerm(ctx, "numOutputRows") + val joinedRow = ctx.freshName("joined") val (beforeLoop, condCheck) = if (condition.isDefined) { // Split the code of creating variables based on whether it's used by condition or not. val loaded = ctx.freshName("loaded") val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) // Generate code for condition + // set INPUT_ROW to the joined row because it is the data for the condition + ctx.INPUT_ROW = joinedRow ctx.currentVars = leftVars ++ rightVars val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) // evaluate the columns those used by condition before loop val before = s""" |boolean $loaded = false; + |$joinedRow.withLeft($leftRow); |$leftBefore """.stripMargin val checking = s""" + |$joinedRow.withRight($rightRow); |$rightBefore |${cond.code} |if (${cond.isNull} || !${cond.value}) continue; @@ -615,6 +620,7 @@ case class SortMergeJoinExec( } s""" + |JoinedRow $joinedRow = new JoinedRow(); |while (findNextInnerJoinRows($leftInput, $rightInput)) { | ${beforeLoop.trim} | scala.collection.Iterator $iterator = $matches.generateIterator(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 4408ece11225..e69271fa07af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{And, BinaryExpression, Expression, Predicate} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join @@ -124,7 +125,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { rightPlan: SparkPlan) = { val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition, leftPlan, rightPlan) - EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) + EnsureRequirements(spark.sessionState.conf) + .apply(ProjectExec(sortMergeJoin.output, sortMergeJoin)) } test(s"$testName using BroadcastHashJoin (build=left)") { @@ -228,6 +230,27 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { ) ) + testInnerJoin( + "inner join with CodegenFallback filter", + myUpperCaseData, + myLowerCaseData, + () => { + // add a second equality check that is implemented with a CodegenFallback + // this expression is in the test so that no one implements codegen for it + And( + (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr, + EqNoCodegen( + org.apache.spark.sql.functions.lower(myUpperCaseData.col("L")).expr, + myLowerCaseData.col("l").expr)) + }, + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + ) + ) + { lazy val left = myTestData1.where("a = 1") lazy val right = myTestData2.where("a = 1") @@ -287,3 +310,10 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { (Row(2, 2), "L2", Row(2, 2), "R2"))) } } + +case class EqNoCodegen(left: Expression, right: Expression) extends BinaryExpression + with CodegenFallback with Serializable with Predicate { + override protected def nullSafeEval(left: Any, right: Any): Boolean = { + left == right + } +}