diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala new file mode 100644 index 000000000000..e71955ab4e75 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, UpCast} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, StaticInvoke} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +object DeserializerBuildHelper { + /** Returns the current path with a sub-field extracted. */ + def addToPath( + path: Expression, + part: String, + dataType: DataType, + walkedTypePath: Seq[String]): Expression = { + val newPath = UnresolvedExtractValue(path, expressions.Literal(part)) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } + + /** Returns the current path with a field at ordinal extracted. */ + def addToPathOrdinal( + path: Expression, + ordinal: Int, + dataType: DataType, + walkedTypePath: Seq[String]): Expression = { + val newPath = GetStructField(path, ordinal) + upCastToExpectedType(newPath, dataType, walkedTypePath) + } + + def deserializerForWithNullSafety( + expr: Expression, + dataType: DataType, + nullable: Boolean, + walkedTypePath: Seq[String], + funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = { + val newExpr = funcForCreatingNewExpr(expr, walkedTypePath) + expressionWithNullSafety(newExpr, nullable, walkedTypePath) + } + + def deserializerForWithNullSafetyAndUpcast( + expr: Expression, + dataType: DataType, + nullable: Boolean, + walkedTypePath: Seq[String], + funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = { + val casted = upCastToExpectedType(expr, dataType, walkedTypePath) + deserializerForWithNullSafety(casted, dataType, nullable, walkedTypePath, + funcForCreatingNewExpr) + } + + private def expressionWithNullSafety( + expr: Expression, + nullable: Boolean, + walkedTypePath: Seq[String]): Expression = { + if (nullable) { + expr + } else { + AssertNotNull(expr, walkedTypePath) + } + } + + def createDeserializerForTypesSupportValueOf( + path: Expression, + clazz: Class[_]): Expression = { + StaticInvoke( + clazz, + ObjectType(clazz), + "valueOf", + path :: Nil, + returnNullable = false) + } + + def createDeserializerForString(path: Expression, returnNullable: Boolean): Expression = { + Invoke(path, "toString", ObjectType(classOf[java.lang.String]), + returnNullable = returnNullable) + } + + def createDeserializerForSqlDate(path: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.sql.Date]), + "toJavaDate", + path :: Nil, + returnNullable = false) + } + + def createDeserializerForSqlTimestamp(path: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.sql.Timestamp]), + "toJavaTimestamp", + path :: Nil, + returnNullable = false) + } + + def createDeserializerForJavaBigDecimal( + path: Expression, + returnNullable: Boolean): Expression = { + Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), + returnNullable = returnNullable) + } + + def createDeserializerForScalaBigDecimal( + path: Expression, + returnNullable: Boolean): Expression = { + Invoke(path, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = returnNullable) + } + + def createDeserializerForJavaBigInteger( + path: Expression, + returnNullable: Boolean): Expression = { + Invoke(path, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), + returnNullable = returnNullable) + } + + def createDeserializerForScalaBigInt(path: Expression): Expression = { + Invoke(path, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), + returnNullable = false) + } + + /** + * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff + * and lost the required data type, which may lead to runtime error if the real type doesn't + * match the encoder's schema. + * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type + * is [a: int, b: long], then we will hit runtime error and say that we can't construct class + * `Data` with int and long, because we lost the information that `b` should be a string. + * + * This method help us "remember" the required data type by adding a `UpCast`. Note that we + * only need to do this for leaf nodes. + */ + private def upCastToExpectedType( + expr: Expression, + expected: DataType, + walkedTypePath: Seq[String]): Expression = expected match { + case _: StructType => expr + case _: ArrayType => expr + case _: MapType => expr + case _ => UpCast(expr, expected, walkedTypePath) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 311060e5961c..dafa87839ec6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -26,7 +26,8 @@ import scala.language.existentials import com.google.common.reflect.TypeToken -import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -194,14 +195,20 @@ object JavaTypeInference { */ def deserializerFor(beanClass: Class[_]): Expression = { val typeToken = TypeToken.of(beanClass) - deserializerFor(typeToken, GetColumnByOrdinal(0, inferDataType(typeToken)._1)) + val walkedTypePath = s"""- root class: "${beanClass.getCanonicalName}"""" :: Nil + val (dataType, nullable) = inferDataType(typeToken) + + // Assumes we are deserializing the first column of a row. + deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType, + nullable = nullable, walkedTypePath, (casted, walkedTypePath) => { + deserializerFor(typeToken, casted, walkedTypePath) + }) } - private def deserializerFor(typeToken: TypeToken[_], path: Expression): Expression = { - /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = UnresolvedExtractValue(path, - expressions.Literal(part)) - + private def deserializerFor( + typeToken: TypeToken[_], + path: Expression, + walkedTypePath: Seq[String]): Expression = { typeToken.getRawType match { case c if !inferExternalType(c).isInstanceOf[ObjectType] => path @@ -212,74 +219,79 @@ object JavaTypeInference { c == classOf[java.lang.Float] || c == classOf[java.lang.Byte] || c == classOf[java.lang.Boolean] => - StaticInvoke( - c, - ObjectType(c), - "valueOf", - path :: Nil, - returnNullable = false) + createDeserializerForTypesSupportValueOf(path, c) case c if c == classOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils.getClass, - ObjectType(c), - "toJavaDate", - path :: Nil, - returnNullable = false) + createDeserializerForSqlDate(path) case c if c == classOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils.getClass, - ObjectType(c), - "toJavaTimestamp", - path :: Nil, - returnNullable = false) + createDeserializerForSqlTimestamp(path) case c if c == classOf[java.lang.String] => - Invoke(path, "toString", ObjectType(classOf[String])) + createDeserializerForString(path, returnNullable = true) case c if c == classOf[java.math.BigDecimal] => - Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + createDeserializerForJavaBigDecimal(path, returnNullable = true) + + case c if c == classOf[java.math.BigInteger] => + createDeserializerForJavaBigInteger(path, returnNullable = true) case c if c.isArray => val elementType = c.getComponentType - val primitiveMethod = elementType match { - case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray") - case c if c == java.lang.Byte.TYPE => Some("toByteArray") - case c if c == java.lang.Short.TYPE => Some("toShortArray") - case c if c == java.lang.Integer.TYPE => Some("toIntArray") - case c if c == java.lang.Long.TYPE => Some("toLongArray") - case c if c == java.lang.Float.TYPE => Some("toFloatArray") - case c if c == java.lang.Double.TYPE => Some("toDoubleArray") - case _ => None + val newTypePath = s"""- array element class: "${elementType.getCanonicalName}"""" +: + walkedTypePath + val (dataType, elementNullable) = inferDataType(elementType) + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + deserializerForWithNullSafetyAndUpcast( + element, + dataType, + nullable = elementNullable, + newTypePath, + (casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath)) } - primitiveMethod.map { method => - Invoke(path, method, ObjectType(c)) - }.getOrElse { - Invoke( - MapObjects( - p => deserializerFor(typeToken.getComponentType, p), - path, - inferDataType(elementType)._1), - "array", - ObjectType(c)) + val arrayData = UnresolvedMapObjects(mapFunction, path) + + val methodName = elementType match { + case c if c == java.lang.Integer.TYPE => "toIntArray" + case c if c == java.lang.Long.TYPE => "toLongArray" + case c if c == java.lang.Double.TYPE => "toDoubleArray" + case c if c == java.lang.Float.TYPE => "toFloatArray" + case c if c == java.lang.Short.TYPE => "toShortArray" + case c if c == java.lang.Byte.TYPE => "toByteArray" + case c if c == java.lang.Boolean.TYPE => "toBooleanArray" + // non-primitive + case _ => "array" } + Invoke(arrayData, methodName, ObjectType(c)) case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) - UnresolvedMapObjects( - p => deserializerFor(et, p), - path, - customCollectionCls = Some(c)) + val newTypePath = s"""- array element class: "${et.getType.getTypeName}"""" +: + walkedTypePath + val (dataType, elementNullable) = inferDataType(et) + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expected. + deserializerForWithNullSafetyAndUpcast( + element, + dataType, + nullable = elementNullable, + newTypePath, + (casted, typePath) => deserializerFor(et, casted, typePath)) + } + + UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c)) case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) + val newTypePath = (s"""- map key class: "${keyType.getType.getTypeName}"""" + + s""", value class: "${valueType.getType.getTypeName}"""") +: walkedTypePath val keyData = Invoke( UnresolvedMapObjects( - p => deserializerFor(keyType, p), + p => deserializerFor(keyType, p, newTypePath), MapKeys(path)), "array", ObjectType(classOf[Array[Any]])) @@ -287,7 +299,7 @@ object JavaTypeInference { val valueData = Invoke( UnresolvedMapObjects( - p => deserializerFor(valueType, p), + p => deserializerFor(valueType, p, newTypePath), MapValues(path)), "array", ObjectType(classOf[Array[Any]])) @@ -300,25 +312,25 @@ object JavaTypeInference { returnNullable = false) case other if other.isEnum => - StaticInvoke( - other, - ObjectType(other), - "valueOf", - Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil, - returnNullable = false) + createDeserializerForTypesSupportValueOf( + createDeserializerForString(path, returnNullable = false), + other) case other => val properties = getJavaBeanReadableAndWritableProperties(other) val setters = properties.map { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType - val (_, nullable) = inferDataType(fieldType) - val constructor = deserializerFor(fieldType, addToPath(fieldName)) - val setter = if (nullable) { - constructor - } else { - AssertNotNull(constructor, Seq("currently no type path record in java")) - } + val (dataType, nullable) = inferDataType(fieldType) + val newTypePath = (s"""- field (class: "${fieldType.getType.getTypeName}"""" + + s""", name: "$fieldName")""") +: walkedTypePath + val setter = deserializerForWithNullSafety( + path, + dataType, + nullable = nullable, + newTypePath, + (expr, typePath) => deserializerFor(fieldType, + addToPath(expr, fieldName, dataType, typePath), typePath)) p.getWriteMethod.getName -> setter }.toMap diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d5af91acd071..741cba80640b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.catalyst -import java.lang.reflect.Constructor - -import scala.util.Properties - import org.apache.commons.lang3.reflect.ConstructorUtils import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal +import org.apache.spark.sql.catalyst.expressions.{Expression, _} import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -128,25 +125,6 @@ object ScalaReflection extends ScalaReflection { case _ => false } - /** - * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff - * and lost the required data type, which may lead to runtime error if the real type doesn't - * match the encoder's schema. - * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type - * is [a: int, b: long], then we will hit runtime error and say that we can't construct class - * `Data` with int and long, because we lost the information that `b` should be a string. - * - * This method help us "remember" the required data type by adding a `UpCast`. Note that we - * only need to do this for leaf nodes. - */ - private def upCastToExpectedType(expr: Expression, expected: DataType, - walkedTypePath: Seq[String]): Expression = expected match { - case _: StructType => expr - case _: ArrayType => expr - case _: MapType => expr - case _ => UpCast(expr, expected, walkedTypePath) - } - /** * Returns an expression that can be used to deserialize a Spark SQL representation to an object * of type `T` with a compatible schema. The Spark SQL representation is located at ordinal 0 of @@ -162,15 +140,9 @@ object ScalaReflection extends ScalaReflection { val Schema(dataType, nullable) = schemaFor(tpe) // Assumes we are deserializing the first column of a row. - val input = upCastToExpectedType( - GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) - - val expr = deserializerFor(tpe, input, walkedTypePath) - if (nullable) { - expr - } else { - AssertNotNull(expr, walkedTypePath) - } + deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType, + nullable = nullable, walkedTypePath, + (casted, typePath) => deserializerFor(tpe, casted, typePath)) } /** @@ -185,22 +157,6 @@ object ScalaReflection extends ScalaReflection { tpe: `Type`, path: Expression, walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects { - - /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { - val newPath = UnresolvedExtractValue(path, expressions.Literal(part)) - upCastToExpectedType(newPath, dataType, walkedTypePath) - } - - /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal( - ordinal: Int, - dataType: DataType, - walkedTypePath: Seq[String]): Expression = { - val newPath = GetStructField(path, ordinal) - upCastToExpectedType(newPath, dataType, walkedTypePath) - } - tpe.dealias match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path @@ -211,73 +167,53 @@ object ScalaReflection extends ScalaReflection { WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType)) case t if t <:< localTypeOf[java.lang.Integer] => - val boxedType = classOf[java.lang.Integer] - val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) + createDeserializerForTypesSupportValueOf(path, + classOf[java.lang.Integer]) case t if t <:< localTypeOf[java.lang.Long] => - val boxedType = classOf[java.lang.Long] - val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) + createDeserializerForTypesSupportValueOf(path, + classOf[java.lang.Long]) case t if t <:< localTypeOf[java.lang.Double] => - val boxedType = classOf[java.lang.Double] - val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) + createDeserializerForTypesSupportValueOf(path, + classOf[java.lang.Double]) case t if t <:< localTypeOf[java.lang.Float] => - val boxedType = classOf[java.lang.Float] - val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) + createDeserializerForTypesSupportValueOf(path, + classOf[java.lang.Float]) case t if t <:< localTypeOf[java.lang.Short] => - val boxedType = classOf[java.lang.Short] - val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) + createDeserializerForTypesSupportValueOf(path, + classOf[java.lang.Short]) case t if t <:< localTypeOf[java.lang.Byte] => - val boxedType = classOf[java.lang.Byte] - val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) + createDeserializerForTypesSupportValueOf(path, + classOf[java.lang.Byte]) case t if t <:< localTypeOf[java.lang.Boolean] => - val boxedType = classOf[java.lang.Boolean] - val objectType = ObjectType(boxedType) - StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, returnNullable = false) + createDeserializerForTypesSupportValueOf(path, + classOf[java.lang.Boolean]) case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils.getClass, - ObjectType(classOf[java.sql.Date]), - "toJavaDate", - path :: Nil, - returnNullable = false) + createDeserializerForSqlDate(path) case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils.getClass, - ObjectType(classOf[java.sql.Timestamp]), - "toJavaTimestamp", - path :: Nil, - returnNullable = false) + createDeserializerForSqlTimestamp(path) case t if t <:< localTypeOf[java.lang.String] => - Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) + createDeserializerForString(path, returnNullable = false) case t if t <:< localTypeOf[java.math.BigDecimal] => - Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), - returnNullable = false) + createDeserializerForJavaBigDecimal(path, returnNullable = false) case t if t <:< localTypeOf[BigDecimal] => - Invoke(path, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false) + createDeserializerForScalaBigDecimal(path, returnNullable = false) case t if t <:< localTypeOf[java.math.BigInteger] => - Invoke(path, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), - returnNullable = false) + createDeserializerForJavaBigInteger(path, returnNullable = false) case t if t <:< localTypeOf[scala.math.BigInt] => - Invoke(path, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), - returnNullable = false) + createDeserializerForScalaBigInt(path) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t @@ -287,34 +223,29 @@ object ScalaReflection extends ScalaReflection { val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. - val casted = upCastToExpectedType(element, dataType, newTypePath) - val converter = deserializerFor(elementType, casted, newTypePath) - if (elementNullable) { - converter - } else { - AssertNotNull(converter, newTypePath) - } + deserializerForWithNullSafetyAndUpcast( + element, + dataType, + nullable = elementNullable, + newTypePath, + (casted, typePath) => deserializerFor(elementType, casted, typePath)) } val arrayData = UnresolvedMapObjects(mapFunction, path) val arrayCls = arrayClassFor(elementType) - if (elementNullable) { - Invoke(arrayData, "array", arrayCls, returnNullable = false) - } else { - val primitiveMethod = elementType match { - case t if t <:< definitions.IntTpe => "toIntArray" - case t if t <:< definitions.LongTpe => "toLongArray" - case t if t <:< definitions.DoubleTpe => "toDoubleArray" - case t if t <:< definitions.FloatTpe => "toFloatArray" - case t if t <:< definitions.ShortTpe => "toShortArray" - case t if t <:< definitions.ByteTpe => "toByteArray" - case t if t <:< definitions.BooleanTpe => "toBooleanArray" - case other => throw new IllegalStateException("expect primitive array element type " + - "but got " + other) - } - Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false) + val methodName = elementType match { + case t if t <:< definitions.IntTpe => "toIntArray" + case t if t <:< definitions.LongTpe => "toLongArray" + case t if t <:< definitions.DoubleTpe => "toDoubleArray" + case t if t <:< definitions.FloatTpe => "toFloatArray" + case t if t <:< definitions.ShortTpe => "toShortArray" + case t if t <:< definitions.ByteTpe => "toByteArray" + case t if t <:< definitions.BooleanTpe => "toBooleanArray" + // non-primitive + case _ => "array" } + Invoke(arrayData, methodName, arrayCls, returnNullable = false) // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array // to a `Set`, if there are duplicated elements, the elements will be de-duplicated. @@ -326,14 +257,12 @@ object ScalaReflection extends ScalaReflection { val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath val mapFunction: Expression => Expression = element => { - // upcast the array element to the data type the encoder expected. - val casted = upCastToExpectedType(element, dataType, newTypePath) - val converter = deserializerFor(elementType, casted, newTypePath) - if (elementNullable) { - converter - } else { - AssertNotNull(converter, newTypePath) - } + deserializerForWithNullSafetyAndUpcast( + element, + dataType, + nullable = elementNullable, + newTypePath, + (casted, typePath) => deserializerFor(elementType, casted, typePath)) } val companion = t.dealias.typeSymbol.companion.typeSignature @@ -346,13 +275,18 @@ object ScalaReflection extends ScalaReflection { UnresolvedMapObjects(mapFunction, path, Some(cls)) case t if t <:< localTypeOf[Map[_, _]] => - // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t + val classNameForKey = getClassNameFromType(keyType) + val classNameForValue = getClassNameFromType(valueType) + + val newTypePath = (s"""- map key class: "${classNameForKey}"""" + + s""", value class: "${classNameForValue}"""") +: walkedTypePath + UnresolvedCatalystToExternalMap( path, - p => deserializerFor(keyType, p, walkedTypePath), - p => deserializerFor(valueType, p, walkedTypePath), + p => deserializerFor(keyType, p, newTypePath), + p => deserializerFor(valueType, p, newTypePath), mirror.runtimeClass(t.typeSymbol.asClass) ) @@ -382,25 +316,28 @@ object ScalaReflection extends ScalaReflection { val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) - val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - // For tuples, we based grab the inner fields by ordinal instead of name. - val constructor = if (cls.getName startsWith "scala.Tuple") { - deserializerFor( - fieldType, - addToPathOrdinal(i, dataType, newTypePath), - newTypePath) - } else { - deserializerFor( - fieldType, - addToPath(fieldName, dataType, newTypePath), - newTypePath) - } + val newTypePath = (s"""- field (class: "$clsName", """ + + s"""name: "$fieldName")""") +: walkedTypePath - if (!nullable) { - AssertNotNull(constructor, newTypePath) - } else { - constructor - } + // For tuples, we based grab the inner fields by ordinal instead of name. + deserializerForWithNullSafety( + path, + dataType, + nullable = nullable, + newTypePath, + (expr, typePath) => { + if (cls.getName startsWith "scala.Tuple") { + deserializerFor( + fieldType, + addToPathOrdinal(expr, i, dataType, typePath), + newTypePath) + } else { + deserializerFor( + fieldType, + addToPath(expr, fieldName, dataType, typePath), + newTypePath) + } + }) } val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index 8f35abeb579b..49ff522cee8e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -20,11 +20,12 @@ import java.io.Serializable; import java.util.*; +import org.apache.spark.sql.*; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; import org.junit.*; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Encoder; -import org.apache.spark.sql.Encoders; import org.apache.spark.sql.test.TestSparkSession; public class JavaBeanDeserializationSuite implements Serializable { @@ -115,6 +116,109 @@ public void testBeanWithMapFieldsDeserialization() { Assert.assertEquals(records, MAP_RECORDS); } + @Test + public void testSpark22000() { + List inputRows = new ArrayList<>(); + List expectedRecords = new ArrayList<>(); + + for (long idx = 0 ; idx < 5 ; idx++) { + Row row = createRecordSpark22000Row(idx); + inputRows.add(row); + expectedRecords.add(createRecordSpark22000(row)); + } + + // Here we try to convert the fields, from any types to string. + // Before applying SPARK-22000, Spark called toString() against variable which type might + // be primitive. + // SPARK-22000 it calls String.valueOf() which finally calls toString() but handles boxing + // if the type is primitive. + Encoder encoder = Encoders.bean(RecordSpark22000.class); + + StructType schema = new StructType() + .add("shortField", DataTypes.ShortType) + .add("intField", DataTypes.IntegerType) + .add("longField", DataTypes.LongType) + .add("floatField", DataTypes.FloatType) + .add("doubleField", DataTypes.DoubleType) + .add("stringField", DataTypes.StringType) + .add("booleanField", DataTypes.BooleanType) + .add("timestampField", DataTypes.TimestampType) + // explicitly setting nullable = true to make clear the intention + .add("nullIntField", DataTypes.IntegerType, true); + + Dataset dataFrame = spark.createDataFrame(inputRows, schema); + Dataset dataset = dataFrame.as(encoder); + + List records = dataset.collectAsList(); + + Assert.assertEquals(records, records); + } + + @Test + public void testSpark22000FailToUpcast() { + List inputRows = new ArrayList<>(); + for (long idx = 0 ; idx < 5 ; idx++) { + Row row = createRecordSpark22000FailToUpcastRow(idx); + inputRows.add(row); + } + + // Here we try to convert the fields, from string type to int, which upcast doesn't help. + Encoder encoder = + Encoders.bean(RecordSpark22000FailToUpcast.class); + + StructType schema = new StructType().add("id", DataTypes.StringType); + + Dataset dataFrame = spark.createDataFrame(inputRows, schema); + + try { + dataFrame.as(encoder).collect(); + Assert.fail("Expected AnalysisException, but passed."); + } catch (Throwable e) { + // Here we need to handle weird case: compiler complains AnalysisException never be thrown + // in try statement, but it can be thrown actually. Maybe Scala-Java interop issue? + if (e instanceof AnalysisException) { + Assert.assertTrue(e.getMessage().contains("Cannot up cast ")); + } else { + throw e; + } + } + } + + private static Row createRecordSpark22000Row(Long index) { + Object[] values = new Object[] { + index.shortValue(), + index.intValue(), + index, + index.floatValue(), + index.doubleValue(), + String.valueOf(index), + index % 2 == 0, + new java.sql.Timestamp(System.currentTimeMillis()), + null + }; + return new GenericRow(values); + } + + private static RecordSpark22000 createRecordSpark22000(Row recordRow) { + RecordSpark22000 record = new RecordSpark22000(); + record.setShortField(String.valueOf(recordRow.getShort(0))); + record.setIntField(String.valueOf(recordRow.getInt(1))); + record.setLongField(String.valueOf(recordRow.getLong(2))); + record.setFloatField(String.valueOf(recordRow.getFloat(3))); + record.setDoubleField(String.valueOf(recordRow.getDouble(4))); + record.setStringField(recordRow.getString(5)); + record.setBooleanField(String.valueOf(recordRow.getBoolean(6))); + record.setTimestampField(String.valueOf(recordRow.getTimestamp(7).getTime() * 1000)); + // This would figure out that null value will not become "null". + record.setNullIntField(null); + return record; + } + + private static Row createRecordSpark22000FailToUpcastRow(Long index) { + Object[] values = new Object[] { String.valueOf(index) }; + return new GenericRow(values); + } + public static class ArrayRecord { private int id; @@ -252,4 +356,142 @@ public String toString() { return String.format("[%d,%d]", startTime, endTime); } } + + public static final class RecordSpark22000 { + private String shortField; + private String intField; + private String longField; + private String floatField; + private String doubleField; + private String stringField; + private String booleanField; + private String timestampField; + private String nullIntField; + + public RecordSpark22000() { } + + public String getShortField() { + return shortField; + } + + public void setShortField(String shortField) { + this.shortField = shortField; + } + + public String getIntField() { + return intField; + } + + public void setIntField(String intField) { + this.intField = intField; + } + + public String getLongField() { + return longField; + } + + public void setLongField(String longField) { + this.longField = longField; + } + + public String getFloatField() { + return floatField; + } + + public void setFloatField(String floatField) { + this.floatField = floatField; + } + + public String getDoubleField() { + return doubleField; + } + + public void setDoubleField(String doubleField) { + this.doubleField = doubleField; + } + + public String getStringField() { + return stringField; + } + + public void setStringField(String stringField) { + this.stringField = stringField; + } + + public String getBooleanField() { + return booleanField; + } + + public void setBooleanField(String booleanField) { + this.booleanField = booleanField; + } + + public String getTimestampField() { + return timestampField; + } + + public void setTimestampField(String timestampField) { + this.timestampField = timestampField; + } + + public String getNullIntField() { + return nullIntField; + } + + public void setNullIntField(String nullIntField) { + this.nullIntField = nullIntField; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RecordSpark22000 that = (RecordSpark22000) o; + return Objects.equals(shortField, that.shortField) && + Objects.equals(intField, that.intField) && + Objects.equals(longField, that.longField) && + Objects.equals(floatField, that.floatField) && + Objects.equals(doubleField, that.doubleField) && + Objects.equals(stringField, that.stringField) && + Objects.equals(booleanField, that.booleanField) && + Objects.equals(timestampField, that.timestampField) && + Objects.equals(nullIntField, that.nullIntField); + } + + @Override + public int hashCode() { + return Objects.hash(shortField, intField, longField, floatField, doubleField, stringField, + booleanField, timestampField, nullIntField); + } + + @Override + public String toString() { + return com.google.common.base.Objects.toStringHelper(this) + .add("shortField", shortField) + .add("intField", intField) + .add("longField", longField) + .add("floatField", floatField) + .add("doubleField", doubleField) + .add("stringField", stringField) + .add("booleanField", booleanField) + .add("timestampField", timestampField) + .add("nullIntField", nullIntField) + .toString(); + } + } + + public static final class RecordSpark22000FailToUpcast { + private Integer id; + + public RecordSpark22000FailToUpcast() { + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + } }