Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ class CodegenContext {

/**
* Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a
* 3-tuple: java type, variable name, code to init it.
* As an example, ("int", "count", "count = 0;") will produce code:
* 4-tuple: java type, variable name, code to init it, flag marked if it can be replaced.
* If a state is a WholeStageCodegen result variable, the flag will be true to mark it as
* substitutable for the following generated one.
* As an example, ("int", "count", "count = 0;", false) will produce code:
* {{{
* private int count;
* }}}
Expand All @@ -129,11 +131,17 @@ class CodegenContext {
*
* They will be kept as member variables in generated classes like `SpecificProjection`.
*/
val mutableStates: mutable.ArrayBuffer[(String, String, String)] =
mutable.ArrayBuffer.empty[(String, String, String)]
private var mutableStates: mutable.ArrayBuffer[(String, String, String, Boolean)] =
mutable.ArrayBuffer.empty[(String, String, String, Boolean)]

def addMutableState(javaType: String, variableName: String, initCode: String): Unit = {
mutableStates += ((javaType, variableName, initCode))
def addMutableState(javaType: String, variableName: String,
initCode: String, isWholeStageResultVar: Boolean = false): Unit = {
if (isWholeStageResultVar) {
mutableStates = mutableStates.filterNot(state => (state._1 == javaType) && state._4)
mutableStates += ((javaType, variableName, initCode, isWholeStageResultVar))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you're effectively letting later-added states to override earlier-added ones that are equivalent and replaceable. Could that be done in a add-if-absent way instead?
Also, what's the criteria for a state to be "replaceable" in your design? Could you please make that explicit as comment in the code?

} else {
mutableStates += ((javaType, variableName, initCode, isWholeStageResultVar))
}
}

/**
Expand All @@ -155,7 +163,7 @@ class CodegenContext {
def declareMutableStates(): String = {
// It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
mutableStates.distinct.map { case (javaType, variableName, _) =>
mutableStates.distinct.map { case (javaType, variableName, _, _) =>
s"private $javaType $variableName;"
}.mkString("\n")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
inputs: Seq[ExprCode],
inputTypes: Seq[DataType],
bufferHolder: String,
isTopLevel: Boolean = false): String = {
isTopLevel: Boolean = false,
isWholeStageResultVar: Boolean = false): String = {
val rowWriterClass = classOf[UnsafeRowWriter].getName
val rowWriter = ctx.freshName("rowWriter")
ctx.addMutableState(rowWriterClass, rowWriter,
s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});")
s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});",
isWholeStageResultVar)

val resetWriter = if (isTopLevel) {
// For top level row writer, it always writes to the beginning of the global buffer holder,
Expand Down Expand Up @@ -295,7 +297,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
def createCode(
ctx: CodegenContext,
expressions: Seq[Expression],
useSubexprElimination: Boolean = false): ExprCode = {
useSubexprElimination: Boolean = false,
isWholeStageResultVar: Boolean = false): ExprCode = {
val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
val exprTypes = expressions.map(_.dataType)

Expand All @@ -306,12 +309,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}

val result = ctx.freshName("result")
ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});")
ctx.addMutableState("UnsafeRow", result,
s"$result = new UnsafeRow(${expressions.length});", isWholeStageResultVar)

val holder = ctx.freshName("holder")
val holderClass = classOf[BufferHolder].getName
ctx.addMutableState(holderClass, holder,
s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});")
s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});", isWholeStageResultVar)

val resetBufferHolder = if (numVarLenFields == 0) {
""
Expand All @@ -328,7 +332,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val evalSubexpr = ctx.subexprFunctions.mkString("\n")

val writeExpressions =
writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true)
writeExpressionsToBuffer(ctx, ctx.INPUT_ROW,
exprEvals, exprTypes, holder, isTopLevel = true, isWholeStageResultVar)

val code =
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ trait CodegenSupport extends SparkPlan {
// generate the code to create a UnsafeRow
ctx.INPUT_ROW = row
ctx.currentVars = outputVars
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false, true)
val code = s"""
|$evaluateInputs
|${ev.code.trim}
Expand Down