Skip to content

Commit 98eaae9

Browse files
committed
[SPARK-22520][SQL] Support code generation for large CaseWhen
1 parent 4bacddb commit 98eaae9

File tree

9 files changed

+64
-250
lines changed

9 files changed

+64
-250
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ class EquivalentExpressions {
8787
def childrenToRecurse: Seq[Expression] = expr match {
8888
case _: CodegenFallback => Nil
8989
case i: If => i.predicate :: Nil
90-
// `CaseWhen` implements `CodegenFallback`, we only need to handle `CaseWhenCodegen` here.
91-
case c: CaseWhenCodegen => c.children.head :: Nil
90+
case c: CaseWhen => c.children.head :: Nil
9291
case c: Coalesce => c.children.head :: Nil
9392
case other => other.children
9493
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala

Lines changed: 61 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,34 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
141141
}
142142

143143
/**
144-
* Abstract parent class for common logic in CaseWhen and CaseWhenCodegen.
144+
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
145+
* When a = true, returns b; when c = true, returns d; else returns e.
145146
*
146147
* @param branches seq of (branch condition, branch value)
147148
* @param elseValue optional value for the else branch
148149
*/
149-
abstract class CaseWhenBase(
150+
// scalastyle:off line.size.limit
151+
@ExpressionDescription(
152+
usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.",
153+
arguments = """
154+
Arguments:
155+
* expr1, expr3 - the branch condition expressions should all be boolean type.
156+
* expr2, expr4, expr5 - the branch value expressions and else value expression should all be
157+
same type or coercible to a common type.
158+
""",
159+
examples = """
160+
Examples:
161+
> SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
162+
1
163+
> SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
164+
2
165+
> SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 ELSE null END;
166+
NULL
167+
""")
168+
// scalastyle:on line.size.limit
169+
case class CaseWhen(
150170
branches: Seq[(Expression, Expression)],
151-
elseValue: Option[Expression])
171+
elseValue: Option[Expression] = None)
152172
extends Expression with Serializable {
153173

154174
override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
@@ -211,111 +231,61 @@ abstract class CaseWhenBase(
211231
val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
212232
"CASE" + cases + elseCase + " END"
213233
}
214-
}
215-
216-
217-
/**
218-
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
219-
* When a = true, returns b; when c = true, returns d; else returns e.
220-
*
221-
* @param branches seq of (branch condition, branch value)
222-
* @param elseValue optional value for the else branch
223-
*/
224-
// scalastyle:off line.size.limit
225-
@ExpressionDescription(
226-
usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.",
227-
arguments = """
228-
Arguments:
229-
* expr1, expr3 - the branch condition expressions should all be boolean type.
230-
* expr2, expr4, expr5 - the branch value expressions and else value expression should all be
231-
same type or coercible to a common type.
232-
""",
233-
examples = """
234-
Examples:
235-
> SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
236-
1
237-
> SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
238-
2
239-
> SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 ELSE null END;
240-
NULL
241-
""")
242-
// scalastyle:on line.size.limit
243-
case class CaseWhen(
244-
val branches: Seq[(Expression, Expression)],
245-
val elseValue: Option[Expression] = None)
246-
extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable {
247-
248-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
249-
super[CodegenFallback].doGenCode(ctx, ev)
250-
}
251-
252-
def toCodegen(): CaseWhenCodegen = {
253-
CaseWhenCodegen(branches, elseValue)
254-
}
255-
}
256-
257-
/**
258-
* CaseWhen expression used when code generation condition is satisfied.
259-
* OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen.
260-
*
261-
* @param branches seq of (branch condition, branch value)
262-
* @param elseValue optional value for the else branch
263-
*/
264-
case class CaseWhenCodegen(
265-
val branches: Seq[(Expression, Expression)],
266-
val elseValue: Option[Expression] = None)
267-
extends CaseWhenBase(branches, elseValue) with Serializable {
268234

269235
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
270-
// Generate code that looks like:
271-
//
272-
// condA = ...
273-
// if (condA) {
274-
// valueA
275-
// } else {
276-
// condB = ...
277-
// if (condB) {
278-
// valueB
279-
// } else {
280-
// condC = ...
281-
// if (condC) {
282-
// valueC
283-
// } else {
284-
// elseValue
285-
// }
286-
// }
287-
// }
236+
val conditionMet = ctx.freshName("caseWhenConditionMet")
237+
ctx.addMutableState("boolean", ev.isNull, "")
238+
ctx.addMutableState(ctx.javaType(dataType), ev.value, "")
288239
val cases = branches.map { case (condExpr, valueExpr) =>
289240
val cond = condExpr.genCode(ctx)
290241
val res = valueExpr.genCode(ctx)
291242
s"""
292-
${cond.code}
293-
if (!${cond.isNull} && ${cond.value}) {
294-
${res.code}
295-
${ev.isNull} = ${res.isNull};
296-
${ev.value} = ${res.value};
243+
if(!$conditionMet) {
244+
${cond.code}
245+
if (!${cond.isNull} && ${cond.value}) {
246+
${res.code}
247+
${ev.isNull} = ${res.isNull};
248+
${ev.value} = ${res.value};
249+
$conditionMet = true;
250+
}
297251
}
298252
"""
299253
}
300254

301-
var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n")
302-
303-
elseValue.foreach { elseExpr =>
255+
val elseCode = elseValue.map { elseExpr =>
304256
val res = elseExpr.genCode(ctx)
305-
generatedCode +=
306-
s"""
257+
s"""
258+
if(!$conditionMet) {
307259
${res.code}
308260
${ev.isNull} = ${res.isNull};
309261
${ev.value} = ${res.value};
310-
"""
311-
}
262+
}
263+
"""
264+
}.getOrElse("")
312265

313-
generatedCode += "}\n" * cases.size
266+
val casesCode = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
267+
cases.mkString("\n")
268+
} else {
269+
ctx.splitExpressions(cases, "caseWhen",
270+
("InternalRow", ctx.INPUT_ROW) :: ("boolean", conditionMet) :: Nil, returnType = "boolean",
271+
makeSplitFunction = {
272+
func =>
273+
s"""
274+
$func
275+
return $conditionMet;
276+
"""
277+
},
278+
foldFunctions = { funcCalls =>
279+
funcCalls.map(funcCall => s"$conditionMet = $funcCall;").mkString("\n")
280+
})
281+
}
314282

315283
ev.copy(code = s"""
316-
boolean ${ev.isNull} = true;
317-
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
318-
$generatedCode""")
284+
${ev.isNull} = true;
285+
${ev.value} = ${ctx.defaultValue(dataType)};
286+
boolean $conditionMet = false;
287+
$casesCode
288+
$elseCode""")
319289
}
320290
}
321291

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
138138
// The following batch should be executed after batch "Join Reorder" and "LocalRelation".
139139
Batch("Check Cartesian Products", Once,
140140
CheckCartesianProducts) ::
141-
Batch("OptimizeCodegen", Once,
142-
OptimizeCodegen) ::
143141
Batch("RewriteSubquery", Once,
144142
RewritePredicateSubquery,
145143
CollapseProject) :: Nil

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -552,21 +552,6 @@ object FoldablePropagation extends Rule[LogicalPlan] {
552552
}
553553

554554

555-
/**
556-
* Optimizes expressions by replacing according to CodeGen configuration.
557-
*/
558-
object OptimizeCodegen extends Rule[LogicalPlan] {
559-
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
560-
case e: CaseWhen if canCodegen(e) => e.toCodegen()
561-
}
562-
563-
private def canCodegen(e: CaseWhen): Boolean = {
564-
val numBranches = e.branches.size + e.elseValue.size
565-
numBranches <= SQLConf.get.maxCaseBranchesForCodegen
566-
}
567-
}
568-
569-
570555
/**
571556
* Removes [[Cast Casts]] that are unnecessary because the input is already the correct type.
572557
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -570,12 +570,6 @@ object SQLConf {
570570
.booleanConf
571571
.createWithDefault(true)
572572

573-
val MAX_CASES_BRANCHES = buildConf("spark.sql.codegen.maxCaseBranches")
574-
.internal()
575-
.doc("The maximum number of switches supported with codegen.")
576-
.intConf
577-
.createWithDefault(20)
578-
579573
val CODEGEN_LOGGING_MAX_LINES = buildConf("spark.sql.codegen.logging.maxLines")
580574
.internal()
581575
.doc("The maximum number of codegen lines to log when errors occur. Use -1 for unlimited.")
@@ -1084,8 +1078,6 @@ class SQLConf extends Serializable with Logging {
10841078

10851079
def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK)
10861080

1087-
def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES)
1088-
10891081
def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES)
10901082

10911083
def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
7777
}
7878

7979
test("SPARK-13242: case-when expression with large number of branches (or cases)") {
80-
val cases = 50
80+
val cases = 500
8181
val clauses = 20
8282

8383
// Generate an individual case

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala

Lines changed: 0 additions & 101 deletions
This file was deleted.

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class FlatMapGroupsWithState_StateManager(
9090
val deser = stateEncoder.resolveAndBind().deserializer.transformUp {
9191
case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal)
9292
}
93-
CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser).toCodegen()
93+
CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser)
9494
}
9595

9696
// Converters for translating state between rows and Java objects

0 commit comments

Comments
 (0)