-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-22520][SQL] Support code generation for large CaseWhen #19752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
98eaae9
6225c8e
f9c20be
9063583
c7f0a92
f4c7896
5adb513
6b280fd
c7347b1
dd5f455
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -141,14 +141,34 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi | |
| } | ||
|
|
||
| /** | ||
| * Abstract parent class for common logic in CaseWhen and CaseWhenCodegen. | ||
| * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". | ||
| * When a = true, returns b; when c = true, returns d; else returns e. | ||
| * | ||
| * @param branches seq of (branch condition, branch value) | ||
| * @param elseValue optional value for the else branch | ||
| */ | ||
| abstract class CaseWhenBase( | ||
| // scalastyle:off line.size.limit | ||
| @ExpressionDescription( | ||
| usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.", | ||
| arguments = """ | ||
| Arguments: | ||
| * expr1, expr3 - the branch condition expressions should all be boolean type. | ||
| * expr2, expr4, expr5 - the branch value expressions and else value expression should all be | ||
| same type or coercible to a common type. | ||
| """, | ||
| examples = """ | ||
| Examples: | ||
| > SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END; | ||
| 1 | ||
| > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END; | ||
| 2 | ||
| > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 END; | ||
| NULL | ||
| """) | ||
| // scalastyle:on line.size.limit | ||
| case class CaseWhen( | ||
| branches: Seq[(Expression, Expression)], | ||
| elseValue: Option[Expression]) | ||
| elseValue: Option[Expression] = None) | ||
| extends Expression with Serializable { | ||
|
|
||
| override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue | ||
|
|
@@ -211,111 +231,73 @@ abstract class CaseWhenBase( | |
| val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("") | ||
| "CASE" + cases + elseCase + " END" | ||
| } | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". | ||
| * When a = true, returns b; when c = true, returns d; else returns e. | ||
| * | ||
| * @param branches seq of (branch condition, branch value) | ||
| * @param elseValue optional value for the else branch | ||
| */ | ||
| // scalastyle:off line.size.limit | ||
| @ExpressionDescription( | ||
| usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.", | ||
| arguments = """ | ||
| Arguments: | ||
| * expr1, expr3 - the branch condition expressions should all be boolean type. | ||
| * expr2, expr4, expr5 - the branch value expressions and else value expression should all be | ||
| same type or coercible to a common type. | ||
| """, | ||
| examples = """ | ||
| Examples: | ||
| > SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END; | ||
| 1 | ||
| > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END; | ||
| 2 | ||
| > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 ELSE null END; | ||
| NULL | ||
| """) | ||
| // scalastyle:on line.size.limit | ||
| case class CaseWhen( | ||
| val branches: Seq[(Expression, Expression)], | ||
| val elseValue: Option[Expression] = None) | ||
| extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable { | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| super[CodegenFallback].doGenCode(ctx, ev) | ||
| } | ||
|
|
||
| def toCodegen(): CaseWhenCodegen = { | ||
| CaseWhenCodegen(branches, elseValue) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * CaseWhen expression used when code generation condition is satisfied. | ||
| * OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen. | ||
| * | ||
| * @param branches seq of (branch condition, branch value) | ||
| * @param elseValue optional value for the else branch | ||
| */ | ||
| case class CaseWhenCodegen( | ||
| val branches: Seq[(Expression, Expression)], | ||
| val elseValue: Option[Expression] = None) | ||
| extends CaseWhenBase(branches, elseValue) with Serializable { | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| // Generate code that looks like: | ||
| // | ||
| // condA = ... | ||
| // if (condA) { | ||
| // valueA | ||
| // } else { | ||
| // condB = ... | ||
| // if (condB) { | ||
| // valueB | ||
| // } else { | ||
| // condC = ... | ||
| // if (condC) { | ||
| // valueC | ||
| // } else { | ||
| // elseValue | ||
| // } | ||
| // } | ||
| // } | ||
| // This variable represents whether the first successful condition is met or not. | ||
| // It is initialized to `false` and it is set to `true` when the first condition which | ||
| // evaluates to `true` is met and therefore is not needed to go on anymore on the computation | ||
| // of the following conditions. | ||
| val conditionMet = ctx.freshName("caseWhenConditionMet") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comment to explain what it is. |
||
| ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull, "") | ||
|
||
| ctx.addMutableState(ctx.javaType(dataType), ev.value, "") | ||
| val cases = branches.map { case (condExpr, valueExpr) => | ||
| val cond = condExpr.genCode(ctx) | ||
| val res = valueExpr.genCode(ctx) | ||
| s""" | ||
| ${cond.code} | ||
| if (!${cond.isNull} && ${cond.value}) { | ||
| ${res.code} | ||
| ${ev.isNull} = ${res.isNull}; | ||
| ${ev.value} = ${res.value}; | ||
| if(!$conditionMet) { | ||
| ${cond.code} | ||
| if (!${cond.isNull} && ${cond.value}) { | ||
| ${res.code} | ||
| ${ev.isNull} = ${res.isNull}; | ||
| ${ev.value} = ${res.value}; | ||
| $conditionMet = true; | ||
| } | ||
| } | ||
| """ | ||
| } | ||
|
|
||
| var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n") | ||
|
|
||
| elseValue.foreach { elseExpr => | ||
| val elseCode = elseValue.map { elseExpr => | ||
| val res = elseExpr.genCode(ctx) | ||
| generatedCode += | ||
| s""" | ||
| s""" | ||
| if(!$conditionMet) { | ||
| ${res.code} | ||
| ${ev.isNull} = ${res.isNull}; | ||
| ${ev.value} = ${res.value}; | ||
| """ | ||
| } | ||
| """ | ||
| } | ||
|
|
||
| generatedCode += "}\n" * cases.size | ||
| val allConditions = cases ++ elseCode | ||
|
|
||
| val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { | ||
| allConditions.mkString("\n") | ||
| } else { | ||
| ctx.splitExpressions(allConditions, "caseWhen", | ||
| ("InternalRow", ctx.INPUT_ROW) :: (ctx.JAVA_BOOLEAN, conditionMet) :: Nil, | ||
|
||
| returnType = ctx.JAVA_BOOLEAN, | ||
| makeSplitFunction = { | ||
| func => | ||
| s""" | ||
| $func | ||
| return $conditionMet; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we apply the same
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this would complicate the code and I don't think it is worth, since if the code is not split, it means that we don't have many conditions, thus we would save only few
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think in most cases we just split the codes into a few methods, which means, it's more important to apply the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, but having this optimization outside means skipping whole methods. Anyway, if you think that this optimization is needed I can do it. I think only that the code readability would be a bit worse but I'll try to address this problem with comments. |
||
| """ | ||
| }, | ||
| foldFunctions = { funcCalls => | ||
| funcCalls.map { funcCall => | ||
| s""" | ||
| $conditionMet = $funcCall; | ||
| if ($conditionMet) { | ||
| continue; | ||
| }""" | ||
| }.mkString("do {", "", "\n} while (false);") | ||
|
||
| }) | ||
| } | ||
|
|
||
| ev.copy(code = s""" | ||
| boolean ${ev.isNull} = true; | ||
| ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; | ||
| $generatedCode""") | ||
| ${ev.isNull} = true; | ||
| ${ev.value} = ${ctx.defaultValue(dataType)}; | ||
| boolean $conditionMet = false; | ||
| $code""") | ||
| } | ||
| } | ||
|
|
||
|
|
||
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we keep this comment and update it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is necessary since now the generated code is way easier and more standard and nowhere else a comment like this is provided. Anyway, if you feel it is needed, I can add it.