-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24121][SQL] Add API for handling expression code generation #21193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
1df9943
5fe425c
00bef6b
5d9c454
162deb2
d138ee0
ee9a4c0
e7cfa28
5945c15
2b30654
aff411b
53b329a
72faac3
ffbf4ab
d040676
c378ce2
2ca9741
d91f111
4b49e8a
96c594a
00cc564
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ import org.apache.spark.SparkException | |
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} | ||
| import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
| import org.apache.spark.sql.catalyst.util._ | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} | ||
|
|
@@ -623,8 +624,14 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String | |
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| val eval = child.genCode(ctx) | ||
| val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) | ||
|
|
||
| // Below the code comment including `eval.value` and `eval.isNull` is a trick. It makes the two | ||
| // expr values are referred by this code block. | ||
| ev.copy(code = eval.code + | ||
| castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) | ||
| code""" | ||
| // Cast from ${eval.value}, ${eval.isNull} | ||
|
||
| ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)} | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might miss something, but don't we need to pass
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can use Inputs to |
||
| """) | ||
| } | ||
|
|
||
| // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ import java.util.Locale | |
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
| import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
| import org.apache.spark.sql.catalyst.trees.TreeNode | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.util.Utils | ||
|
|
@@ -100,17 +101,18 @@ abstract class Expression extends TreeNode[Expression] { | |
| ctx.subExprEliminationExprs.get(this).map { subExprState => | ||
| // This expression is repeated which means that the code to evaluate it has already been added | ||
| // as a function before. In that case, we just re-use it. | ||
| ExprCode(ctx.registerComment(this.toString), subExprState.isNull, subExprState.value) | ||
| ExprCode(code"${ctx.registerComment(this.toString)}", subExprState.isNull, | ||
|
||
| subExprState.value) | ||
|
||
| }.getOrElse { | ||
| val isNull = ctx.freshName("isNull") | ||
| val value = ctx.freshName("value") | ||
| val eval = doGenCode(ctx, ExprCode( | ||
| JavaCode.isNullVariable(isNull), | ||
| JavaCode.variable(value, dataType))) | ||
| reduceCodeSize(ctx, eval) | ||
| if (eval.code.nonEmpty) { | ||
| if (eval.code.toString.nonEmpty) { | ||
| // Add `this` in the comment. | ||
| eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim) | ||
| eval.copy(code = code"${ctx.registerComment(this.toString)}\n" + eval.code) | ||
| } else { | ||
| eval | ||
| } | ||
|
|
@@ -119,7 +121,7 @@ abstract class Expression extends TreeNode[Expression] { | |
|
|
||
| private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { | ||
| // TODO: support whole stage codegen too | ||
| if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { | ||
| if (eval.code.toString.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { | ||
|
||
| val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { | ||
| val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") | ||
| val localIsNull = eval.isNull | ||
|
|
@@ -136,14 +138,14 @@ abstract class Expression extends TreeNode[Expression] { | |
| val funcFullName = ctx.addNewFunction(funcName, | ||
| s""" | ||
| |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { | ||
| | ${eval.code.trim} | ||
| | ${eval.code} | ||
| | $setIsNull | ||
| | return ${eval.value}; | ||
| |} | ||
| """.stripMargin) | ||
|
|
||
| eval.value = JavaCode.variable(newValue, dataType) | ||
| eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" | ||
| eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -437,15 +439,14 @@ abstract class UnaryExpression extends Expression { | |
|
|
||
| if (nullable) { | ||
| val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) | ||
| ev.copy(code = s""" | ||
| ev.copy(code = code""" | ||
| ${childGen.code} | ||
| boolean ${ev.isNull} = ${childGen.isNull}; | ||
| ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; | ||
| $nullSafeEval | ||
| """) | ||
| } else { | ||
| ev.copy(code = s""" | ||
| boolean ${ev.isNull} = false; | ||
| ev.copy(code = code""" | ||
| ${childGen.code} | ||
| ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; | ||
| $resultCode""", isNull = FalseLiteral) | ||
|
|
@@ -537,14 +538,13 @@ abstract class BinaryExpression extends Expression { | |
| } | ||
| } | ||
|
|
||
| ev.copy(code = s""" | ||
| ev.copy(code = code""" | ||
| boolean ${ev.isNull} = true; | ||
| ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; | ||
| $nullSafeEval | ||
| """) | ||
| } else { | ||
| ev.copy(code = s""" | ||
| boolean ${ev.isNull} = false; | ||
| ev.copy(code = code""" | ||
| ${leftGen.code} | ||
| ${rightGen.code} | ||
| ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; | ||
|
|
@@ -681,13 +681,12 @@ abstract class TernaryExpression extends Expression { | |
| } | ||
| } | ||
|
|
||
| ev.copy(code = s""" | ||
| ev.copy(code = code""" | ||
| boolean ${ev.isNull} = true; | ||
| ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; | ||
| $nullSafeEval""") | ||
| } else { | ||
| ev.copy(code = s""" | ||
| boolean ${ev.isNull} = false; | ||
| ev.copy(code = code""" | ||
| ${leftGen.code} | ||
| ${midGen.code} | ||
| ${rightGen.code} | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,6 +38,7 @@ import org.apache.spark.internal.Logging | |
| import org.apache.spark.metrics.source.CodegenMetrics | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
| import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} | ||
| import org.apache.spark.sql.internal.SQLConf | ||
| import org.apache.spark.sql.types._ | ||
|
|
@@ -56,19 +57,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils} | |
| * @param value A term for a (possibly primitive) value of the result of the evaluation. Not | ||
| * valid if `isNull` is set to `true`. | ||
| */ | ||
| case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) | ||
| case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue) | ||
|
|
||
| object ExprCode { | ||
| def apply(isNull: ExprValue, value: ExprValue): ExprCode = { | ||
| ExprCode(code = "", isNull, value) | ||
| ExprCode(code = code"", isNull, value) | ||
|
||
| } | ||
|
|
||
| def forNullValue(dataType: DataType): ExprCode = { | ||
| ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) | ||
| ExprCode(code = code"", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) | ||
| } | ||
|
|
||
| def forNonNullValue(value: ExprValue): ExprCode = { | ||
| ExprCode(code = "", isNull = FalseLiteral, value = value) | ||
| ExprCode(code = code"", isNull = FalseLiteral, value = value) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -329,9 +330,9 @@ class CodegenContext { | |
| def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { | ||
| val value = addMutableState(javaType(dataType), variableName) | ||
| val code = dataType match { | ||
| case StringType => s"$value = $initCode.clone();" | ||
| case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" | ||
| case _ => s"$value = $initCode;" | ||
| case StringType => code"$value = $initCode.clone();" | ||
| case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();" | ||
| case _ => code"$value = $initCode;" | ||
| } | ||
| ExprCode(code, FalseLiteral, JavaCode.global(value, dataType)) | ||
| } | ||
|
|
@@ -988,7 +989,7 @@ class CodegenContext { | |
| val eval = expr.genCode(this) | ||
| val state = SubExprEliminationState(eval.isNull, eval.value) | ||
| e.foreach(localSubExprEliminationExprs.put(_, state)) | ||
| eval.code.trim | ||
| eval.code.toString | ||
| } | ||
| SubExprCodes(codes, localSubExprEliminationExprs.toMap) | ||
| } | ||
|
|
@@ -1016,7 +1017,7 @@ class CodegenContext { | |
| val fn = | ||
| s""" | ||
| |private void $fnName(InternalRow $INPUT_ROW) { | ||
| | ${eval.code.trim} | ||
| | ${eval.code} | ||
| | $isNull = ${eval.isNull}; | ||
| | $value = ${eval.value}; | ||
| |} | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: how about moving
eval.codeinto the followingcodeinterpolation?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok.