Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 24 additions & 25 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -339,40 +339,38 @@ private[hive] case class HiveUDAFFunction(
val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false)
resolver.getEvaluator(parameterInfo)
}

case class Mode(evaluator: GenericUDAFEvaluator, objectInspector: ObjectInspector)

// The UDAF evaluator used to consume raw input rows and produce partial aggregation results.
@transient
private lazy val partial1ModeEvaluator = newEvaluator()

// Hive `ObjectInspector` used to inspect partial aggregation results.
@transient
private val partialResultInspector = partial1ModeEvaluator.init(
GenericUDAFEvaluator.Mode.PARTIAL1,
inputInspectors
)
private lazy val partial1Mode = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

partial1ModeEvaluator

val evaluator = newEvaluator()
Mode(evaluator, evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputInspectors))
}

// The UDAF evaluator used to merge partial aggregation results.
@transient
private lazy val partial2ModeEvaluator = {
val evaluator = newEvaluator()
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector))
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partial1Mode.objectInspector))
evaluator
}

// Spark SQL data type of partial aggregation results
@transient
private lazy val partialResultDataType = inspectorToDataType(partialResultInspector)
private lazy val partialResultDataType = inspectorToDataType(partial1Mode.objectInspector)

// The UDAF evaluator used to compute the final result from a partial aggregation result objects.
@transient
private lazy val finalModeEvaluator = newEvaluator()

// Hive `ObjectInspector` used to inspect the final aggregation result object.
@transient
private val returnInspector = finalModeEvaluator.init(
GenericUDAFEvaluator.Mode.FINAL,
Array(partialResultInspector)
)
private lazy val finalMode = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah it's also used in final mode, then maybe HiveEvaluator is a better name than PartialEvaluator

val evaluator = newEvaluator()
Mode(
evaluator,
evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1Mode.objectInspector)))
}

// Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
@transient
Expand All @@ -381,7 +379,7 @@ private[hive] case class HiveUDAFFunction(
// Unwrapper function used to unwrap final aggregation result objects returned by Hive UDAFs into
// Spark SQL specific format.
@transient
private lazy val resultUnwrapper = unwrapperFor(returnInspector)
private lazy val resultUnwrapper = unwrapperFor(finalMode.objectInspector)

@transient
private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
Expand All @@ -391,7 +389,7 @@ private[hive] case class HiveUDAFFunction(

override def nullable: Boolean = true

override lazy val dataType: DataType = inspectorToDataType(returnInspector)
override lazy val dataType: DataType = inspectorToDataType(finalMode.objectInspector)

override def prettyName: String = name

Expand All @@ -401,13 +399,13 @@ private[hive] case class HiveUDAFFunction(
}

override def createAggregationBuffer(): AggregationBuffer =
partial1ModeEvaluator.getNewAggregationBuffer
partial1Mode.evaluator.getNewAggregationBuffer

@transient
private lazy val inputProjection = UnsafeProjection.create(children)

override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
partial1ModeEvaluator.iterate(
partial1Mode.evaluator.iterate(
buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
buffer
}
Expand All @@ -417,12 +415,12 @@ private[hive] case class HiveUDAFFunction(
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input))
partial2ModeEvaluator.merge(buffer, partial1Mode.evaluator.terminatePartial(input))
buffer
}

override def eval(buffer: AggregationBuffer): Any = {
resultUnwrapper(finalModeEvaluator.terminate(buffer))
resultUnwrapper(finalMode.evaluator.terminate(buffer))
}

override def serialize(buffer: AggregationBuffer): Array[Byte] = {
Expand All @@ -439,9 +437,10 @@ private[hive] case class HiveUDAFFunction(

// Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
private class AggregationBufferSerDe {
private val partialResultUnwrapper = unwrapperFor(partialResultInspector)
private val partialResultUnwrapper = unwrapperFor(partial1Mode.objectInspector)

private val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType)
private val partialResultWrapper =
wrapperFor(partial1Mode.objectInspector, partialResultDataType)

private val projection = UnsafeProjection.create(Array(partialResultDataType))

Expand All @@ -451,7 +450,7 @@ private[hive] case class HiveUDAFFunction(
// `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
// that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
// Then we can unwrap it to a Spark SQL value.
mutableRow.update(0, partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer)))
mutableRow.update(0, partialResultUnwrapper(partial1Mode.evaluator.terminatePartial(buffer)))
val unsafeRow = projection(mutableRow)
val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes)
unsafeRow.writeTo(bytes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,21 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
Row(3) :: Row(3) :: Nil)
}
}

test("constant argument expecting Hive UDF") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may you please reference the JIRA?

val testData = spark.range(10).toDF()
withTempView("inputTable") {
testData.createOrReplaceTempView("inputTable")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can move here spark.range(10)

withUserDefinedFunction("testGenericUDAFPercentileApprox" -> false) {
val numFunc = spark.catalog.listFunctions().count()
sql(s"CREATE FUNCTION testGenericUDAFPercentileApprox AS '" +
s"${classOf[GenericUDAFPercentileApprox].getName}'")
checkAnswer(
sql("SELECT testGenericUDAFPercentileApprox(id, 0.5) FROM inputTable"),
Seq(Row(4.0)))
}
}
}
}

class TestPair(x: Int, y: Int) extends Writable with Serializable {
Expand Down