Skip to content

Commit 6225c8e

Browse files
committed
adding test case
1 parent 98eaae9 commit 6225c8e

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,14 @@ case class CaseWhen(
261261
${ev.value} = ${res.value};
262262
}
263263
"""
264-
}.getOrElse("")
264+
}
265+
266+
val allConditions = cases ++ elseCode
265267

266-
val casesCode = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
267-
cases.mkString("\n")
268+
val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
269+
allConditions.mkString("\n")
268270
} else {
269-
ctx.splitExpressions(cases, "caseWhen",
271+
ctx.splitExpressions(allConditions, "caseWhen",
270272
("InternalRow", ctx.INPUT_ROW) :: ("boolean", conditionMet) :: Nil, returnType = "boolean",
271273
makeSplitFunction = {
272274
func =>
@@ -284,8 +286,7 @@ case class CaseWhen(
284286
${ev.isNull} = true;
285287
${ev.value} = ${ctx.defaultValue(dataType)};
286288
boolean $conditionMet = false;
287-
$casesCode
288-
$elseCode""")
289+
$code""")
289290
}
290291
}
291292

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.scalatest.Matchers._
2929
import org.apache.spark.SparkException
3030
import org.apache.spark.sql.catalyst.TableIdentifier
3131
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union}
32-
import org.apache.spark.sql.execution.{FilterExec, QueryExecution}
32+
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
3333
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
3434
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
3535
import org.apache.spark.sql.functions._
@@ -2126,4 +2126,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
21262126
val mean = result.select("DecimalCol").where($"summary" === "mean")
21272127
assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000")))
21282128
}
2129+
2130+
test("SPARK-22520: support code generation for large CaseWhen") {
2131+
val N = 30
2132+
var expr1 = when($"id" === lit(0), 0)
2133+
var expr2 = when($"id" === lit(0), 10)
2134+
(1 to N).foreach { i =>
2135+
expr1 = expr1.when($"id" === lit(i), -i)
2136+
expr2 = expr2.when($"id" === lit(i + 10), i)
2137+
}
2138+
val df = spark.range(1).select(expr1, expr2.otherwise(0))
2139+
df.show
2140+
assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
2141+
}
21292142
}

0 commit comments

Comments
 (0)