diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 33b9b804fc60..a3656e514671 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -116,8 +116,10 @@ class CodegenContext { /** * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a - * 3-tuple: java type, variable name, code to init it. - * As an example, ("int", "count", "count = 0;") will produce code: + * 4-tuple: java type, variable name, code to init it, flag marked if it can be replaced. + * If a state is a WholeStageCodegen result variable, the flag will be true to mark it as + * substitutable for the following generated one. + * As an example, ("int", "count", "count = 0;", false) will produce code: * {{{ * private int count; * }}} @@ -129,11 +131,17 @@ class CodegenContext { * * They will be kept as member variables in generated classes like `SpecificProjection`. */ - val mutableStates: mutable.ArrayBuffer[(String, String, String)] = - mutable.ArrayBuffer.empty[(String, String, String)] + private var mutableStates: mutable.ArrayBuffer[(String, String, String, Boolean)] = + mutable.ArrayBuffer.empty[(String, String, String, Boolean)] - def addMutableState(javaType: String, variableName: String, initCode: String): Unit = { - mutableStates += ((javaType, variableName, initCode)) + def addMutableState(javaType: String, variableName: String, + initCode: String, isWholeStageResultVar: Boolean = false): Unit = { + if (isWholeStageResultVar) { + mutableStates = mutableStates.filterNot(state => (state._1 == javaType) && state._4) + mutableStates += ((javaType, variableName, initCode, isWholeStageResultVar)) + } else { + mutableStates += ((javaType, variableName, initCode, isWholeStageResultVar)) + } } /** @@ -155,7 +163,7 @@ class CodegenContext { def declareMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - mutableStates.distinct.map { case (javaType, variableName, _) => + mutableStates.distinct.map { case (javaType, variableName, _, _) => s"private $javaType $variableName;" }.mkString("\n") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 5efba4b3a608..de5c62b2465a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -71,11 +71,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro inputs: Seq[ExprCode], inputTypes: Seq[DataType], bufferHolder: String, - isTopLevel: Boolean = false): String = { + isTopLevel: Boolean = false, + isWholeStageResultVar: Boolean = false): String = { val rowWriterClass = classOf[UnsafeRowWriter].getName val rowWriter = ctx.freshName("rowWriter") ctx.addMutableState(rowWriterClass, rowWriter, - s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") + s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});", + isWholeStageResultVar) val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, @@ -295,7 +297,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def createCode( ctx: CodegenContext, expressions: Seq[Expression], - useSubexprElimination: Boolean = false): ExprCode = { + useSubexprElimination: Boolean = false, + isWholeStageResultVar: Boolean = false): ExprCode = { val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) val exprTypes = expressions.map(_.dataType) @@ -306,12 +309,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val result = ctx.freshName("result") - ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") + ctx.addMutableState("UnsafeRow", result, + s"$result = new UnsafeRow(${expressions.length});", isWholeStageResultVar) val holder = ctx.freshName("holder") val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, holder, - s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});") + s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});", isWholeStageResultVar) val resetBufferHolder = if (numVarLenFields == 0) { "" @@ -328,7 +332,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val evalSubexpr = ctx.subexprFunctions.mkString("\n") val writeExpressions = - writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) + writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, + exprEvals, exprTypes, holder, isTopLevel = true, isWholeStageResultVar) val code = s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index fb57ed7692de..0a47b68cc3a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -133,7 +133,7 @@ trait CodegenSupport extends SparkPlan { // generate the code to create a UnsafeRow ctx.INPUT_ROW = row ctx.currentVars = outputVars - val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false, true) val code = s""" |$evaluateInputs |${ev.code.trim}