-
Notifications
You must be signed in to change notification settings - Fork 29.3k
[SPARK-32159][SQL] Fix integration between Aggregator[Array[_], _, _] and UnresolvedMapObjects #28983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-32159][SQL] Fix integration between Aggregator[Array[_], _, _] and UnresolvedMapObjects #28983
Changes from 18 commits
1a501d9
73299e8
20012b3
a8dd23d
bc2d880
399cbab
2139c14
2092b3a
ac22ccf
1351b76
e923d2f
a4858d5
d3c5d4d
814956c
aca7b51
e679c01
c632437
ee96cc0
622ac1c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression | |
| import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection} | ||
| import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction} | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
|
|
@@ -458,7 +460,8 @@ case class ScalaUDAF( | |
| case class ScalaAggregator[IN, BUF, OUT]( | ||
| children: Seq[Expression], | ||
| agg: Aggregator[IN, BUF, OUT], | ||
| inputEncoderNR: ExpressionEncoder[IN], | ||
| inputEncoder: ExpressionEncoder[IN], | ||
| bufferEncoder: ExpressionEncoder[BUF], | ||
| nullable: Boolean = true, | ||
| isDeterministic: Boolean = true, | ||
| mutableAggBufferOffset: Int = 0, | ||
|
|
@@ -469,17 +472,16 @@ case class ScalaAggregator[IN, BUF, OUT]( | |
| with ImplicitCastInputTypes | ||
| with Logging { | ||
|
|
||
| private[this] lazy val inputDeserializer = inputEncoderNR.resolveAndBind().createDeserializer() | ||
| private[this] lazy val bufferEncoder = | ||
| agg.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]].resolveAndBind() | ||
| // input and buffer encoders are resolved by ResolveEncodersInScalaAgg | ||
| private[this] lazy val inputDeserializer = inputEncoder.createDeserializer() | ||
| private[this] lazy val bufferSerializer = bufferEncoder.createSerializer() | ||
| private[this] lazy val bufferDeserializer = bufferEncoder.createDeserializer() | ||
| private[this] lazy val outputEncoder = agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]] | ||
| private[this] lazy val outputSerializer = outputEncoder.createSerializer() | ||
|
|
||
| def dataType: DataType = outputEncoder.objSerializer.dataType | ||
|
|
||
| def inputTypes: Seq[DataType] = inputEncoderNR.schema.map(_.dataType) | ||
| def inputTypes: Seq[DataType] = inputEncoder.schema.map(_.dataType) | ||
|
|
||
| override lazy val deterministic: Boolean = isDeterministic | ||
|
|
||
|
|
@@ -517,3 +519,18 @@ case class ScalaAggregator[IN, BUF, OUT]( | |
|
|
||
| override def nodeName: String = agg.getClass.getSimpleName | ||
| } | ||
|
|
||
| /** | ||
| * An extension rule to resolve encoder expressions from a [[ScalaAggregator]] | ||
| */ | ||
| object ResolveEncodersInScalaAgg extends Rule[LogicalPlan] { | ||
| override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { | ||
| case p if !p.resolved => p | ||
| case p => p.transformExpressionsUp { | ||
| case agg: ScalaAggregator[_, _, _] => | ||
| agg.copy( | ||
| inputEncoder = agg.inputEncoder.resolveAndBind(), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A followup we can do is to resolve and bind using the actual input data types, so that we can do casting or reorder fields.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would be nice. I tried this and but the way I did it wasn't having any effect.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan what I had done earlier was: object ResolveEncodersInScalaAgg extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p if !p.resolved => p
case p => p.transformExpressionsUp {
case agg: ScalaAggregator[_, _, _] =>
val children = agg.children
require(children.length > 0, "Missing aggregator input")
val dataType: DataType = if (children.length == 1) children.head.dataType else {
StructType(children.map(_.dataType).zipWithIndex.map { case (dt, j) =>
StructField(s"_$j", dt, true)
})
}
val attrs = if (agg.inputEncoder.isSerializedAsStructForTopLevel) {
dataType.asInstanceOf[StructType].toAttributes
} else {
(new StructType().add("input", dataType)).toAttributes
}
agg.copy(
inputEncoder = agg.inputEncoder.resolveAndBind(attrs),
bufferEncoder = agg.bufferEncoder.resolveAndBind())
}
}
}This also passes unit tests, but it would still fail if I tried to give it |
||
| bufferEncoder = agg.bufferEncoder.resolveAndBind()) | ||
| } | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.