diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index b47ec0b72c63..8a30c81912fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -203,12 +203,10 @@ object Encoders { validatePublicClass[T]() ExpressionEncoder[T]( - schema = new StructType().add("value", BinaryType), - flat = true, - serializer = Seq( + objSerializer = EncodeUsingSerializer( - BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), - deserializer = + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo), + objDeserializer = DecodeUsingSerializer[T]( Cast(GetColumnByOrdinal(0, BinaryType), BinaryType), classTag[T], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 60dd4a57139e..f32e08044731 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -187,26 +187,23 @@ object JavaTypeInference { } /** - * Returns an expression that can be used to deserialize an internal row to an object of java bean - * `T` with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes - * of the same name as the constructor arguments. Nested classes will have their fields accessed - * using UnresolvedExtractValue. + * Returns an expression that can be used to deserialize a Spark SQL representation to an object + * of java bean `T` with a compatible schema. The Spark SQL representation is located at ordinal + * 0 of a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed + * using `UnresolvedExtractValue`. */ def deserializerFor(beanClass: Class[_]): Expression = { - deserializerFor(TypeToken.of(beanClass), None) + val typeToken = TypeToken.of(beanClass) + deserializerFor(typeToken, GetColumnByOrdinal(0, inferDataType(typeToken)._1)) } - private def deserializerFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = { + private def deserializerFor(typeToken: TypeToken[_], path: Expression): Expression = { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) - - /** Returns the current path or `GetColumnByOrdinal`. */ - def getPath: Expression = path.getOrElse(GetColumnByOrdinal(0, inferDataType(typeToken)._1)) + def addToPath(part: String): Expression = UnresolvedExtractValue(path, + expressions.Literal(part)) typeToken.getRawType match { - case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath + case c if !inferExternalType(c).isInstanceOf[ObjectType] => path case c if c == classOf[java.lang.Short] || c == classOf[java.lang.Integer] || @@ -219,7 +216,7 @@ object JavaTypeInference { c, ObjectType(c), "valueOf", - getPath :: Nil, + path :: Nil, returnNullable = false) case c if c == classOf[java.sql.Date] => @@ -227,7 +224,7 @@ object JavaTypeInference { DateTimeUtils.getClass, ObjectType(c), "toJavaDate", - getPath :: Nil, + path :: Nil, returnNullable = false) case c if c == classOf[java.sql.Timestamp] => @@ -235,14 +232,14 @@ object JavaTypeInference { DateTimeUtils.getClass, ObjectType(c), "toJavaTimestamp", - getPath :: Nil, + path :: Nil, returnNullable = false) case c if c == classOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String])) + Invoke(path, "toString", ObjectType(classOf[String])) case c if c == classOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) case c if c.isArray => val elementType = c.getComponentType @@ -258,12 +255,12 @@ object JavaTypeInference { } primitiveMethod.map { method => - Invoke(getPath, method, ObjectType(c)) + Invoke(path, method, ObjectType(c)) }.getOrElse { Invoke( MapObjects( - p => deserializerFor(typeToken.getComponentType, Some(p)), - getPath, + p => deserializerFor(typeToken.getComponentType, p), + path, inferDataType(elementType)._1), "array", ObjectType(c)) @@ -272,8 +269,8 @@ object JavaTypeInference { case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) UnresolvedMapObjects( - p => deserializerFor(et, Some(p)), - getPath, + p => deserializerFor(et, p), + path, customCollectionCls = Some(c)) case _ if mapType.isAssignableFrom(typeToken) => @@ -282,16 +279,16 @@ object JavaTypeInference { val keyData = Invoke( UnresolvedMapObjects( - p => deserializerFor(keyType, Some(p)), - GetKeyArrayFromMap(getPath)), + p => deserializerFor(keyType, p), + GetKeyArrayFromMap(path)), "array", ObjectType(classOf[Array[Any]])) val valueData = Invoke( UnresolvedMapObjects( - p => deserializerFor(valueType, Some(p)), - GetValueArrayFromMap(getPath)), + p => deserializerFor(valueType, p), + GetValueArrayFromMap(path)), "array", ObjectType(classOf[Array[Any]])) @@ -307,7 +304,7 @@ object JavaTypeInference { other, ObjectType(other), "valueOf", - Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil, + Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil, returnNullable = false) case other => @@ -316,7 +313,7 @@ object JavaTypeInference { val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val (_, nullable) = inferDataType(fieldType) - val constructor = deserializerFor(fieldType, Some(addToPath(fieldName))) + val constructor = deserializerFor(fieldType, addToPath(fieldName)) val setter = if (nullable) { constructor } else { @@ -328,28 +325,23 @@ object JavaTypeInference { val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false) val result = InitializeJavaBean(newInstance, setters) - if (path.nonEmpty) { - expressions.If( - IsNull(getPath), - expressions.Literal.create(null, ObjectType(other)), - result - ) - } else { + expressions.If( + IsNull(path), + expressions.Literal.create(null, ObjectType(other)), result - } + ) } } /** - * Returns an expression for serializing an object of the given type to an internal row. + * Returns an expression for serializing an object of the given type to a Spark SQL + * representation. The input object is located at ordinal 0 of a row, i.e., + * `BoundReference(0, _)`. */ - def serializerFor(beanClass: Class[_]): CreateNamedStruct = { + def serializerFor(beanClass: Class[_]): Expression = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean")) - serializerFor(nullSafeInput, TypeToken.of(beanClass)) match { - case expressions.If(_, _, s: CreateNamedStruct) => s - case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) - } + serializerFor(nullSafeInput, TypeToken.of(beanClass)) } private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c27180e2a6b9..40074b36f6a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -24,7 +24,7 @@ import scala.util.Properties import org.apache.commons.lang3.reflect.ConstructorUtils import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData} @@ -129,21 +129,44 @@ object ScalaReflection extends ScalaReflection { } /** - * Returns an expression that can be used to deserialize an input row to an object of type `T` - * with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes - * of the same name as the constructor arguments. Nested classes will have their fields accessed - * using UnresolvedExtractValue. + * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff + * and lost the required data type, which may lead to runtime error if the real type doesn't + * match the encoder's schema. + * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type + * is [a: int, b: long], then we will hit runtime error and say that we can't construct class + * `Data` with int and long, because we lost the information that `b` should be a string. * - * When used on a primitive type, the constructor will instead default to extracting the value - * from ordinal 0 (since there are no names to map to). The actual location can be moved by - * calling resolve/bind with a new schema. + * This method help us "remember" the required data type by adding a `UpCast`. Note that we + * only need to do this for leaf nodes. */ - def deserializerFor[T : TypeTag]: Expression = { - val tpe = localTypeOf[T] + private def upCastToExpectedType(expr: Expression, expected: DataType, + walkedTypePath: Seq[String]): Expression = expected match { + case _: StructType => expr + case _: ArrayType => expr + // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and + // it's not trivial to support by-name resolution for StructType inside MapType. + case _ => UpCast(expr, expected, walkedTypePath) + } + + /** + * Returns an expression that can be used to deserialize a Spark SQL representation to an object + * of type `T` with a compatible schema. The Spark SQL representation is located at ordinal 0 of + * a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed using + * `UnresolvedExtractValue`. + * + * The returned expression is used by `ExpressionEncoder`. The encoder will resolve and bind this + * deserializer expression when using it. + */ + def deserializerForType(tpe: `Type`): Expression = { val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil - val expr = deserializerFor(tpe, None, walkedTypePath) - val Schema(_, nullable) = schemaFor(tpe) + val Schema(dataType, nullable) = schemaFor(tpe) + + // Assumes we are deserializing the first column of a row. + val input = upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, + walkedTypePath) + + val expr = deserializerFor(tpe, input, walkedTypePath) if (nullable) { expr } else { @@ -151,16 +174,22 @@ object ScalaReflection extends ScalaReflection { } } + /** + * Returns an expression that can be used to deserialize an input expression to an object of type + * `T` with a compatible schema. + * + * @param tpe The `Type` of deserialized object. + * @param path The expression which can be used to extract serialized value. + * @param walkedTypePath The paths from top to bottom to access current field when deserializing. + */ private def deserializerFor( tpe: `Type`, - path: Option[Expression], + path: Expression, walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects { /** Returns the current path with a sub-field extracted. */ def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { - val newPath = path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute.quoted(part)) + val newPath = UnresolvedExtractValue(path, expressions.Literal(part)) upCastToExpectedType(newPath, dataType, walkedTypePath) } @@ -169,46 +198,12 @@ object ScalaReflection extends ScalaReflection { ordinal: Int, dataType: DataType, walkedTypePath: Seq[String]): Expression = { - val newPath = path - .map(p => GetStructField(p, ordinal)) - .getOrElse(GetColumnByOrdinal(ordinal, dataType)) + val newPath = GetStructField(path, ordinal) upCastToExpectedType(newPath, dataType, walkedTypePath) } - /** Returns the current path or `GetColumnByOrdinal`. */ - def getPath: Expression = { - val dataType = schemaFor(tpe).dataType - if (path.isDefined) { - path.get - } else { - upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) - } - } - - /** - * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff - * and lost the required data type, which may lead to runtime error if the real type doesn't - * match the encoder's schema. - * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type - * is [a: int, b: long], then we will hit runtime error and say that we can't construct class - * `Data` with int and long, because we lost the information that `b` should be a string. - * - * This method help us "remember" the required data type by adding a `UpCast`. Note that we - * only need to do this for leaf nodes. - */ - def upCastToExpectedType( - expr: Expression, - expected: DataType, - walkedTypePath: Seq[String]): Expression = expected match { - case _: StructType => expr - case _: ArrayType => expr - // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and - // it's not trivial to support by-name resolution for StructType inside MapType. - case _ => UpCast(expr, expected, walkedTypePath) - } - tpe.dealias match { - case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -219,44 +214,44 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) + StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", - getPath :: Nil, + path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.sql.Timestamp] => @@ -264,25 +259,25 @@ object ScalaReflection extends ScalaReflection { DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", - getPath :: Nil, + path :: Nil, returnNullable = false) case t if t <:< localTypeOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) + Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), returnNullable = false) case t if t <:< localTypeOf[BigDecimal] => - Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false) + Invoke(path, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false) case t if t <:< localTypeOf[java.math.BigInteger] => - Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), + Invoke(path, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), returnNullable = false) case t if t <:< localTypeOf[scala.math.BigInt] => - Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), + Invoke(path, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), returnNullable = false) case t if t <:< localTypeOf[Array[_]] => @@ -294,7 +289,7 @@ object ScalaReflection extends ScalaReflection { val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. val casted = upCastToExpectedType(element, dataType, newTypePath) - val converter = deserializerFor(elementType, Some(casted), newTypePath) + val converter = deserializerFor(elementType, casted, newTypePath) if (elementNullable) { converter } else { @@ -302,7 +297,7 @@ object ScalaReflection extends ScalaReflection { } } - val arrayData = UnresolvedMapObjects(mapFunction, getPath) + val arrayData = UnresolvedMapObjects(mapFunction, path) val arrayCls = arrayClassFor(elementType) if (elementNullable) { @@ -334,7 +329,7 @@ object ScalaReflection extends ScalaReflection { val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. val casted = upCastToExpectedType(element, dataType, newTypePath) - val converter = deserializerFor(elementType, Some(casted), newTypePath) + val converter = deserializerFor(elementType, casted, newTypePath) if (elementNullable) { converter } else { @@ -349,16 +344,16 @@ object ScalaReflection extends ScalaReflection { classOf[scala.collection.Set[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } - UnresolvedMapObjects(mapFunction, getPath, Some(cls)) + UnresolvedMapObjects(mapFunction, path, Some(cls)) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t CatalystToExternalMap( - p => deserializerFor(keyType, Some(p), walkedTypePath), - p => deserializerFor(valueType, Some(p), walkedTypePath), - getPath, + p => deserializerFor(keyType, p, walkedTypePath), + p => deserializerFor(valueType, p, walkedTypePath), + path, mirror.runtimeClass(t.typeSymbol.asClass) ) @@ -368,7 +363,7 @@ object ScalaReflection extends ScalaReflection { udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) + Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) case t if UDTRegistration.exists(getClassNameFromType(t)) => val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() @@ -377,7 +372,7 @@ object ScalaReflection extends ScalaReflection { udt.getClass, Nil, dataType = ObjectType(udt.getClass)) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) + Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) @@ -392,12 +387,12 @@ object ScalaReflection extends ScalaReflection { val constructor = if (cls.getName startsWith "scala.Tuple") { deserializerFor( fieldType, - Some(addToPathOrdinal(i, dataType, newTypePath)), + addToPathOrdinal(i, dataType, newTypePath), newTypePath) } else { deserializerFor( fieldType, - Some(addToPath(fieldName, dataType, newTypePath)), + addToPath(fieldName, dataType, newTypePath), newTypePath) } @@ -410,20 +405,17 @@ object ScalaReflection extends ScalaReflection { val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) - if (path.nonEmpty) { - expressions.If( - IsNull(getPath), - expressions.Literal.create(null, ObjectType(cls)), - newInstance - ) - } else { + expressions.If( + IsNull(path), + expressions.Literal.create(null, ObjectType(cls)), newInstance - } + ) } } /** - * Returns an expression for serializing an object of type T to an internal row. + * Returns an expression for serializing an object of type T to Spark SQL representation. The + * input object is located at ordinal 0 of a row, i.e., `BoundReference(0, _)`. * * If the given type is not supported, i.e. there is no encoder can be built for this type, * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain @@ -434,17 +426,21 @@ object ScalaReflection extends ScalaReflection { * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` */ - def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { - val tpe = localTypeOf[T] + def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects { val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "$clsName"""" :: Nil - serializerFor(inputObject, tpe, walkedTypePath) match { - case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s - case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) - } + + // The input object to `ExpressionEncoder` is located at first column of an row. + val inputObject = BoundReference(0, dataTypeFor(tpe), + nullable = !tpe.typeSymbol.asClass.isPrimitive) + + serializerFor(inputObject, tpe, walkedTypePath) } - /** Helper for extracting internal fields from a case class. */ + /** + * Returns an expression for serializing the value of an input expression into Spark SQL + * internal representation. + */ private def serializerFor( inputObject: Expression, tpe: `Type`, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index cbea3c017a26..29f6136a75ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -25,10 +25,11 @@ import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaRefle import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} -import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType} +import org.apache.spark.sql.types.{ObjectType, StringType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -43,8 +44,8 @@ import org.apache.spark.util.Utils * to the name `value`. */ object ExpressionEncoder { + def apply[T : TypeTag](): ExpressionEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = ScalaReflection.mirror val tpe = typeTag[T].in(mirror).tpe @@ -58,25 +59,11 @@ object ExpressionEncoder { } val cls = mirror.runtimeClass(tpe) - val flat = !ScalaReflection.definedByConstructorParams(tpe) - - val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = !cls.isPrimitive) - val nullSafeInput = if (flat) { - inputObject - } else { - // For input object of Product type, we can't encode it to row if it's null, as Spark SQL - // doesn't allow top-level row to be null, only its columns can be null. - AssertNotNull(inputObject, Seq("top level Product input object")) - } - val serializer = ScalaReflection.serializerFor[T](nullSafeInput) - val deserializer = ScalaReflection.deserializerFor[T] - - val schema = serializer.dataType + val serializer = ScalaReflection.serializerForType(tpe) + val deserializer = ScalaReflection.deserializerForType(tpe) new ExpressionEncoder[T]( - schema, - flat, - serializer.flatten, + serializer, deserializer, ClassTag[T](cls)) } @@ -86,14 +73,12 @@ object ExpressionEncoder { val schema = JavaTypeInference.inferDataType(beanClass)._1 assert(schema.isInstanceOf[StructType]) - val serializer = JavaTypeInference.serializerFor(beanClass) - val deserializer = JavaTypeInference.deserializerFor(beanClass) + val objSerializer = JavaTypeInference.serializerFor(beanClass) + val objDeserializer = JavaTypeInference.deserializerFor(beanClass) new ExpressionEncoder[T]( - schema.asInstanceOf[StructType], - flat = false, - serializer.flatten, - deserializer, + objSerializer, + objDeserializer, ClassTag[T](beanClass)) } @@ -103,75 +88,59 @@ object ExpressionEncoder { * name/positional binding is preserved. */ def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + // TODO: check if encoders length is more than 22 and throw exception for it. + encoders.foreach(_.assertUnresolved()) val schema = StructType(encoders.zipWithIndex.map { case (e, i) => - val (dataType, nullable) = if (e.flat) { - e.schema.head.dataType -> e.schema.head.nullable - } else { - e.schema -> true - } - StructField(s"_${i + 1}", dataType, nullable) + StructField(s"_${i + 1}", e.objSerializer.dataType, e.objSerializer.nullable) }) val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - val serializer = encoders.zipWithIndex.map { case (enc, index) => - val originalInputObject = enc.serializer.head.collect { case b: BoundReference => b }.head + val serializers = encoders.zipWithIndex.map { case (enc, index) => + val boundRefs = enc.objSerializer.collect { case b: BoundReference => b }.distinct + assert(boundRefs.size == 1, "object serializer should have only one bound reference but " + + s"there are ${boundRefs.size}") + + val originalInputObject = boundRefs.head val newInputObject = Invoke( BoundReference(0, ObjectType(cls), nullable = true), s"_${index + 1}", - originalInputObject.dataType) - - val newSerializer = enc.serializer.map(_.transformUp { - case b: BoundReference if b == originalInputObject => newInputObject - }) + originalInputObject.dataType, + returnNullable = originalInputObject.nullable) - val serializerExpr = if (enc.flat) { - newSerializer.head - } else { - // For non-flat encoder, the input object is not top level anymore after being combined to - // a tuple encoder, thus it can be null and we should wrap the `CreateStruct` with `If` and - // null check to handle null case correctly. - // e.g. for Encoder[(Int, String)], the serializer expressions will create 2 columns, and is - // not able to handle the case when the input tuple is null. This is not a problem as there - // is a check to make sure the input object won't be null. However, if this encoder is used - // to create a bigger tuple encoder, the original input object becomes a filed of the new - // input tuple and can be null. So instead of creating a struct directly here, we should add - // a null/None check and return a null struct if the null/None check fails. - val struct = CreateStruct(newSerializer) - val nullCheck = Or( - IsNull(newInputObject), - Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil)) - If(nullCheck, Literal.create(null, struct.dataType), struct) + val newSerializer = enc.objSerializer.transformUp { + case b: BoundReference => newInputObject } - Alias(serializerExpr, s"_${index + 1}")() + + Alias(newSerializer, s"_${index + 1}")() } val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => - if (enc.flat) { - enc.deserializer.transform { - case g: GetColumnByOrdinal => g.copy(ordinal = index) - } + val getColumnsByOrdinals = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c } + .distinct + assert(getColumnsByOrdinals.size == 1, "object deserializer should have only one " + + s"`GetColumnByOrdinal`, but there are ${getColumnsByOrdinals.size}") + + val input = GetStructField(GetColumnByOrdinal(0, schema), index) + val newDeserializer = enc.objDeserializer.transformUp { + case GetColumnByOrdinal(0, _) => input + } + if (schema(index).nullable) { + If(IsNull(input), Literal.create(null, newDeserializer.dataType), newDeserializer) } else { - val input = GetColumnByOrdinal(index, enc.schema) - val deserialized = enc.deserializer.transformUp { - case UnresolvedAttribute(nameParts) => - assert(nameParts.length == 1) - UnresolvedExtractValue(input, Literal(nameParts.head)) - case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal) - } - If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized) + newDeserializer } } + val serializer = If(IsNull(BoundReference(0, ObjectType(cls), nullable = true)), + Literal.create(null, schema), CreateStruct(serializers)) val deserializer = NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false) new ExpressionEncoder[Any]( - schema, - flat = false, serializer, deserializer, ClassTag(cls)) @@ -212,21 +181,91 @@ object ExpressionEncoder { * A generic encoder for JVM objects that uses Catalyst Expressions for a `serializer` * and a `deserializer`. * - * @param schema The schema after converting `T` to a Spark SQL row. - * @param serializer A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object into an [[InternalRow]]. - * @param deserializer An expression that will construct an object given an [[InternalRow]]. + * @param objSerializer An expression that can be used to encode a raw object to corresponding + * Spark SQL representation that can be a primitive column, array, map or a + * struct. This represents how Spark SQL generally serializes an object of + * type `T`. + * @param objDeserializer An expression that will construct an object given a Spark SQL + * representation. This represents how Spark SQL generally deserializes + * a serialized value in Spark SQL representation back to an object of + * type `T`. * @param clsTag A classtag for `T`. */ case class ExpressionEncoder[T]( - schema: StructType, - flat: Boolean, - serializer: Seq[Expression], - deserializer: Expression, + objSerializer: Expression, + objDeserializer: Expression, clsTag: ClassTag[T]) extends Encoder[T] { - if (flat) require(serializer.size == 1) + /** + * A sequence of expressions, one for each top-level field that can be used to + * extract the values from a raw object into an [[InternalRow]]: + * 1. If `serializer` encodes a raw object to a struct, strip the outer If-IsNull and get + * the `CreateNamedStruct`. + * 2. For other cases, wrap the single serializer with `CreateNamedStruct`. + */ + val serializer: Seq[NamedExpression] = { + val clsName = Utils.getSimpleName(clsTag.runtimeClass) + + if (isSerializedAsStruct) { + val nullSafeSerializer = objSerializer.transformUp { + case r: BoundReference => + // For input object of Product type, we can't encode it to row if it's null, as Spark SQL + // doesn't allow top-level row to be null, only its columns can be null. + AssertNotNull(r, Seq("top level Product or row object")) + } + nullSafeSerializer match { + case If(_: IsNull, _, s: CreateNamedStruct) => s + case s: CreateNamedStruct => s + case _ => + throw new RuntimeException(s"class $clsName has unexpected serializer: $objSerializer") + } + } else { + // For other input objects like primitive, array, map, etc., we construct a struct to wrap + // the serializer which is a column of an row. + CreateNamedStruct(Literal("value") :: objSerializer :: Nil) + } + }.flatten + + /** + * Returns an expression that can be used to deserialize an input row to an object of type `T` + * with a compatible schema. Fields of the row will be extracted using `UnresolvedAttribute`. + * of the same name as the constructor arguments. + * + * For complex objects that are encoded to structs, Fields of the struct will be extracted using + * `GetColumnByOrdinal` with corresponding ordinal. + */ + val deserializer: Expression = { + if (isSerializedAsStruct) { + // We serialized this kind of objects to root-level row. The input of general deserializer + // is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to + // transform attributes accessors. + objDeserializer.transform { + case UnresolvedExtractValue(GetColumnByOrdinal(0, _), + Literal(part: UTF8String, StringType)) => + UnresolvedAttribute.quoted(part.toString) + case GetStructField(GetColumnByOrdinal(0, dt), ordinal, _) => + GetColumnByOrdinal(ordinal, dt) + case If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance) => n + case If(IsNull(GetColumnByOrdinal(0, _)), _, i: InitializeJavaBean) => i + } + } else { + // For other input objects like primitive, array, map, etc., we deserialize the first column + // of a row to the object. + objDeserializer + } + } + + // The schema after converting `T` to a Spark SQL row. This schema is dependent on the given + // serialier. + val schema: StructType = StructType(serializer.map { s => + StructField(s.name, s.dataType, s.nullable) + }) + + /** + * Returns true if the type `T` is serialized as a struct. + */ + def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType] // serializer expressions are used to encode an object to a row, while the object is usually an // intermediate value produced inside an operator, not from the output of the child operator. This @@ -258,7 +297,7 @@ case class ExpressionEncoder[T]( analyzer.checkAnalysis(analyzedPlan) val resolved = SimplifyCasts(analyzedPlan).asInstanceOf[DeserializeToObject].deserializer val bound = BindReferences.bindReference(resolved, attrs) - copy(deserializer = bound) + copy(objDeserializer = bound) } @transient diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index ae89f98b1902..d905f8f9858e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -58,12 +58,10 @@ object RowEncoder { def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), schema) - val deserializer = deserializerFor(schema) + val serializer = serializerFor(inputObject, schema) + val deserializer = deserializerFor(GetColumnByOrdinal(0, serializer.dataType), schema) new ExpressionEncoder[Row]( - schema, - flat = false, - serializer.asInstanceOf[CreateNamedStruct].flatten, + serializer, deserializer, ClassTag(cls)) } @@ -237,13 +235,9 @@ object RowEncoder { case udt: UserDefinedType[_] => ObjectType(udt.userClass) } - private def deserializerFor(schema: StructType): Expression = { + private def deserializerFor(input: Expression, schema: StructType): Expression = { val fields = schema.zipWithIndex.map { case (f, i) => - val dt = f.dataType match { - case p: PythonUserDefinedType => p.sqlType - case other => other - } - deserializerFor(GetColumnByOrdinal(i, dt)) + deserializerFor(GetStructField(input, i)) } CreateExternalRow(fields, schema) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index f9ee948b97e0..d98589db323c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} +import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow, UpCast} +import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, If, SpecificInternalRow, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String case class PrimitiveData( intField: Int, @@ -112,6 +113,14 @@ object TestingUDT { class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ + // A helper method used to test `ScalaReflection.serializerForType`. + private def serializerFor[T: TypeTag]: Expression = + serializerForType(ScalaReflection.localTypeOf[T]) + + // A helper method used to test `ScalaReflection.deserializerForType`. + private def deserializerFor[T: TypeTag]: Expression = + deserializerForType(ScalaReflection.localTypeOf[T]) + test("SQLUserDefinedType annotation on Scala structure") { val schema = schemaFor[TestingUDT.NestedStruct] assert(schema === Schema( @@ -263,13 +272,9 @@ class ScalaReflectionSuite extends SparkFunSuite { test("SPARK-15062: Get correct serializer for List[_]") { val list = List(1, 2, 3) - val serializer = serializerFor[List[Int]](BoundReference( - 0, ObjectType(list.getClass), nullable = false)) - assert(serializer.children.size == 2) - assert(serializer.children.head.isInstanceOf[Literal]) - assert(serializer.children.head.asInstanceOf[Literal].value === UTF8String.fromString("value")) - assert(serializer.children.last.isInstanceOf[NewInstance]) - assert(serializer.children.last.asInstanceOf[NewInstance] + val serializer = serializerFor[List[Int]] + assert(serializer.isInstanceOf[NewInstance]) + assert(serializer.asInstanceOf[NewInstance] .cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData])) } @@ -280,59 +285,58 @@ class ScalaReflectionSuite extends SparkFunSuite { test("serialize and deserialize arbitrary sequence types") { import scala.collection.immutable.Queue - val queueSerializer = serializerFor[Queue[Int]](BoundReference( - 0, ObjectType(classOf[Queue[Int]]), nullable = false)) - assert(queueSerializer.dataType.head.dataType == + val queueSerializer = serializerFor[Queue[Int]] + assert(queueSerializer.dataType == ArrayType(IntegerType, containsNull = false)) val queueDeserializer = deserializerFor[Queue[Int]] assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]])) import scala.collection.mutable.ArrayBuffer - val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference( - 0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false)) - assert(arrayBufferSerializer.dataType.head.dataType == + val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]] + assert(arrayBufferSerializer.dataType == ArrayType(IntegerType, containsNull = false)) val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } test("serialize and deserialize arbitrary map types") { - val mapSerializer = serializerFor[Map[Int, Int]](BoundReference( - 0, ObjectType(classOf[Map[Int, Int]]), nullable = false)) - assert(mapSerializer.dataType.head.dataType == + val mapSerializer = serializerFor[Map[Int, Int]] + assert(mapSerializer.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) val mapDeserializer = deserializerFor[Map[Int, Int]] assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]])) import scala.collection.immutable.HashMap - val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference( - 0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false)) - assert(hashMapSerializer.dataType.head.dataType == + val hashMapSerializer = serializerFor[HashMap[Int, Int]] + assert(hashMapSerializer.dataType == MapType(IntegerType, IntegerType, valueContainsNull = false)) val hashMapDeserializer = deserializerFor[HashMap[Int, Int]] assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]])) import scala.collection.mutable.{LinkedHashMap => LHMap} - val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference( - 0, ObjectType(classOf[LHMap[Long, String]]), nullable = false)) - assert(linkedHashMapSerializer.dataType.head.dataType == + val linkedHashMapSerializer = serializerFor[LHMap[Long, String]] + assert(linkedHashMapSerializer.dataType == MapType(LongType, StringType, valueContainsNull = true)) val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]] assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]])) } test("SPARK-22442: Generate correct field names for special characters") { - val serializer = serializerFor[SpecialCharAsFieldData](BoundReference( - 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false)) + val serializer = serializerFor[SpecialCharAsFieldData] + .collect { + case If(_, _, s: CreateNamedStruct) => s + }.head val deserializer = deserializerFor[SpecialCharAsFieldData] assert(serializer.dataType(0).name == "field.1") assert(serializer.dataType(1).name == "field 2") - val argumentsFields = deserializer.asInstanceOf[NewInstance].arguments.flatMap { _.collect { - case UpCast(u: UnresolvedAttribute, _, _) => u.nameParts + val newInstance = deserializer.collect { case n: NewInstance => n }.head + + val argumentsFields = newInstance.arguments.flatMap { _.collect { + case UpCast(u: UnresolvedExtractValue, _, _) => u.extraction.toString }} - assert(argumentsFields(0) == Seq("field.1")) - assert(argumentsFields(1) == Seq("field 2")) + assert(argumentsFields(0) == "field.1") + assert(argumentsFields(1) == "field 2") } test("SPARK-22472: add null check for top-level primitive values") { @@ -351,8 +355,8 @@ class ScalaReflectionSuite extends SparkFunSuite { test("SPARK-23835: add null check to non-nullable types in Tuples") { def numberOfCheckedArguments(deserializer: Expression): Int = { - assert(deserializer.isInstanceOf[NewInstance]) - deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull]) + val newInstance = deserializer.collect { case n: NewInstance => n}.head + newInstance.arguments.count(_.isInstanceOf[AssertNotNull]) } assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2) assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index f0d61de97ffc..e9b100b3b30d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.{Encoder, Encoders} import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -348,7 +348,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes test("nullable of encoder serializer") { def checkNullable[T: Encoder](nullable: Boolean): Unit = { - assert(encoderFor[T].serializer.forall(_.nullable === nullable)) + assert(encoderFor[T].objSerializer.nullable === nullable) } // test for flat encoders diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 235732134d4b..ab819bec72e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -239,7 +239,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { val encoder = RowEncoder(schema) val e = intercept[RuntimeException](encoder.toRow(null)) assert(e.getMessage.contains("Null value appeared in non-nullable field")) - assert(e.getMessage.contains("top level row object")) + assert(e.getMessage.contains("top level Product or row object")) } test("RowEncoder should validate external type") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0fb3301b3616..c91b0d778fab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1087,7 +1087,7 @@ class Dataset[T] private[sql]( // Note that we do this before joining them, to enable the join operator to return null for one // side, in cases like outer-join. val left = { - val combined = if (this.exprEnc.flat) { + val combined = if (!this.exprEnc.isSerializedAsStruct) { assert(joined.left.output.length == 1) Alias(joined.left.output.head, "_1")() } else { @@ -1097,7 +1097,7 @@ class Dataset[T] private[sql]( } val right = { - val combined = if (other.exprEnc.flat) { + val combined = if (!other.exprEnc.isSerializedAsStruct) { assert(joined.right.output.length == 1) Alias(joined.right.output.head, "_2")() } else { @@ -1110,14 +1110,14 @@ class Dataset[T] private[sql]( // combine the outputs of each join side. val conditionExpr = joined.condition.get transformUp { case a: Attribute if joined.left.outputSet.contains(a) => - if (this.exprEnc.flat) { + if (!this.exprEnc.isSerializedAsStruct) { left.output.head } else { val index = joined.left.output.indexWhere(_.exprId == a.exprId) GetStructField(left.output.head, index) } case a: Attribute if joined.right.outputSet.contains(a) => - if (other.exprEnc.flat) { + if (!other.exprEnc.isSerializedAsStruct) { right.output.head } else { val index = joined.right.output.indexWhere(_.exprId == a.exprId) @@ -1390,7 +1390,7 @@ class Dataset[T] private[sql]( implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) - if (encoder.flat) { + if (!encoder.isSerializedAsStruct) { new Dataset[U1](sparkSession, project, encoder) } else { // Flattens inner fields of U1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 6bab21dca0cb..555bcdffb6ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -457,7 +457,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map(_.withInputType(vExprEnc, dataAttributes).named) - val keyColumn = if (kExprEnc.flat) { + val keyColumn = if (!kExprEnc.isSerializedAsStruct) { assert(groupingAttributes.length == 1) groupingAttributes.head } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 6d44890704f4..39200ec00e15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -38,18 +38,14 @@ object TypedAggregateExpression { val bufferSerializer = bufferEncoder.namedExpressions val outputEncoder = encoderFor[OUT] - val outputType = if (outputEncoder.flat) { - outputEncoder.schema.head.dataType - } else { - outputEncoder.schema - } + val outputType = outputEncoder.objSerializer.dataType // Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer // expression is an alias of `BoundReference`, which means the buffer object doesn't need // serialization. val isSimpleBuffer = { bufferSerializer.head match { - case Alias(_: BoundReference, _) if bufferEncoder.flat => true + case Alias(_: BoundReference, _) if !bufferEncoder.isSerializedAsStruct => true case _ => false } } @@ -71,7 +67,7 @@ object TypedAggregateExpression { outputEncoder.serializer, outputEncoder.deserializer.dataType, outputType, - !outputEncoder.flat || outputEncoder.schema.head.nullable) + outputEncoder.objSerializer.nullable) } else { ComplexTypedAggregateExpression( aggregator.asInstanceOf[Aggregator[Any, Any, Any]], @@ -82,7 +78,7 @@ object TypedAggregateExpression { bufferEncoder.resolveAndBind().deserializer, outputEncoder.serializer, outputType, - !outputEncoder.flat || outputEncoder.schema.head.nullable) + outputEncoder.objSerializer.nullable) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 4e593ff046a5..27b3b3d78d2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1065,7 +1065,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("Dataset should throw RuntimeException if top-level product input object is null") { val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS()) assert(e.getMessage.contains("Null value appeared in non-nullable field")) - assert(e.getMessage.contains("top level Product input object")) + assert(e.getMessage.contains("top level Product or row object")) } test("dropDuplicates") {