Skip to content

Commit 353d328

Browse files
peter-tothcloud-fan
authored andcommitted
[SPARK-25768][SQL] fix constant argument expecting UDAFs
## What changes were proposed in this pull request? Without this PR some UDAFs like `GenericUDAFPercentileApprox` can throw an exception because expecting a constant parameter (object inspector) as a particular argument. The exception is thrown because `toPrettySQL` call in `ResolveAliases` analyzer rule transforms a `Literal` parameter to a `PrettyAttribute` which is then transformed to an `ObjectInspector` instead of a `ConstantObjectInspector`. The exception comes from `getEvaluator` method of `GenericUDAFPercentileApprox` that actually shouldn't be called during `toPrettySQL` transformation. The reason why it is called are the non lazy fields in `HiveUDAFFunction`. This PR makes all fields of `HiveUDAFFunction` lazy. ## How was this patch tested? added new UT Closes #22766 from peter-toth/SPARK-25768. Authored-by: Peter Toth <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit f38594f) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 61b301c commit 353d328

File tree

2 files changed

+42
-25
lines changed

2 files changed

+42
-25
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -340,39 +340,40 @@ private[hive] case class HiveUDAFFunction(
340340
resolver.getEvaluator(parameterInfo)
341341
}
342342

343-
// The UDAF evaluator used to consume raw input rows and produce partial aggregation results.
344-
@transient
345-
private lazy val partial1ModeEvaluator = newEvaluator()
343+
private case class HiveEvaluator(
344+
evaluator: GenericUDAFEvaluator,
345+
objectInspector: ObjectInspector)
346346

347+
// The UDAF evaluator used to consume raw input rows and produce partial aggregation results.
347348
// Hive `ObjectInspector` used to inspect partial aggregation results.
348349
@transient
349-
private val partialResultInspector = partial1ModeEvaluator.init(
350-
GenericUDAFEvaluator.Mode.PARTIAL1,
351-
inputInspectors
352-
)
350+
private lazy val partial1HiveEvaluator = {
351+
val evaluator = newEvaluator()
352+
HiveEvaluator(evaluator, evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputInspectors))
353+
}
353354

354355
// The UDAF evaluator used to merge partial aggregation results.
355356
@transient
356357
private lazy val partial2ModeEvaluator = {
357358
val evaluator = newEvaluator()
358-
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector))
359+
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partial1HiveEvaluator.objectInspector))
359360
evaluator
360361
}
361362

362363
// Spark SQL data type of partial aggregation results
363364
@transient
364-
private lazy val partialResultDataType = inspectorToDataType(partialResultInspector)
365+
private lazy val partialResultDataType =
366+
inspectorToDataType(partial1HiveEvaluator.objectInspector)
365367

366368
// The UDAF evaluator used to compute the final result from a partial aggregation result objects.
367-
@transient
368-
private lazy val finalModeEvaluator = newEvaluator()
369-
370369
// Hive `ObjectInspector` used to inspect the final aggregation result object.
371370
@transient
372-
private val returnInspector = finalModeEvaluator.init(
373-
GenericUDAFEvaluator.Mode.FINAL,
374-
Array(partialResultInspector)
375-
)
371+
private lazy val finalHiveEvaluator = {
372+
val evaluator = newEvaluator()
373+
HiveEvaluator(
374+
evaluator,
375+
evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1HiveEvaluator.objectInspector)))
376+
}
376377

377378
// Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
378379
@transient
@@ -381,7 +382,7 @@ private[hive] case class HiveUDAFFunction(
381382
// Unwrapper function used to unwrap final aggregation result objects returned by Hive UDAFs into
382383
// Spark SQL specific format.
383384
@transient
384-
private lazy val resultUnwrapper = unwrapperFor(returnInspector)
385+
private lazy val resultUnwrapper = unwrapperFor(finalHiveEvaluator.objectInspector)
385386

386387
@transient
387388
private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
@@ -391,7 +392,7 @@ private[hive] case class HiveUDAFFunction(
391392

392393
override def nullable: Boolean = true
393394

394-
override lazy val dataType: DataType = inspectorToDataType(returnInspector)
395+
override lazy val dataType: DataType = inspectorToDataType(finalHiveEvaluator.objectInspector)
395396

396397
override def prettyName: String = name
397398

@@ -401,13 +402,13 @@ private[hive] case class HiveUDAFFunction(
401402
}
402403

403404
override def createAggregationBuffer(): AggregationBuffer =
404-
partial1ModeEvaluator.getNewAggregationBuffer
405+
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
405406

406407
@transient
407408
private lazy val inputProjection = UnsafeProjection.create(children)
408409

409410
override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
410-
partial1ModeEvaluator.iterate(
411+
partial1HiveEvaluator.evaluator.iterate(
411412
buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
412413
buffer
413414
}
@@ -417,12 +418,12 @@ private[hive] case class HiveUDAFFunction(
417418
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
418419
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
419420
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
420-
partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input))
421+
partial2ModeEvaluator.merge(buffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
421422
buffer
422423
}
423424

424425
override def eval(buffer: AggregationBuffer): Any = {
425-
resultUnwrapper(finalModeEvaluator.terminate(buffer))
426+
resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer))
426427
}
427428

428429
override def serialize(buffer: AggregationBuffer): Array[Byte] = {
@@ -439,9 +440,10 @@ private[hive] case class HiveUDAFFunction(
439440

440441
// Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
441442
private class AggregationBufferSerDe {
442-
private val partialResultUnwrapper = unwrapperFor(partialResultInspector)
443+
private val partialResultUnwrapper = unwrapperFor(partial1HiveEvaluator.objectInspector)
443444

444-
private val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType)
445+
private val partialResultWrapper =
446+
wrapperFor(partial1HiveEvaluator.objectInspector, partialResultDataType)
445447

446448
private val projection = UnsafeProjection.create(Array(partialResultDataType))
447449

@@ -451,7 +453,8 @@ private[hive] case class HiveUDAFFunction(
451453
// `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
452454
// that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
453455
// Then we can unwrap it to a Spark SQL value.
454-
mutableRow.update(0, partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer)))
456+
mutableRow.update(0, partialResultUnwrapper(
457+
partial1HiveEvaluator.evaluator.terminatePartial(buffer)))
455458
val unsafeRow = projection(mutableRow)
456459
val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes)
457460
unsafeRow.writeTo(bytes)

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,20 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
638638
Row(3) :: Row(3) :: Nil)
639639
}
640640
}
641+
642+
test("SPARK-25768 constant argument expecting Hive UDF") {
643+
withTempView("inputTable") {
644+
spark.range(10).createOrReplaceTempView("inputTable")
645+
withUserDefinedFunction("testGenericUDAFPercentileApprox" -> false) {
646+
val numFunc = spark.catalog.listFunctions().count()
647+
sql(s"CREATE FUNCTION testGenericUDAFPercentileApprox AS '" +
648+
s"${classOf[GenericUDAFPercentileApprox].getName}'")
649+
checkAnswer(
650+
sql("SELECT testGenericUDAFPercentileApprox(id, 0.5) FROM inputTable"),
651+
Seq(Row(4.0)))
652+
}
653+
}
654+
}
641655
}
642656

643657
class TestPair(x: Int, y: Int) extends Writable with Serializable {

0 commit comments

Comments
 (0)