Skip to content

Commit 6b6dd68

Browse files
DonnyZonecloud-fan
authored andcommitted
[SPARK-21441][SQL] Incorrect Codegen in SortMergeJoinExec results failures in some cases
## What changes were proposed in this pull request? https://issues.apache.org/jira/projects/SPARK/issues/SPARK-21441 This issue can be reproduced by the following example: ``` val spark = SparkSession .builder() .appName("smj-codegen") .master("local") .config("spark.sql.autoBroadcastJoinThreshold", "1") .getOrCreate() val df1 = spark.createDataFrame(Seq((1, 1), (2, 2), (3, 3))).toDF("key", "int") val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "3"))).toDF("key", "str") val df = df1.join(df2, df1("key") === df2("key")) .filter("int = 2 or reflect('java.lang.Integer', 'valueOf', str) = 1") .select("int") df.show() ``` To conclude, the issue happens when: (1) SortMergeJoin condition contains CodegenFallback expressions. (2) In PhysicalPlan tree, SortMergeJoin node is the child of root node, e.g., the Project in above example. This patch fixes the logic in `CollapseCodegenStages` rule. ## How was this patch tested? Unit test and manual verification in our cluster. Author: donnyzone <[email protected]> Closes #18656 from DonnyZone/Fix_SortMergeJoinExec.
1 parent 4eb081c commit 6b6dd68

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -489,13 +489,13 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
489489
* Inserts an InputAdapter on top of those that do not support codegen.
490490
*/
491491
private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match {
492-
case j @ SortMergeJoinExec(_, _, _, _, left, right) if j.supportCodegen =>
493-
// The children of SortMergeJoin should do codegen separately.
494-
j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
495-
right = InputAdapter(insertWholeStageCodegen(right)))
496492
case p if !supportCodegen(p) =>
497493
// collapse them recursively
498494
InputAdapter(insertWholeStageCodegen(p))
495+
case j @ SortMergeJoinExec(_, _, _, _, left, right) =>
496+
// The children of SortMergeJoin should do codegen separately.
497+
j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
498+
right = InputAdapter(insertWholeStageCodegen(right)))
499499
case p =>
500500
p.withNewChildren(p.children.map(insertInputAdapter))
501501
}

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
2222
import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack}
2323
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
2424
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
25+
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
2526
import org.apache.spark.sql.expressions.scalalang.typed
2627
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
28+
import org.apache.spark.sql.internal.SQLConf
2729
import org.apache.spark.sql.test.SharedSQLContext
2830
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
2931

@@ -127,4 +129,24 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
127129
"named_struct('a',id+2, 'b',id+2) as col2")
128130
.filter("col1 = col2").count()
129131
}
132+
133+
test("SPARK-21441 SortMergeJoin codegen with CodegenFallback expressions should be disabled") {
134+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") {
135+
import testImplicits._
136+
137+
val df1 = Seq((1, 1), (2, 2), (3, 3)).toDF("key", "int")
138+
val df2 = Seq((1, "1"), (2, "2"), (3, "3")).toDF("key", "str")
139+
140+
val df = df1.join(df2, df1("key") === df2("key"))
141+
.filter("int = 2 or reflect('java.lang.Integer', 'valueOf', str) = 1")
142+
.select("int")
143+
144+
val plan = df.queryExecution.executedPlan
145+
assert(!plan.find(p =>
146+
p.isInstanceOf[WholeStageCodegenExec] &&
147+
p.asInstanceOf[WholeStageCodegenExec].child.children(0)
148+
.isInstanceOf[SortMergeJoinExec]).isDefined)
149+
assert(df.collect() === Array(Row(1), Row(2)))
150+
}
151+
}
130152
}

0 commit comments

Comments
 (0)