diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 52d0450afb181..100048f603ad0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -159,7 +159,11 @@ case class HashAggregateExec( // The variables are used as aggregation buffers and each aggregate function has one or more // ExprCode to initialize its buffer slots. Only used for aggregation without keys. - private var bufVars: Seq[Seq[ExprCode]] = _ + private val bufVar = new ThreadLocal[Seq[Seq[ExprCode]]] { + override def initialValue(): Seq[Seq[ExprCode]] = { + Seq.empty + } + } private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") @@ -169,7 +173,7 @@ case class HashAggregateExec( // generate variables for aggregation buffer val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.map(f => f.initialValues) - bufVars = initExpr.map { exprs => + val localBufVars = initExpr.map { exprs => exprs.map { e => val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") @@ -185,7 +189,8 @@ case class HashAggregateExec( JavaCode.global(value, e.dataType)) } } - val flatBufVars = bufVars.flatten + bufVars.set(localBufVars) + val flatBufVars = bufVars.get().flatten val initBufVar = evaluateVariables(flatBufVars) // generate variables for output @@ -322,7 +327,7 @@ case class HashAggregateExec( e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } - ctx.currentVars = bufVars.flatten ++ input + ctx.currentVars = bufVars.get().flatten ++ input val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } @@ -336,7 +341,7 @@ case class HashAggregateExec( val aggNames = functions.map(_.prettyName) val aggCodeBlocks = bufferEvals.zipWithIndex.map { case (bufferEvalsForOneFunc, i) => - val bufVarsForOneFunc = bufVars(i) + val bufVarsForOneFunc = bufVars.get()(i) // All the update code for aggregation buffers should be placed in the end // of each aggregation function code. val updates = bufferEvalsForOneFunc.zip(bufVarsForOneFunc).map { case (ev, bufVar) =>