From e1b5deebe715479125c8878f0c90a55dc9ab3e85 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 9 Jul 2018 03:42:04 +0000 Subject: [PATCH 01/16] Aggregator should be able to use Option of Product encoder. --- .../catalyst/encoders/ExpressionEncoder.scala | 11 +++- .../spark/sql/DatasetAggregatorSuite.scala | 51 +++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index cbea3c017a26..1b357698d2ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -43,12 +43,19 @@ import org.apache.spark.util.Utils * to the name `value`. */ object ExpressionEncoder { - def apply[T : TypeTag](): ExpressionEncoder[T] = { + // Constructs an encoder for top-level row. + def apply[T : TypeTag](): ExpressionEncoder[T] = apply(topLevel = true) + + /** + * @param topLevel whether the encoders to construct are for top-level row. + */ + def apply[T : TypeTag](topLevel: Boolean): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = ScalaReflection.mirror val tpe = typeTag[T].in(mirror).tpe - if (ScalaReflection.optionOfProductType(tpe)) { + // For non top-level encodes, we allow using Option of Product type. + if (topLevel && ScalaReflection.optionOfProductType(tpe)) { throw new UnsupportedOperationException( "Cannot create encoder for Option of Product type, because Product type is represented " + "as a row, and the entire row can not be null in Spark SQL like normal databases. " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 538ea3c66c40..d31d6d345a7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -149,6 +149,7 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] { case class OptionBooleanData(name: String, isGood: Option[Boolean]) +case class OptionBooleanIntData(name: String, isGood: Option[(Boolean, Int)]) case class OptionBooleanAggregator(colName: String) extends Aggregator[Row, Option[Boolean], Option[Boolean]] { @@ -183,6 +184,43 @@ case class OptionBooleanAggregator(colName: String) def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder() } +case class OptionBooleanIntAggregator(colName: String) + extends Aggregator[Row, Option[(Boolean, Int)], Option[(Boolean, Int)]] { + + override def zero: Option[(Boolean, Int)] = None + + override def reduce(buffer: Option[(Boolean, Int)], row: Row): Option[(Boolean, Int)] = { + val index = row.fieldIndex(colName) + val value = if (row.isNullAt(index)) { + Option.empty[(Boolean, Int)] + } else { + val nestedRow = row.getStruct(index) + Some((nestedRow.getBoolean(0), nestedRow.getInt(1))) + } + merge(buffer, value) + } + + override def merge( + b1: Option[(Boolean, Int)], + b2: Option[(Boolean, Int)]): Option[(Boolean, Int)] = { + if ((b1.isDefined && b1.get._1) || (b2.isDefined && b2.get._1)) { + val newInt = b1.map(_._2).getOrElse(0) + b2.map(_._2).getOrElse(0) + Some((true, newInt)) + } else if (b1.isDefined) { + b1 + } else { + b2 + } + } + + override def finish(reduction: Option[(Boolean, Int)]): Option[(Boolean, Int)] = reduction + + override def bufferEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder + override def outputEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder + + def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder(topLevel = false) +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -393,4 +431,17 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { assert(grouped.schema == df.schema) checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) } + + test("SPARK-24762: Aggregator should be able to use Option of Product encoder") { + val df = Seq( + OptionBooleanIntData("bob", Some((true, 1))), + OptionBooleanIntData("bob", Some((false, 2))), + OptionBooleanIntData("bob", None)).toDF() + val group = df + .groupBy("name") + .agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood")) + assert(df.schema == group.schema) + checkAnswer(group, Row("bob", Row(true, 3)) :: Nil) + checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob", Some((true, 3)))) + } } From 80506f4e98184ccd66dbaac14ec52d69c358020d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 13 Jul 2018 04:40:55 +0000 Subject: [PATCH 02/16] Enable top-level Option of Product encoders. --- .../spark/sql/catalyst/ScalaReflection.scala | 92 +++++++++++-------- .../catalyst/encoders/ExpressionEncoder.scala | 14 +-- .../sql/catalyst/ScalaReflectionSuite.scala | 74 ++++++++++----- .../spark/sql/DatasetAggregatorSuite.scala | 1 + .../org/apache/spark/sql/DatasetSuite.scala | 27 ++++-- 5 files changed, 127 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4543bba8f6ed..3f2505d5d689 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -135,12 +135,20 @@ object ScalaReflection extends ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def deserializerFor[T : TypeTag]: Expression = { + def deserializerFor[T : TypeTag](topLevel: Boolean): Expression = cleanUpReflectionObjects { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil - val expr = deserializerFor(tpe, None, walkedTypePath) - val Schema(_, nullable) = schemaFor(tpe) + val Schema(dataType, tpeNullable) = schemaFor(tpe) + val isOptionOfProduct = tpe.dealias <:< localTypeOf[Option[_]] && + definedByConstructorParams(tpe) + val (optTypePath, nullable) = if (isOptionOfProduct && topLevel) { + // Top-level Option of Product is encoded as single struct column at top-level row. + (Some(addToPathOrdinal(None, 0, dataType, walkedTypePath)), true) + } else { + (None, tpeNullable) + } + val expr = deserializerFor(tpe, optTypePath, walkedTypePath) if (nullable) { expr } else { @@ -148,6 +156,40 @@ object ScalaReflection extends ScalaReflection { } } + /** + * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff + * and lost the required data type, which may lead to runtime error if the real type doesn't + * match the encoder's schema. + * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type + * is [a: int, b: long], then we will hit runtime error and say that we can't construct class + * `Data` with int and long, because we lost the information that `b` should be a string. + * + * This method help us "remember" the required data type by adding a `UpCast`. Note that we + * only need to do this for leaf nodes. + */ + def upCastToExpectedType( + expr: Expression, + expected: DataType, + walkedTypePath: Seq[String]): Expression = expected match { + case _: StructType => expr + case _: ArrayType => expr + // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and + // it's not trivial to support by-name resolution for StructType inside MapType. + case _ => UpCast(expr, expected, walkedTypePath) + } + + /** Returns the current path with a field at ordinal extracted. */ + def addToPathOrdinal( + path: Option[Expression], + ordinal: Int, + dataType: DataType, + walkedTypePath: Seq[String]): Expression = { + val newPath = path + .map(p => GetStructField(p, ordinal)) + .getOrElse(GetColumnByOrdinal(ordinal, dataType)) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } + private def deserializerFor( tpe: `Type`, path: Option[Expression], @@ -161,17 +203,6 @@ object ScalaReflection extends ScalaReflection { upCastToExpectedType(newPath, dataType, walkedTypePath) } - /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal( - ordinal: Int, - dataType: DataType, - walkedTypePath: Seq[String]): Expression = { - val newPath = path - .map(p => GetStructField(p, ordinal)) - .getOrElse(GetColumnByOrdinal(ordinal, dataType)) - upCastToExpectedType(newPath, dataType, walkedTypePath) - } - /** Returns the current path or `GetColumnByOrdinal`. */ def getPath: Expression = { val dataType = schemaFor(tpe).dataType @@ -182,28 +213,6 @@ object ScalaReflection extends ScalaReflection { } } - /** - * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff - * and lost the required data type, which may lead to runtime error if the real type doesn't - * match the encoder's schema. - * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type - * is [a: int, b: long], then we will hit runtime error and say that we can't construct class - * `Data` with int and long, because we lost the information that `b` should be a string. - * - * This method help us "remember" the required data type by adding a `UpCast`. Note that we - * only need to do this for leaf nodes. - */ - def upCastToExpectedType( - expr: Expression, - expected: DataType, - walkedTypePath: Seq[String]): Expression = expected match { - case _: StructType => expr - case _: ArrayType => expr - // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and - // it's not trivial to support by-name resolution for StructType inside MapType. - case _ => UpCast(expr, expected, walkedTypePath) - } - tpe.dealias match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath @@ -389,7 +398,7 @@ object ScalaReflection extends ScalaReflection { val constructor = if (cls.getName startsWith "scala.Tuple") { deserializerFor( fieldType, - Some(addToPathOrdinal(i, dataType, newTypePath)), + Some(addToPathOrdinal(path, i, dataType, newTypePath)), newTypePath) } else { deserializerFor( @@ -431,11 +440,18 @@ object ScalaReflection extends ScalaReflection { * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` */ - def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { + def serializerFor[T : TypeTag]( + inputObject: Expression, + topLevel: Boolean): CreateNamedStruct = cleanUpReflectionObjects { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil serializerFor(inputObject, tpe, walkedTypePath) match { + case i @ expressions.If(_, _, _: CreateNamedStruct) + if tpe.dealias <:< localTypeOf[Option[_]] && + definedByConstructorParams(tpe) && topLevel => + // We encode top-level Option of Product as a single struct column. + CreateNamedStruct(expressions.Literal("value") :: i :: Nil) case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 1b357698d2ec..861451012ff7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -54,16 +54,6 @@ object ExpressionEncoder { val mirror = ScalaReflection.mirror val tpe = typeTag[T].in(mirror).tpe - // For non top-level encodes, we allow using Option of Product type. - if (topLevel && ScalaReflection.optionOfProductType(tpe)) { - throw new UnsupportedOperationException( - "Cannot create encoder for Option of Product type, because Product type is represented " + - "as a row, and the entire row can not be null in Spark SQL like normal databases. " + - "You can wrap your type with Tuple1 if you do want top level null Product objects, " + - "e.g. instead of creating `Dataset[Option[MyClass]]`, you can do something like " + - "`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`") - } - val cls = mirror.runtimeClass(tpe) val flat = !ScalaReflection.definedByConstructorParams(tpe) @@ -75,8 +65,8 @@ object ExpressionEncoder { // doesn't allow top-level row to be null, only its columns can be null. AssertNotNull(inputObject, Seq("top level Product input object")) } - val serializer = ScalaReflection.serializerFor[T](nullSafeInput) - val deserializer = ScalaReflection.deserializerFor[T] + val serializer = ScalaReflection.serializerFor[T](nullSafeInput, topLevel) + val deserializer = ScalaReflection.deserializerFor[T](topLevel) val schema = serializer.dataType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 353b8344658f..0cb188ef9bb7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow, UpCast} -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, CreateNamedStruct, Expression, If, IsNull, Literal, SpecificInternalRow, UpCast} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance, WrapOption} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -281,7 +281,7 @@ class ScalaReflectionSuite extends SparkFunSuite { test("SPARK-15062: Get correct serializer for List[_]") { val list = List(1, 2, 3) val serializer = serializerFor[List[Int]](BoundReference( - 0, ObjectType(list.getClass), nullable = false)) + 0, ObjectType(list.getClass), nullable = false), topLevel = true) assert(serializer.children.size == 2) assert(serializer.children.head.isInstanceOf[Literal]) assert(serializer.children.head.asInstanceOf[Literal].value === UTF8String.fromString("value")) @@ -291,57 +291,57 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK 16792: Get correct deserializer for List[_]") { - val listDeserializer = deserializerFor[List[Int]] + val listDeserializer = deserializerFor[List[Int]](topLevel = true) assert(listDeserializer.dataType == ObjectType(classOf[List[_]])) } test("serialize and deserialize arbitrary sequence types") { import scala.collection.immutable.Queue val queueSerializer = serializerFor[Queue[Int]](BoundReference( - 0, ObjectType(classOf[Queue[Int]]), nullable = false)) + 0, ObjectType(classOf[Queue[Int]]), nullable = false), topLevel = true) assert(queueSerializer.dataType.head.dataType == ArrayType(IntegerType, containsNull = false)) - val queueDeserializer = deserializerFor[Queue[Int]] + val queueDeserializer = deserializerFor[Queue[Int]](topLevel = true) assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]])) import scala.collection.mutable.ArrayBuffer val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference( - 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false)) + 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false), topLevel = true) assert(arrayBufferSerializer.dataType.head.dataType == ArrayType(IntegerType, containsNull = false)) - val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] + val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]](topLevel = true) assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } test("serialize and deserialize arbitrary map types") { val mapSerializer = serializerFor[Map[Int, Int]](BoundReference( - 0, ObjectType(classOf[Map[Int, Int]]), nullable = false)) + 0, ObjectType(classOf[Map[Int, Int]]), nullable = false), topLevel = true) assert(mapSerializer.dataType.head.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val mapDeserializer = deserializerFor[Map[Int, Int]] + val mapDeserializer = deserializerFor[Map[Int, Int]](topLevel = true) assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) import scala.collection.immutable.HashMap val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference( - 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false)) + 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false), topLevel = true) assert(hashMapSerializer.dataType.head.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val hashMapDeserializer = deserializerFor[HashMap[Int, Int]] + val hashMapDeserializer = deserializerFor[HashMap[Int, Int]](topLevel = true) assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) import scala.collection.mutable.{LinkedHashMap => LHMap} val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference( - 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false)) + 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false), topLevel = true) assert(linkedHashMapSerializer.dataType.head.dataType == MapType(LongType, StringType, valueContainsNull = true)) - val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] + val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]](topLevel = true) assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) } test("SPARK-22442: Generate correct field names for special characters") { val serializer = serializerFor[SpecialCharAsFieldData](BoundReference( - 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false)) - val deserializer = deserializerFor[SpecialCharAsFieldData] + 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false), topLevel = true) + val deserializer = deserializerFor[SpecialCharAsFieldData](topLevel = true) assert(serializer.dataType(0).name == "field.1") assert(serializer.dataType(1).name == "field 2") @@ -353,8 +353,8 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK-22472: add null check for top-level primitive values") { - assert(deserializerFor[Int].isInstanceOf[AssertNotNull]) - assert(!deserializerFor[String].isInstanceOf[AssertNotNull]) + assert(deserializerFor[Int](topLevel = true).isInstanceOf[AssertNotNull]) + assert(!deserializerFor[String](topLevel = true).isInstanceOf[AssertNotNull]) } test("SPARK-23025: schemaFor should support Null type") { @@ -371,8 +371,38 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(deserializer.isInstanceOf[NewInstance]) deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull]) } - assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2) - assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) - assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) + assert(numberOfCheckedArguments(deserializerFor[(Double, Double)](topLevel = true)) == 2) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)](topLevel = true)) == 1) + assert(numberOfCheckedArguments( + deserializerFor[(java.lang.Integer, java.lang.Integer)](topLevel = true)) == 0) + } + + test("SPARK-24762: serializer for Option of Product") { + val optionOfProduct = Some((1, "a")) + val topLevelSerializer = serializerFor[Option[(Int, String)]](BoundReference( + 0, ObjectType(optionOfProduct.getClass), nullable = true), topLevel = true) + val nonTopLevelSerializer = serializerFor[Option[(Int, String)]](BoundReference( + 0, ObjectType(optionOfProduct.getClass), nullable = true), topLevel = false) + + topLevelSerializer match { + case CreateNamedStruct(Seq(Literal(_, _), If(_, _, optEncoder))) => + assert(optEncoder.semanticEquals(nonTopLevelSerializer)) + case _ => + fail("top-level Option of Product should be encoded as single struct column.") + } + } + + test("SPARK-24762: deserializer for Option of Product") { + val topLevelDeserializer = deserializerFor[Option[(Int, String)]](topLevel = true) + val nonTopLevelDeserializer = deserializerFor[Option[(Int, String)]](topLevel = false) + .asInstanceOf[WrapOption] + + topLevelDeserializer match { + case WrapOption(If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance), optType) => + assert(n.cls == nonTopLevelDeserializer.child.asInstanceOf[NewInstance].cls) + assert(optType == nonTopLevelDeserializer.optType) + case _ => + fail("top-level Option of Product should be decoded from a single struct column.") + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index d31d6d345a7a..33241671dbcf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -437,6 +437,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { OptionBooleanIntData("bob", Some((true, 1))), OptionBooleanIntData("bob", Some((false, 2))), OptionBooleanIntData("bob", None)).toDF() + val group = df .groupBy("name") .agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index ce8db99d4e2f..843eb224dfb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1253,15 +1253,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(dsString, arrayString) } - test("SPARK-18251: the type of Dataset can't be Option of Product type") { - checkDataset(Seq(Some(1), None).toDS(), Some(1), None) - - val e = intercept[UnsupportedOperationException] { - Seq(Some(1 -> "a"), None).toDS() - } - assert(e.getMessage.contains("Cannot create encoder for Option of Product type")) - } - test ("SPARK-17460: the sizeInBytes in Statistics shouldn't overflow to a negative number") { // Since the sizeInBytes in Statistics could exceed the limit of an Int, we should use BigInt // instead of Int for avoiding possible overflow. @@ -1498,6 +1489,24 @@ class DatasetSuite extends QueryTest with SharedSQLContext { df.where($"city".contains(new java.lang.Character('A'))), Seq(Row("Amsterdam"))) } + + test("SPARK-24762: Enable top-level Option of Product encoders") { + val data = Seq(Some((1, "a")), Some((2, "b")), None) + val ds = data.toDS() + + checkDataset( + ds, + data: _*) + + val schema = StructType(Seq( + StructField("value", StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", StringType, nullable = true) + )), nullable = true) + )) + + assert(ds.schema == schema) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From ed3d5cb697b10af2e2cf4c78ab521d4d0b2f3c9b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 24 Aug 2018 04:26:28 +0000 Subject: [PATCH 03/16] Remove topLevel parameter. --- .../spark/sql/catalyst/ScalaReflection.scala | 10 ++- .../catalyst/encoders/ExpressionEncoder.scala | 12 +--- .../sql/catalyst/ScalaReflectionSuite.scala | 68 +++++++++--------- .../aggregate/TypedAggregateExpression.scala | 72 +++++++++++++++++-- .../spark/sql/DatasetAggregatorSuite.scala | 2 +- 5 files changed, 107 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 3f2505d5d689..ebd2d3bf0dc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -135,14 +135,14 @@ object ScalaReflection extends ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def deserializerFor[T : TypeTag](topLevel: Boolean): Expression = cleanUpReflectionObjects { + def deserializerFor[T : TypeTag](): Expression = cleanUpReflectionObjects { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil val Schema(dataType, tpeNullable) = schemaFor(tpe) val isOptionOfProduct = tpe.dealias <:< localTypeOf[Option[_]] && definedByConstructorParams(tpe) - val (optTypePath, nullable) = if (isOptionOfProduct && topLevel) { + val (optTypePath, nullable) = if (isOptionOfProduct) { // Top-level Option of Product is encoded as single struct column at top-level row. (Some(addToPathOrdinal(None, 0, dataType, walkedTypePath)), true) } else { @@ -441,15 +441,13 @@ object ScalaReflection extends ScalaReflection { * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` */ def serializerFor[T : TypeTag]( - inputObject: Expression, - topLevel: Boolean): CreateNamedStruct = cleanUpReflectionObjects { + inputObject: Expression): CreateNamedStruct = cleanUpReflectionObjects { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil serializerFor(inputObject, tpe, walkedTypePath) match { case i @ expressions.If(_, _, _: CreateNamedStruct) - if tpe.dealias <:< localTypeOf[Option[_]] && - definedByConstructorParams(tpe) && topLevel => + if tpe.dealias <:< localTypeOf[Option[_]] && definedByConstructorParams(tpe) => // We encode top-level Option of Product as a single struct column. CreateNamedStruct(expressions.Literal("value") :: i :: Nil) case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 861451012ff7..a90137d0029d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -43,13 +43,7 @@ import org.apache.spark.util.Utils * to the name `value`. */ object ExpressionEncoder { - // Constructs an encoder for top-level row. - def apply[T : TypeTag](): ExpressionEncoder[T] = apply(topLevel = true) - - /** - * @param topLevel whether the encoders to construct are for top-level row. - */ - def apply[T : TypeTag](topLevel: Boolean): ExpressionEncoder[T] = { + def apply[T : TypeTag](): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = ScalaReflection.mirror val tpe = typeTag[T].in(mirror).tpe @@ -65,8 +59,8 @@ object ExpressionEncoder { // doesn't allow top-level row to be null, only its columns can be null. AssertNotNull(inputObject, Seq("top level Product input object")) } - val serializer = ScalaReflection.serializerFor[T](nullSafeInput, topLevel) - val deserializer = ScalaReflection.deserializerFor[T](topLevel) + val serializer = ScalaReflection.serializerFor[T](nullSafeInput) + val deserializer = ScalaReflection.deserializerFor[T]() val schema = serializer.dataType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 0cb188ef9bb7..da9f2f2d0929 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -281,7 +281,7 @@ class ScalaReflectionSuite extends SparkFunSuite { test("SPARK-15062: Get correct serializer for List[_]") { val list = List(1, 2, 3) val serializer = serializerFor[List[Int]](BoundReference( - 0, ObjectType(list.getClass), nullable = false), topLevel = true) + 0, ObjectType(list.getClass), nullable = false)) assert(serializer.children.size == 2) assert(serializer.children.head.isInstanceOf[Literal]) assert(serializer.children.head.asInstanceOf[Literal].value === UTF8String.fromString("value")) @@ -291,57 +291,57 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK 16792: Get correct deserializer for List[_]") { - val listDeserializer = deserializerFor[List[Int]](topLevel = true) + val listDeserializer = deserializerFor[List[Int]]() assert(listDeserializer.dataType == ObjectType(classOf[List[_]])) } test("serialize and deserialize arbitrary sequence types") { import scala.collection.immutable.Queue val queueSerializer = serializerFor[Queue[Int]](BoundReference( - 0, ObjectType(classOf[Queue[Int]]), nullable = false), topLevel = true) + 0, ObjectType(classOf[Queue[Int]]), nullable = false)) assert(queueSerializer.dataType.head.dataType == ArrayType(IntegerType, containsNull = false)) - val queueDeserializer = deserializerFor[Queue[Int]](topLevel = true) + val queueDeserializer = deserializerFor[Queue[Int]]() assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]])) import scala.collection.mutable.ArrayBuffer val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference( - 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false), topLevel = true) + 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false)) assert(arrayBufferSerializer.dataType.head.dataType == ArrayType(IntegerType, containsNull = false)) - val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]](topLevel = true) + val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]() assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } test("serialize and deserialize arbitrary map types") { val mapSerializer = serializerFor[Map[Int, Int]](BoundReference( - 0, ObjectType(classOf[Map[Int, Int]]), nullable = false), topLevel = true) + 0, ObjectType(classOf[Map[Int, Int]]), nullable = false)) assert(mapSerializer.dataType.head.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val mapDeserializer = deserializerFor[Map[Int, Int]](topLevel = true) + val mapDeserializer = deserializerFor[Map[Int, Int]]() assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) import scala.collection.immutable.HashMap val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference( - 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false), topLevel = true) + 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false)) assert(hashMapSerializer.dataType.head.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val hashMapDeserializer = deserializerFor[HashMap[Int, Int]](topLevel = true) + val hashMapDeserializer = deserializerFor[HashMap[Int, Int]]() assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) import scala.collection.mutable.{LinkedHashMap => LHMap} val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference( - 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false), topLevel = true) + 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false)) assert(linkedHashMapSerializer.dataType.head.dataType == MapType(LongType, StringType, valueContainsNull = true)) - val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]](topLevel = true) + val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]]() assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) } test("SPARK-22442: Generate correct field names for special characters") { val serializer = serializerFor[SpecialCharAsFieldData](BoundReference( - 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false), topLevel = true) - val deserializer = deserializerFor[SpecialCharAsFieldData](topLevel = true) + 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false)) + val deserializer = deserializerFor[SpecialCharAsFieldData]() assert(serializer.dataType(0).name == "field.1") assert(serializer.dataType(1).name == "field 2") @@ -353,8 +353,8 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK-22472: add null check for top-level primitive values") { - assert(deserializerFor[Int](topLevel = true).isInstanceOf[AssertNotNull]) - assert(!deserializerFor[String](topLevel = true).isInstanceOf[AssertNotNull]) + assert(deserializerFor[Int]().isInstanceOf[AssertNotNull]) + assert(!deserializerFor[String]().isInstanceOf[AssertNotNull]) } test("SPARK-23025: schemaFor should support Null type") { @@ -371,36 +371,34 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(deserializer.isInstanceOf[NewInstance]) deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull]) } - assert(numberOfCheckedArguments(deserializerFor[(Double, Double)](topLevel = true)) == 2) - assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)](topLevel = true)) == 1) + assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]()) == 2) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]()) == 1) assert(numberOfCheckedArguments( - deserializerFor[(java.lang.Integer, java.lang.Integer)](topLevel = true)) == 0) + deserializerFor[(java.lang.Integer, java.lang.Integer)]()) == 0) } test("SPARK-24762: serializer for Option of Product") { val optionOfProduct = Some((1, "a")) - val topLevelSerializer = serializerFor[Option[(Int, String)]](BoundReference( - 0, ObjectType(optionOfProduct.getClass), nullable = true), topLevel = true) - val nonTopLevelSerializer = serializerFor[Option[(Int, String)]](BoundReference( - 0, ObjectType(optionOfProduct.getClass), nullable = true), topLevel = false) - - topLevelSerializer match { - case CreateNamedStruct(Seq(Literal(_, _), If(_, _, optEncoder))) => - assert(optEncoder.semanticEquals(nonTopLevelSerializer)) + val serializer = serializerFor[Option[(Int, String)]](BoundReference( + 0, ObjectType(optionOfProduct.getClass), nullable = true)) + + serializer match { + case CreateNamedStruct(Seq(_: Literal, If(_, _, encoder: CreateNamedStruct))) => + val fields = encoder.flatten + assert(fields.length == 2) + assert(fields(0).dataType == IntegerType) + assert(fields(1).dataType == StringType) case _ => fail("top-level Option of Product should be encoded as single struct column.") } } test("SPARK-24762: deserializer for Option of Product") { - val topLevelDeserializer = deserializerFor[Option[(Int, String)]](topLevel = true) - val nonTopLevelDeserializer = deserializerFor[Option[(Int, String)]](topLevel = false) - .asInstanceOf[WrapOption] - - topLevelDeserializer match { - case WrapOption(If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance), optType) => - assert(n.cls == nonTopLevelDeserializer.child.asInstanceOf[NewInstance].cls) - assert(optType == nonTopLevelDeserializer.optType) + val deserializer = deserializerFor[Option[(Int, String)]]() + + deserializer match { + case WrapOption(If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance), _) => + assert(n.cls == classOf[Tuple2[Int, String]]) case _ => fail("top-level Option of Product should be decoded from a single struct column.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 6d44890704f4..c22d0b19a437 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -19,25 +19,85 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials -import org.apache.spark.sql.Encoder +import org.apache.spark.sql.{AnalysisException, Encoder} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedDeserializer} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection -import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance, WrapOption} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ import org.apache.spark.util.Utils object TypedAggregateExpression { + + // Checks if given encoder is for `Option[Product]`. + def isOptProductEncoder(encoder: ExpressionEncoder[_]): Boolean = { + // Only Option[Product] is non-flat. + encoder.clsTag.runtimeClass == classOf[Option[_]] && !encoder.flat + } + + /** + * Flattens serializers and deserializer of given encoder. We only flatten encoder + * of `Option[Product]` class. + */ + def flattenOptProductEncoder(encoder: ExpressionEncoder[_]): ExpressionEncoder[_] = { + val serializer = encoder.serializer + val deserializer = encoder.deserializer + + assert(serializer.length == 1, + "We can only flatten encoder of Option of Product class which has single serializer.") + + val flattenSerializers = serializer(0).collect { + case c: CreateNamedStruct => c.flatten + }.head + + val flattenDeserializer = deserializer match { + case w @ WrapOption(If(_, _, child: NewInstance), optType) => + val newInstance = child match { + case oldNewInstance: NewInstance => + val newArguments = oldNewInstance.arguments.zipWithIndex.map { case (arg, idx) => + arg match { + case a @ AssertNotNull( + UpCast(GetStructField( + child @ GetColumnByOrdinal(0, _), _, _), dt, walkedTypePath), _) => + a.copy(child = UpCast(GetColumnByOrdinal(idx, dt), dt, walkedTypePath.tail)) + } + } + oldNewInstance.copy(arguments = newArguments) + } + w.copy(child = newInstance) + case _ => + throw new AnalysisException( + "On top of deserializer of Option[Product] should be `WrapOption`.") + } + + // `Option[Product]` is encoded as single column of struct type in a row. + val newSchema = encoder.schema.asInstanceOf[StructType].fields(0) + .dataType.asInstanceOf[StructType] + encoder.copy(serializer = flattenSerializers, deserializer = flattenDeserializer, + schema = newSchema) + } + def apply[BUF : Encoder, OUT : Encoder]( aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { - val bufferEncoder = encoderFor[BUF] + val rawBufferEncoder = encoderFor[BUF] + + val bufferEncoder = if (isOptProductEncoder(rawBufferEncoder)) { + flattenOptProductEncoder(rawBufferEncoder) + } else { + rawBufferEncoder + } val bufferSerializer = bufferEncoder.namedExpressions - val outputEncoder = encoderFor[OUT] + val rawOutputEncoder = encoderFor[OUT] + val outputEncoder = if (isOptProductEncoder(rawOutputEncoder)) { + flattenOptProductEncoder(rawOutputEncoder) + } else { + rawOutputEncoder + } val outputType = if (outputEncoder.flat) { outputEncoder.schema.head.dataType } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 33241671dbcf..0446bd9097b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -218,7 +218,7 @@ case class OptionBooleanIntAggregator(colName: String) override def bufferEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder override def outputEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder - def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder(topLevel = false) + def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder() } class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { From 5f95bd0cf1bd308c7df55c41caef7a9f19368f5d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 24 Aug 2018 04:42:33 +0000 Subject: [PATCH 04/16] Remove useless change. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../sql/catalyst/ScalaReflectionSuite.scala | 27 +++++++++---------- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index bc11191e6959..fbcb1ada1a17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -135,7 +135,7 @@ object ScalaReflection extends ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def deserializerFor[T : TypeTag](): Expression = cleanUpReflectionObjects { + def deserializerFor[T : TypeTag]: Expression = cleanUpReflectionObjects { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index a90137d0029d..0a1c23886159 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -60,7 +60,7 @@ object ExpressionEncoder { AssertNotNull(inputObject, Seq("top level Product input object")) } val serializer = ScalaReflection.serializerFor[T](nullSafeInput) - val deserializer = ScalaReflection.deserializerFor[T]() + val deserializer = ScalaReflection.deserializerFor[T] val schema = serializer.dataType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index da9f2f2d0929..750f0b03e46a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -291,7 +291,7 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK 16792: Get correct deserializer for List[_]") { - val listDeserializer = deserializerFor[List[Int]]() + val listDeserializer = deserializerFor[List[Int]] assert(listDeserializer.dataType == ObjectType(classOf[List[_]])) } @@ -301,7 +301,7 @@ class ScalaReflectionSuite extends SparkFunSuite { 0, ObjectType(classOf[Queue[Int]]), nullable = false)) assert(queueSerializer.dataType.head.dataType == ArrayType(IntegerType, containsNull = false)) - val queueDeserializer = deserializerFor[Queue[Int]]() + val queueDeserializer = deserializerFor[Queue[Int]] assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]])) import scala.collection.mutable.ArrayBuffer @@ -309,7 +309,7 @@ class ScalaReflectionSuite extends SparkFunSuite { 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false)) assert(arrayBufferSerializer.dataType.head.dataType == ArrayType(IntegerType, containsNull = false)) - val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]() + val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } @@ -318,7 +318,7 @@ class ScalaReflectionSuite extends SparkFunSuite { 0, ObjectType(classOf[Map[Int, Int]]), nullable = false)) assert(mapSerializer.dataType.head.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val mapDeserializer = deserializerFor[Map[Int, Int]]() + val mapDeserializer = deserializerFor[Map[Int, Int]] assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) import scala.collection.immutable.HashMap @@ -326,7 +326,7 @@ class ScalaReflectionSuite extends SparkFunSuite { 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false)) assert(hashMapSerializer.dataType.head.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) - val hashMapDeserializer = deserializerFor[HashMap[Int, Int]]() + val hashMapDeserializer = deserializerFor[HashMap[Int, Int]] assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) import scala.collection.mutable.{LinkedHashMap => LHMap} @@ -334,14 +334,14 @@ class ScalaReflectionSuite extends SparkFunSuite { 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false)) assert(linkedHashMapSerializer.dataType.head.dataType == MapType(LongType, StringType, valueContainsNull = true)) - val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]]() + val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) } test("SPARK-22442: Generate correct field names for special characters") { val serializer = serializerFor[SpecialCharAsFieldData](BoundReference( 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false)) - val deserializer = deserializerFor[SpecialCharAsFieldData]() + val deserializer = deserializerFor[SpecialCharAsFieldData] assert(serializer.dataType(0).name == "field.1") assert(serializer.dataType(1).name == "field 2") @@ -353,8 +353,8 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK-22472: add null check for top-level primitive values") { - assert(deserializerFor[Int]().isInstanceOf[AssertNotNull]) - assert(!deserializerFor[String]().isInstanceOf[AssertNotNull]) + assert(deserializerFor[Int].isInstanceOf[AssertNotNull]) + assert(!deserializerFor[String].isInstanceOf[AssertNotNull]) } test("SPARK-23025: schemaFor should support Null type") { @@ -371,10 +371,9 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(deserializer.isInstanceOf[NewInstance]) deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull]) } - assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]()) == 2) - assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]()) == 1) - assert(numberOfCheckedArguments( - deserializerFor[(java.lang.Integer, java.lang.Integer)]()) == 0) + assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) } test("SPARK-24762: serializer for Option of Product") { @@ -394,7 +393,7 @@ class ScalaReflectionSuite extends SparkFunSuite { } test("SPARK-24762: deserializer for Option of Product") { - val deserializer = deserializerFor[Option[(Int, String)]]() + val deserializer = deserializerFor[Option[(Int, String)]] deserializer match { case WrapOption(If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance), _) => From a4f04055b2ba22f371663565710328791942855a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 24 Aug 2018 14:38:16 +0000 Subject: [PATCH 05/16] Add more tests. --- .../aggregate/TypedAggregateExpression.scala | 2 +- .../TypedAggregateExpressionSuite.scala | 63 +++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index c22d0b19a437..27a50a270b09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -43,7 +43,7 @@ object TypedAggregateExpression { * Flattens serializers and deserializer of given encoder. We only flatten encoder * of `Option[Product]` class. */ - def flattenOptProductEncoder(encoder: ExpressionEncoder[_]): ExpressionEncoder[_] = { + def flattenOptProductEncoder[T](encoder: ExpressionEncoder[T]): ExpressionEncoder[T] = { val serializer = encoder.serializer val deserializer = encoder.deserializer diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala new file mode 100644 index 000000000000..f54557b1e0f5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.encoders._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + + +class TypedAggregateExpressionSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + private def testOptProductEncoder(encoder: ExpressionEncoder[_], expected: Boolean): Unit = { + assert(TypedAggregateExpression.isOptProductEncoder(encoder) == expected) + } + + test("check an encoder is for option of product") { + testOptProductEncoder(encoderFor[Int], false) + testOptProductEncoder(encoderFor[(Long, Long)], false) + testOptProductEncoder(encoderFor[Option[Int]], false) + testOptProductEncoder(encoderFor[Option[(Int, Long)]], true) + testOptProductEncoder(encoderFor[Option[SimpleCaseClass]], true) + } + + test("flatten encoders of option of product") { + // Option[Product] is encoded as a struct column in a row. + val optProductEncoder: ExpressionEncoder[Option[(Int, Long)]] = encoderFor[Option[(Int, Long)]] + val optProductSchema = StructType(StructField("value", StructType( + StructField("_1", IntegerType) :: StructField("_2", LongType) :: Nil)) :: Nil) + + assert(optProductEncoder.schema.length == 1) + assert(DataType.equalsIgnoreCaseAndNullability(optProductEncoder.schema, optProductSchema)) + + val flattenEncoder = TypedAggregateExpression.flattenOptProductEncoder(optProductEncoder) + .resolveAndBind() + assert(flattenEncoder.schema.length == 2) + assert(DataType.equalsIgnoreCaseAndNullability(flattenEncoder.schema, + optProductSchema.fields(0).dataType)) + + val row = flattenEncoder.toRow(Some((1, 2L))) + val expected = flattenEncoder.fromRow(row) + assert(Some((1, 2L)) == expected) + } +} + +case class SimpleCaseClass(a: Int) From c1f798f7e9cba0d04223eed06f1b1f547ec29dc5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 25 Aug 2018 01:52:01 +0000 Subject: [PATCH 06/16] Add test. --- .../org/apache/spark/sql/DatasetSuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index e292795d6b7c..ce8c5ea6cc00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1506,6 +1506,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { )) assert(ds.schema == schema) + + val nestedOptData = Seq(Some((Some((1, "a")), 2.0)), Some((Some((2, "b")), 3.0))) + val nestedDs = nestedOptData.toDS() + + checkDataset( + nestedDs, + nestedOptData: _*) + + val nestedSchema = StructType(Seq( + StructField("value", StructType(Seq( + StructField("_1", StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", StringType, nullable = true)))), + StructField("_2", DoubleType, nullable = false) + )), nullable = true) + )) + assert(nestedDs.schema == nestedSchema) } test("SPARK-23034 show rdd names in RDD scan nodes") { From 0f029b0a28700334dc6334f1ad89b3124f235a51 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 6 Oct 2018 04:40:07 +0000 Subject: [PATCH 07/16] Improve code comments. --- .../spark/sql/catalyst/ScalaReflection.scala | 8 ++++++-- .../aggregate/TypedAggregateExpression.scala | 14 ++++++++++---- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 6e06a755b6f8..2ce7b02afd8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -143,7 +143,8 @@ object ScalaReflection extends ScalaReflection { val isOptionOfProduct = tpe.dealias <:< localTypeOf[Option[_]] && definedByConstructorParams(tpe) val (optTypePath, nullable) = if (isOptionOfProduct) { - // Top-level Option of Product is encoded as single struct column at top-level row. + // Because we encode top-level Option[Product] as a struct at the first column of the row, + // we add zero ordinal as the path to access it when to deserialize it. (Some(addToPathOrdinal(None, 0, dataType, walkedTypePath)), true) } else { (None, tpeNullable) @@ -448,7 +449,10 @@ object ScalaReflection extends ScalaReflection { serializerFor(inputObject, tpe, walkedTypePath) match { case i @ expressions.If(_, _, _: CreateNamedStruct) if tpe.dealias <:< localTypeOf[Option[_]] && definedByConstructorParams(tpe) => - // We encode top-level Option of Product as a single struct column. + // When we are going to serialize an Option[Product] at top-level of row, because + // Spark doesn't support top-level row as null, we encode the Option[Product] as a + // struct at the first column of the row. So here we add an extra named struct wrapping + // the serialized Option[Product] which is the first and only column named `value`. CreateNamedStruct(expressions.Literal("value") :: i :: Nil) case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 27a50a270b09..1cca43e77785 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -35,7 +35,7 @@ object TypedAggregateExpression { // Checks if given encoder is for `Option[Product]`. def isOptProductEncoder(encoder: ExpressionEncoder[_]): Boolean = { - // Only Option[Product] is non-flat. + // For all Option[_] classes, only Option[Product] is reported as not flat. encoder.clsTag.runtimeClass == classOf[Option[_]] && !encoder.flat } @@ -47,8 +47,10 @@ object TypedAggregateExpression { val serializer = encoder.serializer val deserializer = encoder.deserializer + // This is just a sanity check. Encoders of Option[Product] has only one `CreateNamedStruct` + // serializer expression. assert(serializer.length == 1, - "We can only flatten encoder of Option of Product class which has single serializer.") + "We only flatten encoder of Option[Product] class which has single serializer.") val flattenSerializers = serializer(0).collect { case c: CreateNamedStruct => c.flatten @@ -74,8 +76,7 @@ object TypedAggregateExpression { "On top of deserializer of Option[Product] should be `WrapOption`.") } - // `Option[Product]` is encoded as single column of struct type in a row. - val newSchema = encoder.schema.asInstanceOf[StructType].fields(0) + val newSchema = encoder.schema.fields(0) .dataType.asInstanceOf[StructType] encoder.copy(serializer = flattenSerializers, deserializer = flattenDeserializer, schema = newSchema) @@ -85,6 +86,11 @@ object TypedAggregateExpression { aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { val rawBufferEncoder = encoderFor[BUF] + // When `BUF` or `OUT` is an Option[Product], we need to flatten serializers and deserializer + // of original encoder. It is because we wrap serializers of Option[Product] inside an extra + // struct in order to support encoding of Option[Product] at top-level row. But here we use + // the encoder to encode Option[Product] for a column, we need to get rid of this extra + // struct. val bufferEncoder = if (isOptProductEncoder(rawBufferEncoder)) { flattenOptProductEncoder(rawBufferEncoder) } else { From 16af64c3718cb427aa739e4dd764ba4a75e744e8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 27 Oct 2018 22:22:56 +0800 Subject: [PATCH 08/16] Enable top-level Option[Product] encoder. --- .../catalyst/encoders/ExpressionEncoder.scala | 7 +++-- .../sql/catalyst/ScalaReflectionSuite.scala | 27 +++++++++++-------- .../org/apache/spark/sql/DatasetSuite.scala | 2 -- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 54085bb604cb..484c2467eeb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -198,7 +198,7 @@ case class ExpressionEncoder[T]( val serializer: Seq[NamedExpression] = { val clsName = Utils.getSimpleName(clsTag.runtimeClass) - if (isSerializedAsStruct) { + if (isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass)) { val nullSafeSerializer = objSerializer.transformUp { case r: BoundReference => // For input object of Product type, we can't encode it to row if it's null, as Spark SQL @@ -214,6 +214,9 @@ case class ExpressionEncoder[T]( } else { // For other input objects like primitive, array, map, etc., we construct a struct to wrap // the serializer which is a column of an row. + // + // Note: Because Spark SQL doesn't allow top-level row to be null, to encode + // top-level Option[Product] type, we make it as a top-level struct column. CreateNamedStruct(Literal("value") :: objSerializer :: Nil) } }.flatten @@ -227,7 +230,7 @@ case class ExpressionEncoder[T]( * `GetColumnByOrdinal` with corresponding ordinal. */ val deserializer: Expression = { - if (isSerializedAsStruct) { + if (isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass)) { // We serialized this kind of objects to root-level row. The input of general deserializer // is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to // transform attributes accessors. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index f77fc0ad3bbc..26553dee39bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -22,9 +22,9 @@ import java.sql.{Date, Timestamp} import scala.reflect.runtime.universe.TypeTag import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue -import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, If, SpecificInternalRow, UpCast} -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, GetStructField, If, IsNull, SpecificInternalRow, UpCast} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance, WrapOption} import org.apache.spark.sql.types._ case class PrimitiveData( @@ -363,14 +363,15 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) } - /* test("SPARK-24762: serializer for Option of Product") { - val optionOfProduct = Some((1, "a")) - val serializer = serializerFor[Option[(Int, String)]](BoundReference( - 0, ObjectType(optionOfProduct.getClass), nullable = true)) + val serializer = serializerFor[Option[(Int, String)]] + val datatype = StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", StringType, nullable = true))) + assert(serializer.dataType == datatype) serializer match { - case CreateNamedStruct(Seq(_: Literal, If(_, _, encoder: CreateNamedStruct))) => + case If(_, _, encoder: CreateNamedStruct) => val fields = encoder.flatten assert(fields.length == 2) assert(fields(0).dataType == IntegerType) @@ -384,11 +385,15 @@ class ScalaReflectionSuite extends SparkFunSuite { val deserializer = deserializerFor[Option[(Int, String)]] deserializer match { - case WrapOption(If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance), _) => + case WrapOption(If(IsNull(g @ GetColumnByOrdinal(0, _)), _, n: NewInstance), _) => assert(n.cls == classOf[Tuple2[Int, String]]) + val arguments = n.arguments.flatMap(_.collect { + case g: GetStructField => g + }) + assert(arguments(0) == GetStructField(g, 0)) + assert(arguments(1) == GetStructField(g, 1)) case _ => - fail("top-level Option of Product should be decoded from a single struct column.") + fail("Incorrect deserializer of Option of Product") } } - */ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 1544c54aaf3b..f9cdcdc02bbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1539,7 +1539,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(Row("Amsterdam"))) } - /* test("SPARK-24762: Enable top-level Option of Product encoders") { val data = Seq(Some((1, "a")), Some((2, "b")), None) val ds = data.toDS() @@ -1574,7 +1573,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { )) assert(nestedDs.schema == nestedSchema) } - */ } case class TestDataUnion(x: Int, y: Int, z: Int) From 79d10c1ebc7b29a7d05bc1fb71dd543eab23db24 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 27 Oct 2018 22:52:16 +0800 Subject: [PATCH 09/16] Enable Option[Product] encoder for Aggregator. --- .../aggregate/TypedAggregateExpression.scala | 20 +++--- .../spark/sql/DatasetAggregatorSuite.scala | 3 +- .../TypedAggregateExpressionSuite.scala | 63 ------------------- 3 files changed, 8 insertions(+), 78 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index db75757b4116..35a8f36868d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials -import org.apache.spark.sql.{AnalysisException, Encoder} +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedDeserializer} -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance, WrapOption} +import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -76,7 +76,7 @@ object TypedAggregateExpression { None, bufferSerializer, bufferEncoder.resolveAndBind().deserializer, - outputEncoder.serializer, + outputEncoder.objSerializer, outputType, outputEncoder.objSerializer.nullable) } @@ -213,7 +213,7 @@ case class ComplexTypedAggregateExpression( inputSchema: Option[StructType], bufferSerializer: Seq[NamedExpression], bufferDeserializer: Expression, - outputSerializer: Seq[Expression], + outputSerializer: Expression, dataType: DataType, nullable: Boolean, mutableAggBufferOffset: Int = 0, @@ -245,13 +245,7 @@ case class ComplexTypedAggregateExpression( aggregator.merge(buffer, input) } - private lazy val resultObjToRow = dataType match { - case _: StructType => - UnsafeProjection.create(CreateStruct(outputSerializer)) - case _ => - assert(outputSerializer.length == 1) - UnsafeProjection.create(outputSerializer.head) - } + private lazy val resultObjToRow = UnsafeProjection.create(outputSerializer) override def eval(buffer: Any): Any = { val resultObj = aggregator.finish(buffer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index c1fe546a6088..0446bd9097b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -431,7 +431,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { assert(grouped.schema == df.schema) checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) } - /* + test("SPARK-24762: Aggregator should be able to use Option of Product encoder") { val df = Seq( OptionBooleanIntData("bob", Some((true, 1))), @@ -445,5 +445,4 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { checkAnswer(group, Row("bob", Row(true, 3)) :: Nil) checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob", Some((true, 3)))) } - */ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala deleted file mode 100644 index f54557b1e0f5..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpressionSuite.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ - - -class TypedAggregateExpressionSuite extends QueryTest with SharedSQLContext { - import testImplicits._ - - private def testOptProductEncoder(encoder: ExpressionEncoder[_], expected: Boolean): Unit = { - assert(TypedAggregateExpression.isOptProductEncoder(encoder) == expected) - } - - test("check an encoder is for option of product") { - testOptProductEncoder(encoderFor[Int], false) - testOptProductEncoder(encoderFor[(Long, Long)], false) - testOptProductEncoder(encoderFor[Option[Int]], false) - testOptProductEncoder(encoderFor[Option[(Int, Long)]], true) - testOptProductEncoder(encoderFor[Option[SimpleCaseClass]], true) - } - - test("flatten encoders of option of product") { - // Option[Product] is encoded as a struct column in a row. - val optProductEncoder: ExpressionEncoder[Option[(Int, Long)]] = encoderFor[Option[(Int, Long)]] - val optProductSchema = StructType(StructField("value", StructType( - StructField("_1", IntegerType) :: StructField("_2", LongType) :: Nil)) :: Nil) - - assert(optProductEncoder.schema.length == 1) - assert(DataType.equalsIgnoreCaseAndNullability(optProductEncoder.schema, optProductSchema)) - - val flattenEncoder = TypedAggregateExpression.flattenOptProductEncoder(optProductEncoder) - .resolveAndBind() - assert(flattenEncoder.schema.length == 2) - assert(DataType.equalsIgnoreCaseAndNullability(flattenEncoder.schema, - optProductSchema.fields(0).dataType)) - - val row = flattenEncoder.toRow(Some((1, 2L))) - val expected = flattenEncoder.fromRow(row) - assert(Some((1, 2L)) == expected) - } -} - -case class SimpleCaseClass(a: Int) From fec1cac2c5f8fa5226001820c24fe5fc8304fe3f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 28 Oct 2018 07:46:30 +0800 Subject: [PATCH 10/16] Update migration guide. --- docs/sql-migration-guide-upgrade.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index c9685b866774..0eb8e2611c47 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -17,6 +17,8 @@ displayTitle: Spark SQL Upgrading Guide - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set. + - In Spark version 2.4 and earlier, `Dataset` doesn't support to encode `Option[Product]` at top-level row, because in Spark SQL entire top-level row can't be null. Since Spark 3.0, `Option[Product]` at top-level is encoded as a row with single struct column. Then with this support, `Aggregator` can also use use `Option[Product]` as buffer and output column types. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. From 3956cdd279921ad1be1f4254e0964f06210a302a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 4 Nov 2018 16:21:57 +0800 Subject: [PATCH 11/16] Revert unnecessary test and doc. --- docs/sql-migration-guide-upgrade.md | 2 - .../sql/catalyst/ScalaReflectionSuite.scala | 40 ++----------------- 2 files changed, 3 insertions(+), 39 deletions(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 0eb8e2611c47..c9685b866774 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -17,8 +17,6 @@ displayTitle: Spark SQL Upgrading Guide - The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set. - - In Spark version 2.4 and earlier, `Dataset` doesn't support to encode `Option[Product]` at top-level row, because in Spark SQL entire top-level row can't be null. Since Spark 3.0, `Option[Product]` at top-level is encoded as a row with single struct column. Then with this support, `Aggregator` can also use use `Option[Product]` as buffer and output column types. - ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 26553dee39bb..d98589db323c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -22,9 +22,9 @@ import java.sql.{Date, Timestamp} import scala.reflect.runtime.universe.TypeTag import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, GetStructField, If, IsNull, SpecificInternalRow, UpCast} -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance, WrapOption} +import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, If, SpecificInternalRow, UpCast} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} import org.apache.spark.sql.types._ case class PrimitiveData( @@ -362,38 +362,4 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) } - - test("SPARK-24762: serializer for Option of Product") { - val serializer = serializerFor[Option[(Int, String)]] - val datatype = StructType(Seq( - StructField("_1", IntegerType, nullable = false), - StructField("_2", StringType, nullable = true))) - assert(serializer.dataType == datatype) - - serializer match { - case If(_, _, encoder: CreateNamedStruct) => - val fields = encoder.flatten - assert(fields.length == 2) - assert(fields(0).dataType == IntegerType) - assert(fields(1).dataType == StringType) - case _ => - fail("top-level Option of Product should be encoded as single struct column.") - } - } - - test("SPARK-24762: deserializer for Option of Product") { - val deserializer = deserializerFor[Option[(Int, String)]] - - deserializer match { - case WrapOption(If(IsNull(g @ GetColumnByOrdinal(0, _)), _, n: NewInstance), _) => - assert(n.cls == classOf[Tuple2[Int, String]]) - val arguments = n.arguments.flatMap(_.collect { - case g: GetStructField => g - }) - assert(arguments(0) == GetStructField(g, 0)) - assert(arguments(1) == GetStructField(g, 1)) - case _ => - fail("Incorrect deserializer of Option of Product") - } - } } From 8304de84e678d56e00d33119fe3eb6956bfa08d8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 4 Nov 2018 17:48:46 +0800 Subject: [PATCH 12/16] Add more tests. --- .../scala/org/apache/spark/sql/DatasetSuite.scala | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 4287a49cc7ca..2941abc5fad6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1582,6 +1582,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { )) assert(nestedDs.schema == nestedSchema) } + + test("SPARK-24762: Resolving Option[Product] field") { + val ds = Seq((1, ("a", 1.0)), (2, ("b", 2.0))).toDS().as[(Int, Option[(String, Double)])] + checkDataset(ds, + (1, Some(("a", 1.0))), (2, Some(("b", 2.0)))) + } + + test("SPARK-24762: select Option[Product] field") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + .select(expr("struct(_2, _2 + 1)").as[Option[(Int, Int)]]) + checkDataset(ds, + Some((1, 2)), Some((2, 3)), Some((3, 4))) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From 3a8a04786e5b8c070fd0b5720674e348e485675b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 5 Nov 2018 08:33:03 +0800 Subject: [PATCH 13/16] Address comments. --- .../apache/spark/sql/DatasetAggregatorSuite.scala | 14 ++++++++++++-- .../scala/org/apache/spark/sql/DatasetSuite.scala | 12 ++++++------ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 0446bd9097b6..97c3f358c0e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructType} object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { @@ -441,7 +441,17 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { val group = df .groupBy("name") .agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood")) - assert(df.schema == group.schema) + + val expectedSchema = new StructType() + .add("name", StringType, nullable = true) + .add("isGood", + new StructType() + .add("_1", BooleanType, nullable = false) + .add("_2", IntegerType, nullable = false), + nullable = true) + + assert(df.schema == expectedSchema) + assert(group.schema == expectedSchema) checkAnswer(group, Row("bob", Row(true, 3)) :: Nil) checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob", Some((true, 3)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 2941abc5fad6..11eef98f023a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1556,12 +1556,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ds, data: _*) - val schema = StructType(Seq( - StructField("value", StructType(Seq( - StructField("_1", IntegerType, nullable = false), - StructField("_2", StringType, nullable = true) - )), nullable = true) - )) + val schema = new StructType().add( + "value", + new StructType() + .add("_1", IntegerType, nullable = false) + .add("_2", StringType, nullable = true), + nullable = true) assert(ds.schema == schema) From 2d2057b4f2dbb541b4f2573944318f7a874fac3d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 9 Nov 2018 23:57:16 +0800 Subject: [PATCH 14/16] Add more option check. --- .../sql/catalyst/encoders/ExpressionEncoder.scala | 9 +++++++-- .../main/scala/org/apache/spark/sql/Dataset.scala | 10 +++++----- .../apache/spark/sql/KeyValueGroupedDataset.scala | 2 +- .../scala/org/apache/spark/sql/DatasetSuite.scala | 15 +++++++++++++++ 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index d57b911fbdec..ad42ec7f2fee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -189,7 +189,7 @@ case class ExpressionEncoder[T]( val serializer: Seq[NamedExpression] = { val clsName = Utils.getSimpleName(clsTag.runtimeClass) - if (isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass)) { + if (isSerializedAsStruct && !isOptionType) { val nullSafeSerializer = objSerializer.transformUp { case r: BoundReference => // For input object of Product type, we can't encode it to row if it's null, as Spark SQL @@ -220,7 +220,7 @@ case class ExpressionEncoder[T]( * `GetColumnByOrdinal` with corresponding ordinal. */ val deserializer: Expression = { - if (isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass)) { + if (isSerializedAsStruct && !isOptionType) { // We serialized this kind of objects to root-level row. The input of general deserializer // is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to // transform attributes accessors. @@ -251,6 +251,11 @@ case class ExpressionEncoder[T]( */ def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType] + /** + * Returns true if the type `T` is `Option`. + */ + def isOptionType: Boolean = classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) + // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This // is quite different from normal expressions, and `AttributeReference` doesn't work here diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c91b0d778fab..23b8c48d2bc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1087,7 +1087,7 @@ class Dataset[T] private[sql]( // Note that we do this before joining them, to enable the join operator to return null for one // side, in cases like outer-join. val left = { - val combined = if (!this.exprEnc.isSerializedAsStruct) { + val combined = if (!this.exprEnc.isSerializedAsStruct || this.exprEnc.isOptionType) { assert(joined.left.output.length == 1) Alias(joined.left.output.head, "_1")() } else { @@ -1097,7 +1097,7 @@ class Dataset[T] private[sql]( } val right = { - val combined = if (!other.exprEnc.isSerializedAsStruct) { + val combined = if (!other.exprEnc.isSerializedAsStruct || other.exprEnc.isOptionType) { assert(joined.right.output.length == 1) Alias(joined.right.output.head, "_2")() } else { @@ -1110,14 +1110,14 @@ class Dataset[T] private[sql]( // combine the outputs of each join side. val conditionExpr = joined.condition.get transformUp { case a: Attribute if joined.left.outputSet.contains(a) => - if (!this.exprEnc.isSerializedAsStruct) { + if (!this.exprEnc.isSerializedAsStruct || this.exprEnc.isOptionType) { left.output.head } else { val index = joined.left.output.indexWhere(_.exprId == a.exprId) GetStructField(left.output.head, index) } case a: Attribute if joined.right.outputSet.contains(a) => - if (!other.exprEnc.isSerializedAsStruct) { + if (!other.exprEnc.isSerializedAsStruct || other.exprEnc.isOptionType) { right.output.head } else { val index = joined.right.output.indexWhere(_.exprId == a.exprId) @@ -1390,7 +1390,7 @@ class Dataset[T] private[sql]( implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) - if (!encoder.isSerializedAsStruct) { + if (!encoder.isSerializedAsStruct || encoder.isOptionType) { new Dataset[U1](sparkSession, project, encoder) } else { // Flattens inner fields of U1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 555bcdffb6ee..d2457610880a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -457,7 +457,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map(_.withInputType(vExprEnc, dataAttributes).named) - val keyColumn = if (!kExprEnc.isSerializedAsStruct) { + val keyColumn = if (!kExprEnc.isSerializedAsStruct || kExprEnc.isOptionType) { assert(groupingAttributes.length == 1) groupingAttributes.head } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 11eef98f023a..6aa622f25630 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1595,6 +1595,21 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(ds, Some((1, 2)), Some((2, 3)), Some((3, 4))) } + + test("SPARK-24762: joinWith on Option[Product]") { + val ds1 = Seq(Some((1, 2)), Some((2, 3))).toDS().as("a") + val ds2 = Seq(Some((1, 2)), Some((2, 3))).toDS().as("b") + val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner") + checkDataset(joined, (Some((2, 3)), Some((1, 2)))) + } + + test("SPARK-24762: typed agg on Option[Product] type") { + val ds = Seq(Some((1, 2)), Some((2, 3)), Some((1, 3))).toDS() + assert(ds.groupByKey(_.get._1).count().collect() === Seq((1, 2), (2, 1))) + + assert(ds.groupByKey(x => x).count().collect() === + Seq((Some((1, 2)), 1), (Some((2, 3)), 1), (Some((1, 3)), 1))) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From dbd8678445cb4e8c58615d0bfa340a53c68a4f8a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 20 Nov 2018 23:25:07 +0800 Subject: [PATCH 15/16] Improve tests and document. --- .../catalyst/encoders/ExpressionEncoder.scala | 4 ++-- .../aggregate/TypedAggregateExpression.scala | 6 +++--- .../org/apache/spark/sql/DatasetSuite.scala | 17 +++++++++++------ 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index ad42ec7f2fee..243a2b5dbfeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -247,12 +247,12 @@ case class ExpressionEncoder[T]( }) /** - * Returns true if the type `T` is serialized as a struct. + * Returns true if the type `T` is serialized as a struct by `objSerializer`. */ def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType] /** - * Returns true if the type `T` is `Option`. + * Returns true if the type `T` is an `Option` type. */ def isOptionType: Boolean = classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 35a8f36868d4..b75752945a49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -40,9 +40,9 @@ object TypedAggregateExpression { val outputEncoder = encoderFor[OUT] val outputType = outputEncoder.objSerializer.dataType - // Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer - // expression is an alias of `BoundReference`, which means the buffer object doesn't need - // serialization. + // Checks if the buffer object is simple, i.e. the `BUF` type is not serialized as struct + // and the serializer expression is an alias of `BoundReference`, which means the buffer + // object doesn't need serialization. val isSimpleBuffer = { bufferSerializer.head match { case Alias(_: BoundReference, _) if !bufferEncoder.isSerializedAsStruct => true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index fc487fb2dde6..624b15f5e98d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1584,21 +1584,26 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("SPARK-24762: Resolving Option[Product] field") { - val ds = Seq((1, ("a", 1.0)), (2, ("b", 2.0))).toDS().as[(Int, Option[(String, Double)])] + val ds = Seq((1, ("a", 1.0)), (2, ("b", 2.0)), (3, null)).toDS() + .as[(Int, Option[(String, Double)])] checkDataset(ds, - (1, Some(("a", 1.0))), (2, Some(("b", 2.0)))) + (1, Some(("a", 1.0))), (2, Some(("b", 2.0))), (3, None)) } test("SPARK-24762: select Option[Product] field") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - .select(expr("struct(_2, _2 + 1)").as[Option[(Int, Int)]]) - checkDataset(ds, + val ds1 = ds.select(expr("struct(_2, _2 + 1)").as[Option[(Int, Int)]]) + checkDataset(ds1, Some((1, 2)), Some((2, 3)), Some((3, 4))) + + val ds2 = ds.select(expr("if(_2 > 2, struct(_2, _2 + 1), null)").as[Option[(Int, Int)]]) + checkDataset(ds2, + None, None, Some((3, 4))) } test("SPARK-24762: joinWith on Option[Product]") { - val ds1 = Seq(Some((1, 2)), Some((2, 3))).toDS().as("a") - val ds2 = Seq(Some((1, 2)), Some((2, 3))).toDS().as("b") + val ds1 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("a") + val ds2 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("b") val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner") checkDataset(joined, (Some((2, 3)), Some((1, 2)))) } From 62fdb17b4f72d935f25041c801708e3939e16074 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 24 Nov 2018 07:56:22 +0800 Subject: [PATCH 16/16] Add helper method. --- .../sql/catalyst/encoders/ExpressionEncoder.scala | 13 +++++++++++-- .../main/scala/org/apache/spark/sql/Dataset.scala | 10 +++++----- .../apache/spark/sql/KeyValueGroupedDataset.scala | 2 +- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 243a2b5dbfeb..d019924711e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -189,7 +189,7 @@ case class ExpressionEncoder[T]( val serializer: Seq[NamedExpression] = { val clsName = Utils.getSimpleName(clsTag.runtimeClass) - if (isSerializedAsStruct && !isOptionType) { + if (isSerializedAsStructForTopLevel) { val nullSafeSerializer = objSerializer.transformUp { case r: BoundReference => // For input object of Product type, we can't encode it to row if it's null, as Spark SQL @@ -220,7 +220,7 @@ case class ExpressionEncoder[T]( * `GetColumnByOrdinal` with corresponding ordinal. */ val deserializer: Expression = { - if (isSerializedAsStruct && !isOptionType) { + if (isSerializedAsStructForTopLevel) { // We serialized this kind of objects to root-level row. The input of general deserializer // is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to // transform attributes accessors. @@ -256,6 +256,15 @@ case class ExpressionEncoder[T]( */ def isOptionType: Boolean = classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) + /** + * If the type `T` is serialized as a struct, when it is encoded to a Spark SQL row, fields in + * the struct are naturally mapped to top-level columns in a row. In other words, the serialized + * struct is flattened to row. But in case of the `T` is also an `Option` type, it can't be + * flattened to top-level row, because in Spark SQL top-level row can't be null. This method + * returns true if `T` is serialized as struct and is not `Option` type. + */ + def isSerializedAsStructForTopLevel: Boolean = isSerializedAsStruct && !isOptionType + // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This // is quite different from normal expressions, and `AttributeReference` doesn't work here diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7c202b05a8ad..c78011485479 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1084,7 +1084,7 @@ class Dataset[T] private[sql]( // Note that we do this before joining them, to enable the join operator to return null for one // side, in cases like outer-join. val left = { - val combined = if (!this.exprEnc.isSerializedAsStruct || this.exprEnc.isOptionType) { + val combined = if (!this.exprEnc.isSerializedAsStructForTopLevel) { assert(joined.left.output.length == 1) Alias(joined.left.output.head, "_1")() } else { @@ -1094,7 +1094,7 @@ class Dataset[T] private[sql]( } val right = { - val combined = if (!other.exprEnc.isSerializedAsStruct || other.exprEnc.isOptionType) { + val combined = if (!other.exprEnc.isSerializedAsStructForTopLevel) { assert(joined.right.output.length == 1) Alias(joined.right.output.head, "_2")() } else { @@ -1107,14 +1107,14 @@ class Dataset[T] private[sql]( // combine the outputs of each join side. val conditionExpr = joined.condition.get transformUp { case a: Attribute if joined.left.outputSet.contains(a) => - if (!this.exprEnc.isSerializedAsStruct || this.exprEnc.isOptionType) { + if (!this.exprEnc.isSerializedAsStructForTopLevel) { left.output.head } else { val index = joined.left.output.indexWhere(_.exprId == a.exprId) GetStructField(left.output.head, index) } case a: Attribute if joined.right.outputSet.contains(a) => - if (!other.exprEnc.isSerializedAsStruct || other.exprEnc.isOptionType) { + if (!other.exprEnc.isSerializedAsStructForTopLevel) { right.output.head } else { val index = joined.right.output.indexWhere(_.exprId == a.exprId) @@ -1387,7 +1387,7 @@ class Dataset[T] private[sql]( implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) - if (!encoder.isSerializedAsStruct || encoder.isOptionType) { + if (!encoder.isSerializedAsStructForTopLevel) { new Dataset[U1](sparkSession, project, encoder) } else { // Flattens inner fields of U1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index bd8b7d241da3..dbb1c313869f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -457,7 +457,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map(_.withInputType(vExprEnc, dataAttributes).named) - val keyColumn = if (!kExprEnc.isSerializedAsStruct || kExprEnc.isOptionType) { + val keyColumn = if (!kExprEnc.isSerializedAsStructForTopLevel) { assert(groupingAttributes.length == 1) groupingAttributes.head } else {