diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index ee7f4fadca89..f658823d74ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -473,19 +473,59 @@ class CodegenContext { * * @param row the variable name of row that is used by expressions * @param expressions the codes to evaluate expressions. + * @param dataTypes the dataType of expressions. When the input of expressions is currentVars + * instead of input row, we should provide this. Default value is null. */ - def splitExpressions(row: String, expressions: Seq[String]): String = { + def splitExpressions( + row: String, + expressions: Seq[String], + dataTypes: Seq[DataType] = null): String = { + if (currentVars != null && dataTypes == null) { + throw new IllegalArgumentException("cannot split expressions based on currentVars " + + "without expression types") + } + if (currentVars != null && currentVars.size != expressions.size) { + throw new IllegalArgumentException("cannot split expressions based on currentVars " + + "that doesn't have the same length as the given expressions.") + } val blocks = new ArrayBuffer[String]() val blockBuilder = new StringBuilder() - for (code <- expressions) { + val vars = new ArrayBuffer[String]() + val callArgs = new ArrayBuffer[String]() + val varsBuilder = new ArrayBuffer[(String, String)]() + for ((code, i) <- expressions.zipWithIndex) { // We can't know how many byte code will be generated, so use the number of bytes as limit - if (blockBuilder.length > 64 * 1000) { + // When we split expressions with currentVars instead of input row, we use the size of + // function call parameters as limit to prevent stackoverflow. + if ((currentVars != null && varsBuilder.size > 100) || blockBuilder.length > 64 * 1000) { blocks.append(blockBuilder.toString()) blockBuilder.clear() + + if (currentVars != null) { + val argsForBlock = varsBuilder.map { case (argType, argValue) => + s"$argType $argValue" + }.mkString(", ") + vars.append(argsForBlock) + callArgs.append(varsBuilder.map(_._2).mkString(", ")) + varsBuilder.clear() + } } + blockBuilder.append(code) + if (currentVars != null) { + val argType = javaType(dataTypes(i)) + varsBuilder.append((argType, currentVars(i).value)) + } } + blocks.append(blockBuilder.toString()) + if (currentVars != null) { + val argsForBlock = varsBuilder.map { case (argType, argValue) => + s"$argType $argValue" + }.mkString(", ") + vars.append(argsForBlock) + callArgs.append(varsBuilder.map(_._2).mkString(", ")) + } if (blocks.length == 1) { // inline execution if only one block @@ -493,17 +533,27 @@ class CodegenContext { } else { val apply = freshName("apply") val functions = blocks.zipWithIndex.map { case (body, i) => + val functionArgs = if (currentVars != null) { + vars(i) + } else { + s"InternalRow $row" + } val name = s"${apply}_$i" val code = s""" - |private void $name(InternalRow $row) { + |private void $name($functionArgs) { | $body |} """.stripMargin addNewFunction(name, code) - name + if (currentVars != null) { + val args = callArgs(i) + s"$name($args);" + } else { + s"$name($row);" + } } - functions.map(name => s"$name($row);").mkString("\n") + functions.mkString("\n") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 6aa9cbf08bdb..9108e6b4c6c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -164,7 +164,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" $resetWriter - ${ctx.splitExpressions(row, writeFields)} + ${ctx.splitExpressions(row, writeFields, inputTypes)} """.trim } @@ -327,6 +327,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Evaluate all the subexpression. val evalSubexpr = ctx.subexprFunctions.mkString("\n") + ctx.currentVars = null val writeExpressions = writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 26b1ff39b3e9..dfee527783da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -538,7 +538,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) } """ } - val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes) + val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes, children.map(_.dataType)) val schemaField = ctx.addReferenceObj("schema", schema) s""" boolean ${ev.isNull} = false; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 447dbe701815..bd8647dddb07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -448,9 +448,7 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { // the generated code will be huge if there are too many columns val hasTooManyOutputFields = numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields - val hasTooManyInputFields = - plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields) - !willFallback && !hasTooManyOutputFields && !hasTooManyInputFields + !willFallback && !hasTooManyOutputFields case _ => false }