Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -473,37 +473,87 @@ class CodegenContext {
*
* @param row the variable name of row that is used by expressions
* @param expressions the codes to evaluate expressions.
* @param dataTypes the dataType of expressions. When the input of expressions is currentVars
* instead of input row, we should provide this. Default value is null.
*/
def splitExpressions(row: String, expressions: Seq[String]): String = {
def splitExpressions(
row: String,
expressions: Seq[String],
dataTypes: Seq[DataType] = null): String = {
if (currentVars != null && dataTypes == null) {
throw new IllegalArgumentException("cannot split expressions based on currentVars " +
"without expression types")
}
if (currentVars != null && currentVars.size != expressions.size) {
throw new IllegalArgumentException("cannot split expressions based on currentVars " +
"that doesn't have the same length as the given expressions.")
}
val blocks = new ArrayBuffer[String]()
val blockBuilder = new StringBuilder()
for (code <- expressions) {
val vars = new ArrayBuffer[String]()
val callArgs = new ArrayBuffer[String]()
val varsBuilder = new ArrayBuffer[(String, String)]()
for ((code, i) <- expressions.zipWithIndex) {
// We can't know how many byte code will be generated, so use the number of bytes as limit
if (blockBuilder.length > 64 * 1000) {
// When we split expressions with currentVars instead of input row, we use the size of
// function call parameters as limit to prevent stackoverflow.
if ((currentVars != null && varsBuilder.size > 100) || blockBuilder.length > 64 * 1000) {
blocks.append(blockBuilder.toString())
blockBuilder.clear()

if (currentVars != null) {
val argsForBlock = varsBuilder.map { case (argType, argValue) =>
s"$argType $argValue"
}.mkString(", ")
vars.append(argsForBlock)
callArgs.append(varsBuilder.map(_._2).mkString(", "))
varsBuilder.clear()
}
}

blockBuilder.append(code)
if (currentVars != null) {
val argType = javaType(dataTypes(i))
varsBuilder.append((argType, currentVars(i).value))
}
}

blocks.append(blockBuilder.toString())
if (currentVars != null) {
val argsForBlock = varsBuilder.map { case (argType, argValue) =>
s"$argType $argValue"
}.mkString(", ")
vars.append(argsForBlock)
callArgs.append(varsBuilder.map(_._2).mkString(", "))
}

if (blocks.length == 1) {
// inline execution if only one block
blocks.head
} else {
val apply = freshName("apply")
val functions = blocks.zipWithIndex.map { case (body, i) =>
val functionArgs = if (currentVars != null) {
vars(i)
} else {
s"InternalRow $row"
}
val name = s"${apply}_$i"
val code = s"""
|private void $name(InternalRow $row) {
|private void $name($functionArgs) {
| $body
|}
""".stripMargin
addNewFunction(name, code)
name
if (currentVars != null) {
val args = callArgs(i)
s"$name($args);"
} else {
s"$name($row);"
}
}

functions.map(name => s"$name($row);").mkString("\n")
functions.mkString("\n")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro

s"""
$resetWriter
${ctx.splitExpressions(row, writeFields)}
${ctx.splitExpressions(row, writeFields, inputTypes)}
""".trim
}

Expand Down Expand Up @@ -327,6 +327,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Evaluate all the subexpression.
val evalSubexpr = ctx.subexprFunctions.mkString("\n")

ctx.currentVars = null
val writeExpressions =
writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
}
"""
}
val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes)
val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes, children.map(_.dataType))
val schemaField = ctx.addReferenceObj("schema", schema)
s"""
boolean ${ev.isNull} = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,7 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
// the generated code will be huge if there are too many columns
val hasTooManyOutputFields =
numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields
val hasTooManyInputFields =
plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields)
!willFallback && !hasTooManyOutputFields && !hasTooManyInputFields
!willFallback && !hasTooManyOutputFields
case _ => false
}

Expand Down