@@ -340,39 +340,40 @@ private[hive] case class HiveUDAFFunction(
340340 resolver.getEvaluator(parameterInfo)
341341 }
342342
343- // The UDAF evaluator used to consume raw input rows and produce partial aggregation results.
344- @ transient
345- private lazy val partial1ModeEvaluator = newEvaluator( )
343+ private case class HiveEvaluator (
344+ evaluator : GenericUDAFEvaluator ,
345+ objectInspector : ObjectInspector )
346346
347+ // The UDAF evaluator used to consume raw input rows and produce partial aggregation results.
347348 // Hive `ObjectInspector` used to inspect partial aggregation results.
348349 @ transient
349- private val partialResultInspector = partial1ModeEvaluator.init(
350- GenericUDAFEvaluator . Mode . PARTIAL1 ,
351- inputInspectors
352- )
350+ private lazy val partial1HiveEvaluator = {
351+ val evaluator = newEvaluator()
352+ HiveEvaluator (evaluator, evaluator.init( GenericUDAFEvaluator . Mode . PARTIAL1 , inputInspectors))
353+ }
353354
354355 // The UDAF evaluator used to merge partial aggregation results.
355356 @ transient
356357 private lazy val partial2ModeEvaluator = {
357358 val evaluator = newEvaluator()
358- evaluator.init(GenericUDAFEvaluator .Mode .PARTIAL2 , Array (partialResultInspector ))
359+ evaluator.init(GenericUDAFEvaluator .Mode .PARTIAL2 , Array (partial1HiveEvaluator.objectInspector ))
359360 evaluator
360361 }
361362
362363 // Spark SQL data type of partial aggregation results
363364 @ transient
364- private lazy val partialResultDataType = inspectorToDataType(partialResultInspector)
365+ private lazy val partialResultDataType =
366+ inspectorToDataType(partial1HiveEvaluator.objectInspector)
365367
366368 // The UDAF evaluator used to compute the final result from a partial aggregation result objects.
367- @ transient
368- private lazy val finalModeEvaluator = newEvaluator()
369-
370369 // Hive `ObjectInspector` used to inspect the final aggregation result object.
371370 @ transient
372- private val returnInspector = finalModeEvaluator.init(
373- GenericUDAFEvaluator .Mode .FINAL ,
374- Array (partialResultInspector)
375- )
371+ private lazy val finalHiveEvaluator = {
372+ val evaluator = newEvaluator()
373+ HiveEvaluator (
374+ evaluator,
375+ evaluator.init(GenericUDAFEvaluator .Mode .FINAL , Array (partial1HiveEvaluator.objectInspector)))
376+ }
376377
377378 // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
378379 @ transient
@@ -381,7 +382,7 @@ private[hive] case class HiveUDAFFunction(
381382 // Unwrapper function used to unwrap final aggregation result objects returned by Hive UDAFs into
382383 // Spark SQL specific format.
383384 @ transient
384- private lazy val resultUnwrapper = unwrapperFor(returnInspector )
385+ private lazy val resultUnwrapper = unwrapperFor(finalHiveEvaluator.objectInspector )
385386
386387 @ transient
387388 private lazy val cached : Array [AnyRef ] = new Array [AnyRef ](children.length)
@@ -391,7 +392,7 @@ private[hive] case class HiveUDAFFunction(
391392
392393 override def nullable : Boolean = true
393394
394- override lazy val dataType : DataType = inspectorToDataType(returnInspector )
395+ override lazy val dataType : DataType = inspectorToDataType(finalHiveEvaluator.objectInspector )
395396
396397 override def prettyName : String = name
397398
@@ -401,13 +402,13 @@ private[hive] case class HiveUDAFFunction(
401402 }
402403
403404 override def createAggregationBuffer (): AggregationBuffer =
404- partial1ModeEvaluator .getNewAggregationBuffer
405+ partial1HiveEvaluator.evaluator .getNewAggregationBuffer
405406
406407 @ transient
407408 private lazy val inputProjection = UnsafeProjection .create(children)
408409
409410 override def update (buffer : AggregationBuffer , input : InternalRow ): AggregationBuffer = {
410- partial1ModeEvaluator .iterate(
411+ partial1HiveEvaluator.evaluator .iterate(
411412 buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
412413 buffer
413414 }
@@ -417,12 +418,12 @@ private[hive] case class HiveUDAFFunction(
417418 // buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
418419 // this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
419420 // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
420- partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator .terminatePartial(input))
421+ partial2ModeEvaluator.merge(buffer, partial1HiveEvaluator.evaluator .terminatePartial(input))
421422 buffer
422423 }
423424
424425 override def eval (buffer : AggregationBuffer ): Any = {
425- resultUnwrapper(finalModeEvaluator .terminate(buffer))
426+ resultUnwrapper(finalHiveEvaluator.evaluator .terminate(buffer))
426427 }
427428
428429 override def serialize (buffer : AggregationBuffer ): Array [Byte ] = {
@@ -439,9 +440,10 @@ private[hive] case class HiveUDAFFunction(
439440
440441 // Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
441442 private class AggregationBufferSerDe {
442- private val partialResultUnwrapper = unwrapperFor(partialResultInspector )
443+ private val partialResultUnwrapper = unwrapperFor(partial1HiveEvaluator.objectInspector )
443444
444- private val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType)
445+ private val partialResultWrapper =
446+ wrapperFor(partial1HiveEvaluator.objectInspector, partialResultDataType)
445447
446448 private val projection = UnsafeProjection .create(Array (partialResultDataType))
447449
@@ -451,7 +453,8 @@ private[hive] case class HiveUDAFFunction(
451453 // `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
452454 // that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
453455 // Then we can unwrap it to a Spark SQL value.
454- mutableRow.update(0 , partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer)))
456+ mutableRow.update(0 , partialResultUnwrapper(
457+ partial1HiveEvaluator.evaluator.terminatePartial(buffer)))
455458 val unsafeRow = projection(mutableRow)
456459 val bytes = ByteBuffer .allocate(unsafeRow.getSizeInBytes)
457460 unsafeRow.writeTo(bytes)
0 commit comments