From 1a501d9a62cd7b22b560d2c255afa08d449da1d1 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Thu, 2 Jul 2020 14:00:59 -0700 Subject: [PATCH 01/19] fix SPARK-32159 - intercept null function in MapObjects --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 3 +++ .../spark/sql/catalyst/expressions/objects/objects.scala | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d08a6382f738b..cd909eb649de4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3066,9 +3066,12 @@ class Analyzer( inputAttributes } + // println(s"\n\ndeserializer= $deserializer\n") + // println(s"\n\ninputs= $inputs\n") validateTopLevelTupleFields(deserializer, inputs) val resolved = resolveExpressionBottomUp( deserializer, LocalRelation(inputs), throws = true) + // println(s"\n\nresolved= $resolved\n") val result = resolved transformDown { case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => inputData.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d5de95c65e49e..da3df93f5d91f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -679,7 +679,10 @@ object MapObjects { elementNullable: Boolean = true, customCollectionCls: Option[Class[_]] = None): MapObjects = { val loopVar = LambdaVariable("MapObject", elementType, elementNullable) - MapObjects(loopVar, function(loopVar), inputData, customCollectionCls) + // println(s"\n\n $function, $inputData, $elementType, $elementNullable, $customCollectionCls \n") + val fOfLV = if (function == null) loopVar else function(loopVar) + // MapObjects(loopVar, function(loopVar), inputData, customCollectionCls) + MapObjects(loopVar, fOfLV, inputData, customCollectionCls) } } From 73299e8a7aedd24dafd6964c54df117e3019cbe6 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Sun, 5 Jul 2020 15:01:15 -0700 Subject: [PATCH 02/19] revert println --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cd909eb649de4..d08a6382f738b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3066,12 +3066,9 @@ class Analyzer( inputAttributes } - // println(s"\n\ndeserializer= $deserializer\n") - // println(s"\n\ninputs= $inputs\n") validateTopLevelTupleFields(deserializer, inputs) val resolved = resolveExpressionBottomUp( deserializer, LocalRelation(inputs), throws = true) - // println(s"\n\nresolved= $resolved\n") val result = resolved transformDown { case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => inputData.dataType match { From 20012b3f9c57509057fb337af884b665bcb6f9bc Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Sun, 5 Jul 2020 15:01:46 -0700 Subject: [PATCH 03/19] add informative exception for null function --- .../spark/sql/catalyst/expressions/objects/objects.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index da3df93f5d91f..46e30784ffcf2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -678,11 +678,11 @@ object MapObjects { elementType: DataType, elementNullable: Boolean = true, customCollectionCls: Option[Class[_]] = None): MapObjects = { + if (function == null) { + throw new UnsupportedOperationException("Cannot instantiate MapObjects with null function") + } val loopVar = LambdaVariable("MapObject", elementType, elementNullable) - // println(s"\n\n $function, $inputData, $elementType, $elementNullable, $customCollectionCls \n") - val fOfLV = if (function == null) loopVar else function(loopVar) - // MapObjects(loopVar, function(loopVar), inputData, customCollectionCls) - MapObjects(loopVar, fOfLV, inputData, customCollectionCls) + MapObjects(loopVar, function(loopVar), inputData, customCollectionCls) } } From a8dd23d912670e727e18738d35d50ac8ab9d6981 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Sun, 5 Jul 2020 15:03:12 -0700 Subject: [PATCH 04/19] move resolution of ScalaAggregator inputEncoder to new rule ResolveEncodersInScalaAgg --- .../spark/sql/execution/aggregate/udaf.scala | 22 ++++++++++++++++--- .../internal/BaseSessionStateBuilder.scala | 2 ++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 544b90a736071..a5a6a0ad9ef65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -27,9 +27,12 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ + /** * A helper trait used to create specialized setter and getter for types supported by * [[org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap]]'s buffer. @@ -458,7 +461,7 @@ case class ScalaUDAF( case class ScalaAggregator[IN, BUF, OUT]( children: Seq[Expression], agg: Aggregator[IN, BUF, OUT], - inputEncoderNR: ExpressionEncoder[IN], + inputEncoder: ExpressionEncoder[IN], nullable: Boolean = true, isDeterministic: Boolean = true, mutableAggBufferOffset: Int = 0, @@ -469,17 +472,20 @@ case class ScalaAggregator[IN, BUF, OUT]( with ImplicitCastInputTypes with Logging { - private[this] lazy val inputDeserializer = inputEncoderNR.resolveAndBind().createDeserializer() + // input encoder is resolved by ResolveEncodersInScalaAgg + private[this] lazy val inputDeserializer = inputEncoder.createDeserializer() + private[this] lazy val bufferEncoder = agg.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]].resolveAndBind() private[this] lazy val bufferSerializer = bufferEncoder.createSerializer() private[this] lazy val bufferDeserializer = bufferEncoder.createDeserializer() + private[this] lazy val outputEncoder = agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]] private[this] lazy val outputSerializer = outputEncoder.createSerializer() def dataType: DataType = outputEncoder.objSerializer.dataType - def inputTypes: Seq[DataType] = inputEncoderNR.schema.map(_.dataType) + def inputTypes: Seq[DataType] = inputEncoder.schema.map(_.dataType) override lazy val deterministic: Boolean = isDeterministic @@ -517,3 +523,13 @@ case class ScalaAggregator[IN, BUF, OUT]( override def nodeName: String = agg.getClass.getSimpleName } + +class ResolveEncodersInScalaAgg extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case p if !p.resolved => p + case p => p.transformExpressionsUp { + case agg: ScalaAggregator[_, _, _] => + agg.copy(inputEncoder = agg.inputEncoder.resolveAndBind()) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 3bbdbb002cca8..8405bbee3a264 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.{ColumnarRule, QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} +import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.command.CommandCheck import org.apache.spark.sql.execution.datasources._ @@ -175,6 +176,7 @@ abstract class BaseSessionStateBuilder( new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: + new ResolveEncodersInScalaAgg() +: new ResolveSessionCatalog( catalogManager, conf, catalog.isTempView, catalog.isTempFunction) +: customResolutionRules From bc2d880c66844e2eb4be28bffc33b81b184a82e2 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Sun, 5 Jul 2020 15:04:16 -0700 Subject: [PATCH 05/19] add ResolveEncodersInScalaAgg rule to TestHive --- .../org/apache/spark/sql/hive/test/TestHive.scala | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala index a3e2444cae887..1515d0077bc62 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -35,7 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.internal.config.UI._ -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession, SparkSessionExtensions, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation @@ -46,9 +46,19 @@ import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf, WithTestConf} -import org.apache.spark.sql.internal.StaticSQLConf.{CATALOG_IMPLEMENTATION, WAREHOUSE_PATH} +import org.apache.spark.sql.internal.StaticSQLConf.{CATALOG_IMPLEMENTATION, SPARK_SESSION_EXTENSIONS, WAREHOUSE_PATH} import org.apache.spark.util.{ShutdownHookManager, Utils} +class TestHiveExtensions extends (SparkSessionExtensions => Unit) { + import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg + + def apply(e: SparkSessionExtensions): Unit = { + e.injectResolutionRule { session => + new ResolveEncodersInScalaAgg() // SPARK-32159 + } + } +} + // SPARK-3729: Test key required to check for initialization errors with config. object TestHive extends TestHiveContext( @@ -61,6 +71,7 @@ object TestHive .set(HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key, "org.apache.spark.sql.hive.execution.PairSerDe") .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath) + .set(SPARK_SESSION_EXTENSIONS.key, classOf[TestHiveExtensions].getCanonicalName) // SPARK-8910 .set(UI_ENABLED, false) .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) From 399cbab9ce18fe8f0879bc86c406fde7001dda65 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Sun, 5 Jul 2020 16:04:23 -0700 Subject: [PATCH 06/19] add unit test for array input type --- .../sql/hive/execution/UDAQuerySuite.scala | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala index e6856a58b0ea9..6baccdd2f8709 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala @@ -119,6 +119,15 @@ object CountSerDeAgg extends Aggregator[Int, CountSerDeSQL, CountSerDeSQL] { def outputEncoder: Encoder[CountSerDeSQL] = ExpressionEncoder[CountSerDeSQL]() } +object ArrayDataAgg extends Aggregator[Array[Double], Double, Double] { + def zero: Double = 0.0 + def reduce(s: Double, array: Array[Double]): Double = s + array.sum + def merge(s1: Double, s2: Double): Double = s1 + s2 + def finish(s: Double): Double = s + def bufferEncoder: Encoder[Double] = Encoders.scalaDouble + def outputEncoder: Encoder[Double] = Encoders.scalaDouble +} + abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -156,20 +165,11 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi (3, null, null)).toDF("key", "value1", "value2") data2.write.saveAsTable("agg2") - val data3 = Seq[(Seq[Integer], Integer, Integer)]( - (Seq[Integer](1, 1), 10, -10), - (Seq[Integer](null), -60, 60), - (Seq[Integer](1, 1), 30, -30), - (Seq[Integer](1), 30, 30), - (Seq[Integer](2), 1, 1), - (null, -10, 10), - (Seq[Integer](2, 3), -1, null), - (Seq[Integer](2, 3), 1, 1), - (Seq[Integer](2, 3, 4), null, 1), - (Seq[Integer](null), 100, -10), - (Seq[Integer](3), null, 3), - (null, null, null), - (Seq[Integer](3), null, null)).toDF("key", "value1", "value2") + val data3 = Seq[(Seq[Double], Int)]( + (Seq(1.0), 0), + (Seq(2.0, 3.0), 0), + (Seq(4.0, 5.0, 6.0), 0) + ).toDF("data", "dummy") data3.write.saveAsTable("agg3") val data4 = Seq[Boolean](true, false, true).toDF("boolvalues") @@ -184,6 +184,7 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi spark.udf.register("mydoublesum", udaf(MyDoubleSumAgg)) spark.udf.register("mydoubleavg", udaf(MyDoubleAvgAgg)) spark.udf.register("longProductSum", udaf(LongProductSumAgg)) + spark.udf.register("arraysum", udaf(ArrayDataAgg)) } override def afterAll(): Unit = { @@ -354,6 +355,13 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil) } + // SPARK-32159 + test("array input types") { + checkAnswer( + spark.sql("SELECT arraysum(data) FROM agg3"), + Row(21.0) :: Nil) + } + test("verify aggregator ser/de behavior") { val data = sparkContext.parallelize((1 to 100).toSeq, 3).toDF("value1") val agg = udaf(CountSerDeAgg) From 2139c14ef2ffc0c85ce49bc796d918de0b17f9e5 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Mon, 6 Jul 2020 07:38:07 -0700 Subject: [PATCH 07/19] add ResolveEncodersInScalaAgg rule to HiveSessionStateBuilder --- .../org/apache/spark/sql/hive/HiveSessionStateBuilder.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 64726755237a6..e16ef2596d412 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{SparkOptimizer, SparkPlanner} +import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.command.CommandCheck import org.apache.spark.sql.execution.datasources._ @@ -76,6 +77,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: + new ResolveEncodersInScalaAgg() +: new ResolveSessionCatalog( catalogManager, conf, catalog.isTempView, catalog.isTempFunction) +: customResolutionRules From 2092b3a1965d624c772dd48d174d8bd0b59c0ea2 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Mon, 6 Jul 2020 07:38:51 -0700 Subject: [PATCH 08/19] revert session extension in TestHive --- .../org/apache/spark/sql/hive/test/TestHive.scala | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala index 1515d0077bc62..a3e2444cae887 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -35,7 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.internal.config.UI._ -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession, SparkSessionExtensions, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation @@ -46,19 +46,9 @@ import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf, WithTestConf} -import org.apache.spark.sql.internal.StaticSQLConf.{CATALOG_IMPLEMENTATION, SPARK_SESSION_EXTENSIONS, WAREHOUSE_PATH} +import org.apache.spark.sql.internal.StaticSQLConf.{CATALOG_IMPLEMENTATION, WAREHOUSE_PATH} import org.apache.spark.util.{ShutdownHookManager, Utils} -class TestHiveExtensions extends (SparkSessionExtensions => Unit) { - import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg - - def apply(e: SparkSessionExtensions): Unit = { - e.injectResolutionRule { session => - new ResolveEncodersInScalaAgg() // SPARK-32159 - } - } -} - // SPARK-3729: Test key required to check for initialization errors with config. object TestHive extends TestHiveContext( @@ -71,7 +61,6 @@ object TestHive .set(HiveUtils.HIVE_METASTORE_BARRIER_PREFIXES.key, "org.apache.spark.sql.hive.execution.PairSerDe") .set(WAREHOUSE_PATH.key, TestHiveContext.makeWarehouseDir().toURI.getPath) - .set(SPARK_SESSION_EXTENSIONS.key, classOf[TestHiveExtensions].getCanonicalName) // SPARK-8910 .set(UI_ENABLED, false) .set(config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) From ac22ccfe2290c267b0ae22bedcfa4374ef5baa8c Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Mon, 6 Jul 2020 07:50:43 -0700 Subject: [PATCH 09/19] revert spacing --- .../scala/org/apache/spark/sql/execution/aggregate/udaf.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index a5a6a0ad9ef65..94dddd015c4ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -474,12 +474,10 @@ case class ScalaAggregator[IN, BUF, OUT]( // input encoder is resolved by ResolveEncodersInScalaAgg private[this] lazy val inputDeserializer = inputEncoder.createDeserializer() - private[this] lazy val bufferEncoder = agg.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]].resolveAndBind() private[this] lazy val bufferSerializer = bufferEncoder.createSerializer() private[this] lazy val bufferDeserializer = bufferEncoder.createDeserializer() - private[this] lazy val outputEncoder = agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]] private[this] lazy val outputSerializer = outputEncoder.createSerializer() From 1351b7605d4761d8b64c87857cc01784d7360bbc Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Tue, 7 Jul 2020 05:39:42 -0700 Subject: [PATCH 10/19] revert space --- .../scala/org/apache/spark/sql/execution/aggregate/udaf.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 94dddd015c4ad..c678396f9d19c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ - /** * A helper trait used to create specialized setter and getter for types supported by * [[org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap]]'s buffer. From e923d2fab9e067c118a669faac5b6b3fdff2dfd0 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Tue, 7 Jul 2020 05:52:44 -0700 Subject: [PATCH 11/19] move jira into test name --- .../org/apache/spark/sql/hive/execution/UDAQuerySuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala index 6baccdd2f8709..9f6c6ca12972e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala @@ -355,8 +355,7 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil) } - // SPARK-32159 - test("array input types") { + test("SPARK-32159: array input types") { checkAnswer( spark.sql("SELECT arraysum(data) FROM agg3"), Row(21.0) :: Nil) From a4858d51ddc09958b9b42ec747d341be3bcec825 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Tue, 7 Jul 2020 05:53:12 -0700 Subject: [PATCH 12/19] add comment to new rule --- .../scala/org/apache/spark/sql/execution/aggregate/udaf.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index c678396f9d19c..0d290dc132b24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -521,6 +521,9 @@ case class ScalaAggregator[IN, BUF, OUT]( override def nodeName: String = agg.getClass.getSimpleName } +/** + * An extension rule to resolve ScalaAggregator input types from the input encoder + */ class ResolveEncodersInScalaAgg extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.resolved => p From d3c5d4d14480f61a769d37b95d8ad235ba04ac36 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Tue, 7 Jul 2020 06:05:13 -0700 Subject: [PATCH 13/19] make ResolveEncodersInScalaAgg an object --- .../scala/org/apache/spark/sql/execution/aggregate/udaf.scala | 2 +- .../org/apache/spark/sql/internal/BaseSessionStateBuilder.scala | 2 +- .../org/apache/spark/sql/hive/HiveSessionStateBuilder.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 0d290dc132b24..979840d2a815a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -524,7 +524,7 @@ case class ScalaAggregator[IN, BUF, OUT]( /** * An extension rule to resolve ScalaAggregator input types from the input encoder */ -class ResolveEncodersInScalaAgg extends Rule[LogicalPlan] { +object ResolveEncodersInScalaAgg extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.resolved => p case p => p.transformExpressionsUp { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 8405bbee3a264..4ae12f8716752 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -176,7 +176,7 @@ abstract class BaseSessionStateBuilder( new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: - new ResolveEncodersInScalaAgg() +: + ResolveEncodersInScalaAgg +: new ResolveSessionCatalog( catalogManager, conf, catalog.isTempView, catalog.isTempFunction) +: customResolutionRules diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index e16ef2596d412..e25610757a69b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -77,7 +77,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: - new ResolveEncodersInScalaAgg() +: + ResolveEncodersInScalaAgg +: new ResolveSessionCatalog( catalogManager, conf, catalog.isTempView, catalog.isTempFunction) +: customResolutionRules From 814956cd603393317534910b591acd1875bf1b70 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Tue, 7 Jul 2020 13:22:16 -0700 Subject: [PATCH 14/19] add bufferEncoder as a parameter to ScalaAggregator to make resolution easier --- .../apache/spark/sql/execution/aggregate/udaf.scala | 11 ++++++----- .../spark/sql/expressions/UserDefinedFunction.scala | 3 ++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 979840d2a815a..26d5394bd0abc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -461,6 +461,7 @@ case class ScalaAggregator[IN, BUF, OUT]( children: Seq[Expression], agg: Aggregator[IN, BUF, OUT], inputEncoder: ExpressionEncoder[IN], + bufferEncoder: ExpressionEncoder[BUF], nullable: Boolean = true, isDeterministic: Boolean = true, mutableAggBufferOffset: Int = 0, @@ -471,10 +472,8 @@ case class ScalaAggregator[IN, BUF, OUT]( with ImplicitCastInputTypes with Logging { - // input encoder is resolved by ResolveEncodersInScalaAgg + // input and buffer encoders are resolved by ResolveEncodersInScalaAgg private[this] lazy val inputDeserializer = inputEncoder.createDeserializer() - private[this] lazy val bufferEncoder = - agg.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]].resolveAndBind() private[this] lazy val bufferSerializer = bufferEncoder.createSerializer() private[this] lazy val bufferDeserializer = bufferEncoder.createDeserializer() private[this] lazy val outputEncoder = agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]] @@ -522,14 +521,16 @@ case class ScalaAggregator[IN, BUF, OUT]( } /** - * An extension rule to resolve ScalaAggregator input types from the input encoder + * An extension rule to resolve encoder expressions from a ScalaAggregator */ object ResolveEncodersInScalaAgg extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case p if !p.resolved => p case p => p.transformExpressionsUp { case agg: ScalaAggregator[_, _, _] => - agg.copy(inputEncoder = agg.inputEncoder.resolveAndBind()) + agg.copy( + inputEncoder = agg.inputEncoder.resolveAndBind(), + bufferEncoder = agg.bufferEncoder.resolveAndBind()) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 2ef6e3d291cef..6a20a46756f85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -150,7 +150,8 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT]( // This is also used by udf.register(...) when it detects a UserDefinedAggregator def scalaAggregator(exprs: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = { val iEncoder = inputEncoder.asInstanceOf[ExpressionEncoder[IN]] - ScalaAggregator(exprs, aggregator, iEncoder, nullable, deterministic) + val bEncoder = aggregator.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]] + ScalaAggregator(exprs, aggregator, iEncoder, bEncoder, nullable, deterministic) } override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = { From aca7b51727045e181740ee01c0f52cf83a2052b8 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Tue, 7 Jul 2020 13:23:17 -0700 Subject: [PATCH 15/19] test Array type with input, buffer and output encoders --- .../sql/hive/execution/UDAQuerySuite.scala | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala index 9f6c6ca12972e..2bd4db6f59b49 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala @@ -119,13 +119,25 @@ object CountSerDeAgg extends Aggregator[Int, CountSerDeSQL, CountSerDeSQL] { def outputEncoder: Encoder[CountSerDeSQL] = ExpressionEncoder[CountSerDeSQL]() } -object ArrayDataAgg extends Aggregator[Array[Double], Double, Double] { - def zero: Double = 0.0 - def reduce(s: Double, array: Array[Double]): Double = s + array.sum - def merge(s1: Double, s2: Double): Double = s1 + s2 - def finish(s: Double): Double = s - def bufferEncoder: Encoder[Double] = Encoders.scalaDouble - def outputEncoder: Encoder[Double] = Encoders.scalaDouble +object ArrayDataAgg extends Aggregator[Array[Double], Array[Double], Array[Double]] { + def zero: Array[Double] = Array(0.0, 0.0, 0.0) + def reduce(s: Array[Double], array: Array[Double]): Array[Double] = { + require(s.length == array.length) + for ( j <- 0 until s.length ) { + s(j) += array(j) + } + s + } + def merge(s1: Array[Double], s2: Array[Double]): Array[Double] = { + require(s1.length == s2.length) + for ( j <- 0 until s1.length ) { + s1(j) += s2(j) + } + s1 + } + def finish(s: Array[Double]): Array[Double] = s + def bufferEncoder: Encoder[Array[Double]] = ExpressionEncoder[Array[Double]] + def outputEncoder: Encoder[Array[Double]] = ExpressionEncoder[Array[Double]] } abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -166,9 +178,9 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi data2.write.saveAsTable("agg2") val data3 = Seq[(Seq[Double], Int)]( - (Seq(1.0), 0), - (Seq(2.0, 3.0), 0), - (Seq(4.0, 5.0, 6.0), 0) + (Seq(1.0, 2.0, 3.0), 0), + (Seq(4.0, 5.0, 6.0), 0), + (Seq(7.0, 8.0, 9.0), 0) ).toDF("data", "dummy") data3.write.saveAsTable("agg3") @@ -355,10 +367,10 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil) } - test("SPARK-32159: array input types") { + test("SPARK-32159: array encoder types") { checkAnswer( spark.sql("SELECT arraysum(data) FROM agg3"), - Row(21.0) :: Nil) + Row(Seq(12.0, 15.0, 18.0)) :: Nil) } test("verify aggregator ser/de behavior") { From e679c013ddd7ed9f5186b4df57a33486b13cffd3 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Tue, 7 Jul 2020 17:01:12 -0700 Subject: [PATCH 16/19] more detail on array test name --- .../org/apache/spark/sql/hive/execution/UDAQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala index 2bd4db6f59b49..1f1a5568b0201 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala @@ -367,7 +367,7 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil) } - test("SPARK-32159: array encoder types") { + test("SPARK-32159: array encoders should be resolved in analyzer") { checkAnswer( spark.sql("SELECT arraysum(data) FROM agg3"), Row(Seq(12.0, 15.0, 18.0)) :: Nil) From c6324375a40cfd916836758dafe1b978fbcad5d6 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Tue, 7 Jul 2020 17:02:29 -0700 Subject: [PATCH 17/19] add type link to scaladoc --- .../scala/org/apache/spark/sql/execution/aggregate/udaf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 26d5394bd0abc..44bc9c2e3a9d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -521,7 +521,7 @@ case class ScalaAggregator[IN, BUF, OUT]( } /** - * An extension rule to resolve encoder expressions from a ScalaAggregator + * An extension rule to resolve encoder expressions from a [[ScalaAggregator]] */ object ResolveEncodersInScalaAgg extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { From ee96cc012c19fc4611d214d44f9957ad76e64ad2 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Tue, 7 Jul 2020 17:23:44 -0700 Subject: [PATCH 18/19] improved exception message and corresponding comments --- .../sql/catalyst/expressions/objects/objects.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 46e30784ffcf2..ab9e51be22381 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -678,9 +678,13 @@ object MapObjects { elementType: DataType, elementNullable: Boolean = true, customCollectionCls: Option[Class[_]] = None): MapObjects = { - if (function == null) { - throw new UnsupportedOperationException("Cannot instantiate MapObjects with null function") - } + // UnresolvedMapObjects does not serialize its 'function' field. + // If an an array expression or array Encoder is not correctly resolved before + // serialization, this exception condition may occur. + require(function != null, + "MapObjects applied with a null function. " + + "Likely cause is failure to resolve an array expression or encoder. " + + "(See UnresolvedMapObjects)") val loopVar = LambdaVariable("MapObject", elementType, elementNullable) MapObjects(loopVar, function(loopVar), inputData, customCollectionCls) } From 622ac1c245e0918d9c99af2c0cb69671284ee7ac Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Wed, 8 Jul 2020 14:01:31 -0700 Subject: [PATCH 19/19] too much 'an' --- .../apache/spark/sql/catalyst/expressions/objects/objects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index ab9e51be22381..ab2f66b1a53e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -679,7 +679,7 @@ object MapObjects { elementNullable: Boolean = true, customCollectionCls: Option[Class[_]] = None): MapObjects = { // UnresolvedMapObjects does not serialize its 'function' field. - // If an an array expression or array Encoder is not correctly resolved before + // If an array expression or array Encoder is not correctly resolved before // serialization, this exception condition may occur. require(function != null, "MapObjects applied with a null function. " +