Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ case class HashAggregateExec(

// The variables are used as aggregation buffers and each aggregate function has one or more
// ExprCode to initialize its buffer slots. Only used for aggregation without keys.
private var bufVars: Seq[Seq[ExprCode]] = _
private val bufVar = new ThreadLocal[Seq[Seq[ExprCode]]] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this too, but my concern is that we never promise the codegen to be thread-safe. We may have more places to fix, and we may break it again in the future.

override def initialValue(): Seq[Seq[ExprCode]] = {
Seq.empty
}
}

private def doProduceWithoutKeys(ctx: CodegenContext): String = {
val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
Expand All @@ -169,7 +173,7 @@ case class HashAggregateExec(
// generate variables for aggregation buffer
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val initExpr = functions.map(f => f.initialValues)
bufVars = initExpr.map { exprs =>
val localBufVars = initExpr.map { exprs =>
exprs.map { e =>
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull")
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
Expand All @@ -185,7 +189,8 @@ case class HashAggregateExec(
JavaCode.global(value, e.dataType))
}
}
val flatBufVars = bufVars.flatten
bufVars.set(localBufVars)
val flatBufVars = bufVars.get().flatten
val initBufVar = evaluateVariables(flatBufVars)

// generate variables for output
Expand Down Expand Up @@ -322,7 +327,7 @@ case class HashAggregateExec(
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
}
}
ctx.currentVars = bufVars.flatten ++ input
ctx.currentVars = bufVars.get().flatten ++ input
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
bindReferences(updateExprsForOneFunc, inputAttrs)
}
Expand All @@ -336,7 +341,7 @@ case class HashAggregateExec(

val aggNames = functions.map(_.prettyName)
val aggCodeBlocks = bufferEvals.zipWithIndex.map { case (bufferEvalsForOneFunc, i) =>
val bufVarsForOneFunc = bufVars(i)
val bufVarsForOneFunc = bufVars.get()(i)
// All the update code for aggregation buffers should be placed in the end
// of each aggregation function code.
val updates = bufferEvalsForOneFunc.zip(bufVarsForOneFunc).map { case (ev, bufVar) =>
Expand Down