Skip to content

Commit b040a83

Browse files
lianchengyhuai
authored andcommitted
[SPARK-18186][SC-5406][BRANCH-2.1] Migrate HiveUDAFFunction to TypedImperativeAggregate for partial aggregation support
While being evaluated in Spark SQL, Hive UDAFs don't support partial aggregation. This PR migrates `HiveUDAFFunction`s to `TypedImperativeAggregate`, which already provides partial aggregation support for aggregate functions that may use arbitrary Java objects as aggregation states. The following snippet shows the effect of this PR: ```scala import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax sql(s"CREATE FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") spark.range(100).createOrReplaceTempView("t") // A query using both Spark SQL native `max` and Hive `max` sql(s"SELECT max(id), hive_max(id) FROM t").explain() ``` Before this PR: ``` == Physical Plan == SortAggregate(key=[], functions=[max(id#1L), default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax7475f57e), id#1L, false, 0, 0)]) +- Exchange SinglePartition +- *Range (0, 100, step=1, splits=Some(1)) ``` After this PR: ``` == Physical Plan == SortAggregate(key=[], functions=[max(id#1L), default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax5e18a6a7), id#1L, false, 0, 0)]) +- Exchange SinglePartition +- SortAggregate(key=[], functions=[partial_max(id#1L), partial_default.hive_max(default.hive_max, HiveFunctionWrapper(org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax,org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax5e18a6a7), id#1L, false, 0, 0)]) +- *Range (0, 100, step=1, splits=Some(1)) ``` The tricky part of the PR is mostly about updating and passing around aggregation states of `HiveUDAFFunction`s since the aggregation state of a Hive UDAF may appear in three different forms. Let's take a look at the testing `MockUDAF` added in this PR as an example. This UDAF computes the count of non-null values together with the count of nulls of a given column. Its aggregation state may appear as the following forms at different time: 1. A `MockUDAFBuffer`, which is a concrete subclass of `GenericUDAFEvaluator.AggregationBuffer` The form used by Hive UDAF API. This form is required by the following scenarios: - Calling `GenericUDAFEvaluator.iterate()` to update an existing aggregation state with new input values. - Calling `GenericUDAFEvaluator.terminate()` to get the final aggregated value from an existing aggregation state. - Calling `GenericUDAFEvaluator.merge()` to merge other aggregation states into an existing aggregation state. The existing aggregation state to be updated must be in this form. Conversions: - To form 2: `GenericUDAFEvaluator.terminatePartial()` - To form 3: Convert to form 2 first, and then to 3. 2. An `Object[]` array containing two `java.lang.Long` values. The form used to interact with Hive's `ObjectInspector`s. This form is required by the following scenarios: - Calling `GenericUDAFEvaluator.terminatePartial()` to convert an existing aggregation state in form 1 to form 2. - Calling `GenericUDAFEvaluator.merge()` to merge other aggregation states into an existing aggregation state. The input aggregation state must be in this form. Conversions: - To form 1: No direct method. Have to create an empty `AggregationBuffer` and merge it into the empty buffer. - To form 3: `unwrapperFor()`/`unwrap()` method of `HiveInspectors` 3. The byte array that holds data of an `UnsafeRow` with two `LongType` fields. The form used by Spark SQL to shuffle partial aggregation results. This form is required because `TypedImperativeAggregate` always asks its subclasses to serialize their aggregation states into a byte array. Conversions: - To form 1: Convert to form 2 first, and then to 1. - To form 2: `wrapperFor()`/`wrap()` method of `HiveInspectors` Here're some micro-benchmark results produced by the most recent master and this PR branch. Master: ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ w/o groupBy 339 / 372 3.1 323.2 1.0X w/ groupBy 503 / 529 2.1 479.7 0.7X ``` This PR: ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ w/o groupBy 116 / 126 9.0 110.8 1.0X w/ groupBy 151 / 159 6.9 144.0 0.8X ``` Benchmark code snippet: ```scala test("Hive UDAF benchmark") { val N = 1 << 20 sparkSession.sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") val benchmark = new Benchmark( name = "hive udaf vs spark af", valuesPerIteration = N, minNumIters = 5, warmupTime = 5.seconds, minTime = 5.seconds, outputPerIteration = true ) benchmark.addCase("w/o groupBy") { _ => sparkSession.range(N).agg("id" -> "hive_max").collect() } benchmark.addCase("w/ groupBy") { _ => sparkSession.range(N).groupBy($"id" % 10).agg("id" -> "hive_max").collect() } benchmark.run() sparkSession.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") } ``` New test suite `HiveUDAFSuite` is added. Author: Cheng Lian <liandatabricks.com> Closes apache#15703 from liancheng/partial-agg-hive-udaf. Author: Cheng Lian <[email protected]> Closes apache#144 from yhuai/branch-2.1-hive-udaf.
1 parent eaf733d commit b040a83

File tree

2 files changed

+301
-50
lines changed

2 files changed

+301
-50
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala

Lines changed: 149 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,18 @@
1717

1818
package org.apache.spark.sql.hive
1919

20+
import java.nio.ByteBuffer
21+
2022
import scala.collection.JavaConverters._
2123
import scala.collection.mutable.ArrayBuffer
2224

2325
import org.apache.hadoop.hive.ql.exec._
2426
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
2527
import org.apache.hadoop.hive.ql.udf.generic._
28+
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
2629
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
2730
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
28-
import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector,
29-
ObjectInspectorFactory}
31+
import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory}
3032
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
3133

3234
import org.apache.spark.internal.Logging
@@ -58,7 +60,7 @@ case class HiveSimpleUDF(
5860

5961
@transient
6062
private lazy val isUDFDeterministic = {
61-
val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
63+
val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
6264
udfType != null && udfType.deterministic() && !udfType.stateful()
6365
}
6466

@@ -75,7 +77,7 @@ case class HiveSimpleUDF(
7577

7678
@transient
7779
lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector(
78-
method.getGenericReturnType(), ObjectInspectorOptions.JAVA))
80+
method.getGenericReturnType, ObjectInspectorOptions.JAVA))
7981

8082
@transient
8183
private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
@@ -263,8 +265,35 @@ case class HiveGenericUDTF(
263265
}
264266

265267
/**
266-
* Currently we don't support partial aggregation for queries using Hive UDAF, which may hurt
267-
* performance a lot.
268+
* While being evaluated by Spark SQL, the aggregation state of a Hive UDAF may be in the following
269+
* three formats:
270+
*
271+
* 1. An instance of some concrete `GenericUDAFEvaluator.AggregationBuffer` class
272+
*
273+
* This is the native Hive representation of an aggregation state. Hive `GenericUDAFEvaluator`
274+
* methods like `iterate()`, `merge()`, `terminatePartial()`, and `terminate()` use this format.
275+
* We call these methods to evaluate Hive UDAFs.
276+
*
277+
* 2. A Java object that can be inspected using the `ObjectInspector` returned by the
278+
* `GenericUDAFEvaluator.init()` method.
279+
*
280+
* Hive uses this format to produce a serializable aggregation state so that it can shuffle
281+
* partial aggregation results. Whenever we need to convert a Hive `AggregationBuffer` instance
282+
* into a Spark SQL value, we have to convert it to this format first and then do the conversion
283+
* with the help of `ObjectInspector`s.
284+
*
285+
* 3. A Spark SQL value
286+
*
287+
* We use this format for serializing Hive UDAF aggregation states on Spark side. To be more
288+
* specific, we convert `AggregationBuffer`s into equivalent Spark SQL values, write them into
289+
* `UnsafeRow`s, and then retrieve the byte array behind those `UnsafeRow`s as serialization
290+
* results.
291+
*
292+
* We may use the following methods to convert the aggregation state back and forth:
293+
*
294+
* - `wrap()`/`wrapperFor()`: from 3 to 1
295+
* - `unwrap()`/`unwrapperFor()`: from 1 to 3
296+
* - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3
268297
*/
269298
case class HiveUDAFFunction(
270299
name: String,
@@ -273,89 +302,89 @@ case class HiveUDAFFunction(
273302
isUDAFBridgeRequired: Boolean = false,
274303
mutableAggBufferOffset: Int = 0,
275304
inputAggBufferOffset: Int = 0)
276-
extends ImperativeAggregate with HiveInspectors {
305+
extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors {
277306

278307
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
279308
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
280309

281310
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
282311
copy(inputAggBufferOffset = newInputAggBufferOffset)
283312

313+
// Hive `ObjectInspector`s for all child expressions (input parameters of the function).
284314
@transient
285-
private lazy val resolver =
286-
if (isUDAFBridgeRequired) {
315+
private lazy val inputInspectors = children.map(toInspector).toArray
316+
317+
// Spark SQL data types of input parameters.
318+
@transient
319+
private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
320+
321+
private def newEvaluator(): GenericUDAFEvaluator = {
322+
val resolver = if (isUDAFBridgeRequired) {
287323
new GenericUDAFBridge(funcWrapper.createFunction[UDAF]())
288324
} else {
289325
funcWrapper.createFunction[AbstractGenericUDAFResolver]()
290326
}
291327

292-
@transient
293-
private lazy val inspectors = children.map(toInspector).toArray
294-
295-
@transient
296-
private lazy val functionAndInspector = {
297-
val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false)
298-
val f = resolver.getEvaluator(parameterInfo)
299-
f -> f.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
328+
val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false)
329+
resolver.getEvaluator(parameterInfo)
300330
}
301331

332+
// The UDAF evaluator used to consume raw input rows and produce partial aggregation results.
302333
@transient
303-
private lazy val function = functionAndInspector._1
334+
private lazy val partial1ModeEvaluator = newEvaluator()
304335

336+
// Hive `ObjectInspector` used to inspect partial aggregation results.
305337
@transient
306-
private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
338+
private val partialResultInspector = partial1ModeEvaluator.init(
339+
GenericUDAFEvaluator.Mode.PARTIAL1,
340+
inputInspectors
341+
)
307342

343+
// The UDAF evaluator used to merge partial aggregation results.
308344
@transient
309-
private lazy val returnInspector = functionAndInspector._2
345+
private lazy val partial2ModeEvaluator = {
346+
val evaluator = newEvaluator()
347+
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector))
348+
evaluator
349+
}
310350

351+
// Spark SQL data type of partial aggregation results
311352
@transient
312-
private lazy val unwrapper = unwrapperFor(returnInspector)
353+
private lazy val partialResultDataType = inspectorToDataType(partialResultInspector)
313354

355+
// The UDAF evaluator used to compute the final result from a partial aggregation result objects.
314356
@transient
315-
private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _
316-
317-
override def eval(input: InternalRow): Any = unwrapper(function.evaluate(buffer))
357+
private lazy val finalModeEvaluator = newEvaluator()
318358

359+
// Hive `ObjectInspector` used to inspect the final aggregation result object.
319360
@transient
320-
private lazy val inputProjection = new InterpretedProjection(children)
361+
private val returnInspector = finalModeEvaluator.init(
362+
GenericUDAFEvaluator.Mode.FINAL,
363+
Array(partialResultInspector)
364+
)
321365

366+
// Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
322367
@transient
323-
private lazy val cached = new Array[AnyRef](children.length)
368+
private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
324369

370+
// Unwrapper function used to unwrap final aggregation result objects returned by Hive UDAFs into
371+
// Spark SQL specific format.
325372
@transient
326-
private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray
327-
328-
// Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation
329-
// buffer for it.
330-
override def aggBufferSchema: StructType = StructType(Nil)
331-
332-
override def update(_buffer: InternalRow, input: InternalRow): Unit = {
333-
val inputs = inputProjection(input)
334-
function.iterate(buffer, wrap(inputs, wrappers, cached, inputDataTypes))
335-
}
336-
337-
override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = {
338-
throw new UnsupportedOperationException(
339-
"Hive UDAF doesn't support partial aggregate")
340-
}
373+
private lazy val resultUnwrapper = unwrapperFor(returnInspector)
341374

342-
override def initialize(_buffer: InternalRow): Unit = {
343-
buffer = function.getNewAggregationBuffer
344-
}
345-
346-
override val aggBufferAttributes: Seq[AttributeReference] = Nil
375+
@transient
376+
private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
347377

348-
// Note: although this simply copies aggBufferAttributes, this common code can not be placed
349-
// in the superclass because that will lead to initialization ordering issues.
350-
override val inputAggBufferAttributes: Seq[AttributeReference] = Nil
378+
@transient
379+
private lazy val aggBufferSerDe: AggregationBufferSerDe = new AggregationBufferSerDe
351380

352381
// We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our
353382
// catalyst type checking framework.
354383
override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType)
355384

356385
override def nullable: Boolean = true
357386

358-
override def supportsPartial: Boolean = false
387+
override def supportsPartial: Boolean = true
359388

360389
override lazy val dataType: DataType = inspectorToDataType(returnInspector)
361390

@@ -365,4 +394,74 @@ case class HiveUDAFFunction(
365394
val distinct = if (isDistinct) "DISTINCT " else " "
366395
s"$name($distinct${children.map(_.sql).mkString(", ")})"
367396
}
397+
398+
override def createAggregationBuffer(): AggregationBuffer =
399+
partial1ModeEvaluator.getNewAggregationBuffer
400+
401+
@transient
402+
private lazy val inputProjection = UnsafeProjection.create(children)
403+
404+
override def update(buffer: AggregationBuffer, input: InternalRow): Unit = {
405+
partial1ModeEvaluator.iterate(
406+
buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
407+
}
408+
409+
override def merge(buffer: AggregationBuffer, input: AggregationBuffer): Unit = {
410+
// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
411+
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
412+
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
413+
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
414+
partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input))
415+
}
416+
417+
override def eval(buffer: AggregationBuffer): Any = {
418+
resultUnwrapper(finalModeEvaluator.terminate(buffer))
419+
}
420+
421+
override def serialize(buffer: AggregationBuffer): Array[Byte] = {
422+
// Serializes an `AggregationBuffer` that holds partial aggregation results so that we can
423+
// shuffle it for global aggregation later.
424+
aggBufferSerDe.serialize(buffer)
425+
}
426+
427+
override def deserialize(bytes: Array[Byte]): AggregationBuffer = {
428+
// Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare
429+
// for global aggregation by merging multiple partial aggregation results within a single group.
430+
aggBufferSerDe.deserialize(bytes)
431+
}
432+
433+
// Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
434+
private class AggregationBufferSerDe {
435+
private val partialResultUnwrapper = unwrapperFor(partialResultInspector)
436+
437+
private val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType)
438+
439+
private val projection = UnsafeProjection.create(Array(partialResultDataType))
440+
441+
private val mutableRow = new GenericInternalRow(1)
442+
443+
def serialize(buffer: AggregationBuffer): Array[Byte] = {
444+
// `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
445+
// that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
446+
// Then we can unwrap it to a Spark SQL value.
447+
mutableRow.update(0, partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer)))
448+
val unsafeRow = projection(mutableRow)
449+
val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes)
450+
unsafeRow.writeTo(bytes)
451+
bytes.array()
452+
}
453+
454+
def deserialize(bytes: Array[Byte]): AggregationBuffer = {
455+
// `GenericUDAFEvaluator` doesn't provide any method that is capable to convert an object
456+
// returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The
457+
// workaround here is creating an initial `AggregationBuffer` first and then merge the
458+
// deserialized object into the buffer.
459+
val buffer = partial2ModeEvaluator.getNewAggregationBuffer
460+
val unsafeRow = new UnsafeRow(1)
461+
unsafeRow.pointTo(bytes, bytes.length)
462+
val partialResult = unsafeRow.get(0, partialResultDataType)
463+
partial2ModeEvaluator.merge(buffer, partialResultWrapper(partialResult))
464+
buffer
465+
}
466+
}
368467
}

0 commit comments

Comments
 (0)