Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,13 @@ object MapObjects {
elementType: DataType,
elementNullable: Boolean = true,
customCollectionCls: Option[Class[_]] = None): MapObjects = {
// UnresolvedMapObjects does not serialize its 'function' field.
// If an array expression or array Encoder is not correctly resolved before
// serialization, this exception condition may occur.
require(function != null,
"MapObjects applied with a null function. " +
"Likely cause is failure to resolve an array expression or encoder. " +
"(See UnresolvedMapObjects)")
val loopVar = LambdaVariable("MapObject", elementType, elementNullable)
MapObjects(loopVar, function(loopVar), inputData, customCollectionCls)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -458,7 +460,8 @@ case class ScalaUDAF(
case class ScalaAggregator[IN, BUF, OUT](
children: Seq[Expression],
agg: Aggregator[IN, BUF, OUT],
inputEncoderNR: ExpressionEncoder[IN],
inputEncoder: ExpressionEncoder[IN],
bufferEncoder: ExpressionEncoder[BUF],
nullable: Boolean = true,
isDeterministic: Boolean = true,
mutableAggBufferOffset: Int = 0,
Expand All @@ -469,17 +472,16 @@ case class ScalaAggregator[IN, BUF, OUT](
with ImplicitCastInputTypes
with Logging {

private[this] lazy val inputDeserializer = inputEncoderNR.resolveAndBind().createDeserializer()
private[this] lazy val bufferEncoder =
agg.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]].resolveAndBind()
// input and buffer encoders are resolved by ResolveEncodersInScalaAgg
private[this] lazy val inputDeserializer = inputEncoder.createDeserializer()
private[this] lazy val bufferSerializer = bufferEncoder.createSerializer()
private[this] lazy val bufferDeserializer = bufferEncoder.createDeserializer()
private[this] lazy val outputEncoder = agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]]
private[this] lazy val outputSerializer = outputEncoder.createSerializer()

def dataType: DataType = outputEncoder.objSerializer.dataType

def inputTypes: Seq[DataType] = inputEncoderNR.schema.map(_.dataType)
def inputTypes: Seq[DataType] = inputEncoder.schema.map(_.dataType)

override lazy val deterministic: Boolean = isDeterministic

Expand Down Expand Up @@ -517,3 +519,18 @@ case class ScalaAggregator[IN, BUF, OUT](

override def nodeName: String = agg.getClass.getSimpleName
}

/**
* An extension rule to resolve encoder expressions from a [[ScalaAggregator]]
*/
object ResolveEncodersInScalaAgg extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p if !p.resolved => p
case p => p.transformExpressionsUp {
case agg: ScalaAggregator[_, _, _] =>
agg.copy(
inputEncoder = agg.inputEncoder.resolveAndBind(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A followup we can do is to resolve and bind using the actual input data types, so that we can do casting or reorder fields.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be nice. I tried this and but the way I did it wasn't having any effect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan what I had done earlier was:

object ResolveEncodersInScalaAgg extends Rule[LogicalPlan] {
  override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
    case p if !p.resolved => p
    case p => p.transformExpressionsUp {
      case agg: ScalaAggregator[_, _, _] =>
        val children = agg.children
        require(children.length > 0, "Missing aggregator input")
        val dataType: DataType = if (children.length == 1) children.head.dataType else {
          StructType(children.map(_.dataType).zipWithIndex.map { case (dt, j) =>
            StructField(s"_$j", dt, true)
          })
        }
        val attrs = if (agg.inputEncoder.isSerializedAsStructForTopLevel) {
          dataType.asInstanceOf[StructType].toAttributes
        } else {
          (new StructType().add("input", dataType)).toAttributes
        }
        agg.copy(
          inputEncoder = agg.inputEncoder.resolveAndBind(attrs),
          bufferEncoder = agg.bufferEncoder.resolveAndBind())
    }
  }
}

This also passes unit tests, but it would still fail if I tried to give it Float data, so it's not automatically casting.

bufferEncoder = agg.bufferEncoder.resolveAndBind())
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT](
// This is also used by udf.register(...) when it detects a UserDefinedAggregator
def scalaAggregator(exprs: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = {
val iEncoder = inputEncoder.asInstanceOf[ExpressionEncoder[IN]]
ScalaAggregator(exprs, aggregator, iEncoder, nullable, deterministic)
val bEncoder = aggregator.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]]
ScalaAggregator(exprs, aggregator, iEncoder, bEncoder, nullable, deterministic)
}

override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.{ColumnarRule, QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser}
import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg
import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
import org.apache.spark.sql.execution.command.CommandCheck
import org.apache.spark.sql.execution.datasources._
Expand Down Expand Up @@ -175,6 +176,7 @@ abstract class BaseSessionStateBuilder(
new FindDataSourceTable(session) +:
new ResolveSQLOnFile(session) +:
new FallBackFileSourceV2(session) +:
ResolveEncodersInScalaAgg +:
new ResolveSessionCatalog(
catalogManager, conf, catalog.isTempView, catalog.isTempFunction) +:
customResolutionRules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{SparkOptimizer, SparkPlanner}
import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg
import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
import org.apache.spark.sql.execution.command.CommandCheck
import org.apache.spark.sql.execution.datasources._
Expand Down Expand Up @@ -76,6 +77,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session
new FindDataSourceTable(session) +:
new ResolveSQLOnFile(session) +:
new FallBackFileSourceV2(session) +:
ResolveEncodersInScalaAgg +:
new ResolveSessionCatalog(
catalogManager, conf, catalog.isTempView, catalog.isTempFunction) +:
customResolutionRules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,27 @@ object CountSerDeAgg extends Aggregator[Int, CountSerDeSQL, CountSerDeSQL] {
def outputEncoder: Encoder[CountSerDeSQL] = ExpressionEncoder[CountSerDeSQL]()
}

object ArrayDataAgg extends Aggregator[Array[Double], Array[Double], Array[Double]] {
def zero: Array[Double] = Array(0.0, 0.0, 0.0)
def reduce(s: Array[Double], array: Array[Double]): Array[Double] = {
require(s.length == array.length)
for ( j <- 0 until s.length ) {
s(j) += array(j)
}
s
}
def merge(s1: Array[Double], s2: Array[Double]): Array[Double] = {
require(s1.length == s2.length)
for ( j <- 0 until s1.length ) {
s1(j) += s2(j)
}
s1
}
def finish(s: Array[Double]): Array[Double] = s
def bufferEncoder: Encoder[Array[Double]] = ExpressionEncoder[Array[Double]]
def outputEncoder: Encoder[Array[Double]] = ExpressionEncoder[Array[Double]]
}

abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._

Expand Down Expand Up @@ -156,20 +177,11 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi
(3, null, null)).toDF("key", "value1", "value2")
data2.write.saveAsTable("agg2")

val data3 = Seq[(Seq[Integer], Integer, Integer)](
(Seq[Integer](1, 1), 10, -10),
(Seq[Integer](null), -60, 60),
(Seq[Integer](1, 1), 30, -30),
(Seq[Integer](1), 30, 30),
(Seq[Integer](2), 1, 1),
(null, -10, 10),
(Seq[Integer](2, 3), -1, null),
(Seq[Integer](2, 3), 1, 1),
(Seq[Integer](2, 3, 4), null, 1),
(Seq[Integer](null), 100, -10),
(Seq[Integer](3), null, 3),
(null, null, null),
(Seq[Integer](3), null, null)).toDF("key", "value1", "value2")
val data3 = Seq[(Seq[Double], Int)](
(Seq(1.0, 2.0, 3.0), 0),
(Seq(4.0, 5.0, 6.0), 0),
(Seq(7.0, 8.0, 9.0), 0)
).toDF("data", "dummy")
data3.write.saveAsTable("agg3")

val data4 = Seq[Boolean](true, false, true).toDF("boolvalues")
Expand All @@ -184,6 +196,7 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi
spark.udf.register("mydoublesum", udaf(MyDoubleSumAgg))
spark.udf.register("mydoubleavg", udaf(MyDoubleAvgAgg))
spark.udf.register("longProductSum", udaf(LongProductSumAgg))
spark.udf.register("arraysum", udaf(ArrayDataAgg))
}

override def afterAll(): Unit = {
Expand Down Expand Up @@ -354,6 +367,12 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi
Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil)
}

test("SPARK-32159: array encoders should be resolved in analyzer") {
checkAnswer(
spark.sql("SELECT arraysum(data) FROM agg3"),
Row(Seq(12.0, 15.0, 18.0)) :: Nil)
}

test("verify aggregator ser/de behavior") {
val data = sparkContext.parallelize((1 to 100).toSeq, 3).toDF("value1")
val agg = udaf(CountSerDeAgg)
Expand Down