-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-25768][SQL] fix constant argument expecting UDAFs #22766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -339,40 +339,38 @@ private[hive] case class HiveUDAFFunction( | |
| val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) | ||
| resolver.getEvaluator(parameterInfo) | ||
| } | ||
|
|
||
| case class Mode(evaluator: GenericUDAFEvaluator, objectInspector: ObjectInspector) | ||
|
|
||
| // The UDAF evaluator used to consume raw input rows and produce partial aggregation results. | ||
| @transient | ||
| private lazy val partial1ModeEvaluator = newEvaluator() | ||
|
|
||
| // Hive `ObjectInspector` used to inspect partial aggregation results. | ||
| @transient | ||
| private val partialResultInspector = partial1ModeEvaluator.init( | ||
| GenericUDAFEvaluator.Mode.PARTIAL1, | ||
| inputInspectors | ||
| ) | ||
| private lazy val partial1Mode = { | ||
| val evaluator = newEvaluator() | ||
| Mode(evaluator, evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputInspectors)) | ||
| } | ||
|
|
||
| // The UDAF evaluator used to merge partial aggregation results. | ||
| @transient | ||
| private lazy val partial2ModeEvaluator = { | ||
| val evaluator = newEvaluator() | ||
| evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector)) | ||
| evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partial1Mode.objectInspector)) | ||
| evaluator | ||
| } | ||
|
|
||
| // Spark SQL data type of partial aggregation results | ||
| @transient | ||
| private lazy val partialResultDataType = inspectorToDataType(partialResultInspector) | ||
| private lazy val partialResultDataType = inspectorToDataType(partial1Mode.objectInspector) | ||
|
|
||
| // The UDAF evaluator used to compute the final result from a partial aggregation result objects. | ||
| @transient | ||
| private lazy val finalModeEvaluator = newEvaluator() | ||
|
|
||
| // Hive `ObjectInspector` used to inspect the final aggregation result object. | ||
| @transient | ||
| private val returnInspector = finalModeEvaluator.init( | ||
| GenericUDAFEvaluator.Mode.FINAL, | ||
| Array(partialResultInspector) | ||
| ) | ||
| private lazy val finalMode = { | ||
|
||
| val evaluator = newEvaluator() | ||
| Mode( | ||
| evaluator, | ||
| evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1Mode.objectInspector))) | ||
| } | ||
|
|
||
| // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format. | ||
| @transient | ||
|
|
@@ -381,7 +379,7 @@ private[hive] case class HiveUDAFFunction( | |
| // Unwrapper function used to unwrap final aggregation result objects returned by Hive UDAFs into | ||
| // Spark SQL specific format. | ||
| @transient | ||
| private lazy val resultUnwrapper = unwrapperFor(returnInspector) | ||
| private lazy val resultUnwrapper = unwrapperFor(finalMode.objectInspector) | ||
|
|
||
| @transient | ||
| private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) | ||
|
|
@@ -391,7 +389,7 @@ private[hive] case class HiveUDAFFunction( | |
|
|
||
| override def nullable: Boolean = true | ||
|
|
||
| override lazy val dataType: DataType = inspectorToDataType(returnInspector) | ||
| override lazy val dataType: DataType = inspectorToDataType(finalMode.objectInspector) | ||
|
|
||
| override def prettyName: String = name | ||
|
|
||
|
|
@@ -401,13 +399,13 @@ private[hive] case class HiveUDAFFunction( | |
| } | ||
|
|
||
| override def createAggregationBuffer(): AggregationBuffer = | ||
| partial1ModeEvaluator.getNewAggregationBuffer | ||
| partial1Mode.evaluator.getNewAggregationBuffer | ||
|
|
||
| @transient | ||
| private lazy val inputProjection = UnsafeProjection.create(children) | ||
|
|
||
| override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = { | ||
| partial1ModeEvaluator.iterate( | ||
| partial1Mode.evaluator.iterate( | ||
| buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes)) | ||
| buffer | ||
| } | ||
|
|
@@ -417,12 +415,12 @@ private[hive] case class HiveUDAFFunction( | |
| // buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts | ||
| // this `AggregationBuffer`s into this format before shuffling partial aggregation results, and | ||
| // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion. | ||
| partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input)) | ||
| partial2ModeEvaluator.merge(buffer, partial1Mode.evaluator.terminatePartial(input)) | ||
| buffer | ||
| } | ||
|
|
||
| override def eval(buffer: AggregationBuffer): Any = { | ||
| resultUnwrapper(finalModeEvaluator.terminate(buffer)) | ||
| resultUnwrapper(finalMode.evaluator.terminate(buffer)) | ||
| } | ||
|
|
||
| override def serialize(buffer: AggregationBuffer): Array[Byte] = { | ||
|
|
@@ -439,9 +437,10 @@ private[hive] case class HiveUDAFFunction( | |
|
|
||
| // Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects | ||
| private class AggregationBufferSerDe { | ||
| private val partialResultUnwrapper = unwrapperFor(partialResultInspector) | ||
| private val partialResultUnwrapper = unwrapperFor(partial1Mode.objectInspector) | ||
|
|
||
| private val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType) | ||
| private val partialResultWrapper = | ||
| wrapperFor(partial1Mode.objectInspector, partialResultDataType) | ||
|
|
||
| private val projection = UnsafeProjection.create(Array(partialResultDataType)) | ||
|
|
||
|
|
@@ -451,7 +450,7 @@ private[hive] case class HiveUDAFFunction( | |
| // `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object | ||
| // that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`. | ||
| // Then we can unwrap it to a Spark SQL value. | ||
| mutableRow.update(0, partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer))) | ||
| mutableRow.update(0, partialResultUnwrapper(partial1Mode.evaluator.terminatePartial(buffer))) | ||
| val unsafeRow = projection(mutableRow) | ||
| val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes) | ||
| unsafeRow.writeTo(bytes) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -638,6 +638,21 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { | |
| Row(3) :: Row(3) :: Nil) | ||
| } | ||
| } | ||
|
|
||
| test("constant argument expecting Hive UDF") { | ||
|
||
| val testData = spark.range(10).toDF() | ||
| withTempView("inputTable") { | ||
| testData.createOrReplaceTempView("inputTable") | ||
|
||
| withUserDefinedFunction("testGenericUDAFPercentileApprox" -> false) { | ||
| val numFunc = spark.catalog.listFunctions().count() | ||
| sql(s"CREATE FUNCTION testGenericUDAFPercentileApprox AS '" + | ||
| s"${classOf[GenericUDAFPercentileApprox].getName}'") | ||
| checkAnswer( | ||
| sql("SELECT testGenericUDAFPercentileApprox(id, 0.5) FROM inputTable"), | ||
| Seq(Row(4.0))) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| class TestPair(x: Int, y: Int) extends Writable with Serializable { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
partial1ModeEvaluator