Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 28 additions & 25 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -340,39 +340,40 @@ private[hive] case class HiveUDAFFunction(
resolver.getEvaluator(parameterInfo)
}

// The UDAF evaluator used to consume raw input rows and produce partial aggregation results.
@transient
private lazy val partial1ModeEvaluator = newEvaluator()
private case class HiveEvaluator(
evaluator: GenericUDAFEvaluator,
objectInspector: ObjectInspector)

// The UDAF evaluator used to consume raw input rows and produce partial aggregation results.
// Hive `ObjectInspector` used to inspect partial aggregation results.
@transient
private val partialResultInspector = partial1ModeEvaluator.init(
GenericUDAFEvaluator.Mode.PARTIAL1,
inputInspectors
)
private lazy val partial1HiveEvaluator = {
val evaluator = newEvaluator()
HiveEvaluator(evaluator, evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputInspectors))
}

// The UDAF evaluator used to merge partial aggregation results.
@transient
private lazy val partial2ModeEvaluator = {
val evaluator = newEvaluator()
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector))
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partial1HiveEvaluator.objectInspector))
evaluator
}

// Spark SQL data type of partial aggregation results
@transient
private lazy val partialResultDataType = inspectorToDataType(partialResultInspector)
private lazy val partialResultDataType =
inspectorToDataType(partial1HiveEvaluator.objectInspector)

// The UDAF evaluator used to compute the final result from a partial aggregation result objects.
@transient
private lazy val finalModeEvaluator = newEvaluator()

// Hive `ObjectInspector` used to inspect the final aggregation result object.
@transient
private val returnInspector = finalModeEvaluator.init(
GenericUDAFEvaluator.Mode.FINAL,
Array(partialResultInspector)
)
private lazy val finalHiveEvaluator = {
val evaluator = newEvaluator()
HiveEvaluator(
evaluator,
evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1HiveEvaluator.objectInspector)))
}

// Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
@transient
Expand All @@ -381,7 +382,7 @@ private[hive] case class HiveUDAFFunction(
// Unwrapper function used to unwrap final aggregation result objects returned by Hive UDAFs into
// Spark SQL specific format.
@transient
private lazy val resultUnwrapper = unwrapperFor(returnInspector)
private lazy val resultUnwrapper = unwrapperFor(finalHiveEvaluator.objectInspector)

@transient
private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
Expand All @@ -391,7 +392,7 @@ private[hive] case class HiveUDAFFunction(

override def nullable: Boolean = true

override lazy val dataType: DataType = inspectorToDataType(returnInspector)
override lazy val dataType: DataType = inspectorToDataType(finalHiveEvaluator.objectInspector)

override def prettyName: String = name

Expand All @@ -401,13 +402,13 @@ private[hive] case class HiveUDAFFunction(
}

override def createAggregationBuffer(): AggregationBuffer =
partial1ModeEvaluator.getNewAggregationBuffer
partial1HiveEvaluator.evaluator.getNewAggregationBuffer

@transient
private lazy val inputProjection = UnsafeProjection.create(children)

override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
partial1ModeEvaluator.iterate(
partial1HiveEvaluator.evaluator.iterate(
buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
buffer
}
Expand All @@ -417,12 +418,12 @@ private[hive] case class HiveUDAFFunction(
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input))
partial2ModeEvaluator.merge(buffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
buffer
}

override def eval(buffer: AggregationBuffer): Any = {
resultUnwrapper(finalModeEvaluator.terminate(buffer))
resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer))
}

override def serialize(buffer: AggregationBuffer): Array[Byte] = {
Expand All @@ -439,9 +440,10 @@ private[hive] case class HiveUDAFFunction(

// Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
private class AggregationBufferSerDe {
private val partialResultUnwrapper = unwrapperFor(partialResultInspector)
private val partialResultUnwrapper = unwrapperFor(partial1HiveEvaluator.objectInspector)

private val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType)
private val partialResultWrapper =
wrapperFor(partial1HiveEvaluator.objectInspector, partialResultDataType)

private val projection = UnsafeProjection.create(Array(partialResultDataType))

Expand All @@ -451,7 +453,8 @@ private[hive] case class HiveUDAFFunction(
// `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
// that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
// Then we can unwrap it to a Spark SQL value.
mutableRow.update(0, partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer)))
mutableRow.update(0, partialResultUnwrapper(
partial1HiveEvaluator.evaluator.terminatePartial(buffer)))
val unsafeRow = projection(mutableRow)
val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes)
unsafeRow.writeTo(bytes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,20 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
Row(3) :: Row(3) :: Nil)
}
}

test("SPARK-25768 constant argument expecting Hive UDF") {
withTempView("inputTable") {
spark.range(10).createOrReplaceTempView("inputTable")
withUserDefinedFunction("testGenericUDAFPercentileApprox" -> false) {
val numFunc = spark.catalog.listFunctions().count()
sql(s"CREATE FUNCTION testGenericUDAFPercentileApprox AS '" +
s"${classOf[GenericUDAFPercentileApprox].getName}'")
checkAnswer(
sql("SELECT testGenericUDAFPercentileApprox(id, 0.5) FROM inputTable"),
Seq(Row(4.0)))
}
}
}
}

class TestPair(x: Int, y: Int) extends Writable with Serializable {
Expand Down