diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala index 67c64f762b25..d634a7bfcc3d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala @@ -40,7 +40,7 @@ class VectorUDTSuite extends SparkFunSuite { } test("JavaTypeInference with VectorUDT") { - val (dataType, _) = JavaTypeInference.inferDataType(classOf[LabeledPoint]) + val (dataType, _) = JavaTypeInference.inferDataType(classOf[LabeledPoint], false) assert(dataType.asInstanceOf[StructType].fields.map(_.dataType) === Seq(new VectorUDT, DoubleType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index a41980448865..fb095165dff8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -176,7 +176,8 @@ object Encoders { * * @since 1.6.0 */ - def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass) + def bean[T](beanClass: Class[T], skipCircularRefField: Boolean = false) + : Encoder[T] = ExpressionEncoder.javaBean(beanClass, skipCircularRefField) /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 903072ae29d8..7b400b766b50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -65,8 +65,9 @@ object JavaTypeInference { * @param beanClass Java type * @return (SQL data type, nullable) */ - def inferDataType(beanClass: Class[_]): (DataType, Boolean) = { - inferDataType(TypeToken.of(beanClass)) + def inferDataType(beanClass: Class[_], skipCircularRefField: Boolean) + : (DataType, Boolean) = { + inferDataType(TypeToken.of(beanClass), Set.empty, skipCircularRefField) } /** @@ -80,11 +81,12 @@ object JavaTypeInference { /** * Infers the corresponding SQL data type of a Java type. + * Overload the method with configurable skipCircularRefField * @param typeToken Java type * @return (SQL data type, nullable) */ - private def inferDataType(typeToken: TypeToken[_], seenTypeSet: Set[Class[_]] = Set.empty) - : (DataType, Boolean) = { + private def inferDataType(typeToken: TypeToken[_], seenTypeSet: Set[Class[_]] = Set.empty, + skipCircularRefField: Boolean = false) : (DataType, Boolean) = { typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance(), true) @@ -124,33 +126,42 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.time.Period] => (YearMonthIntervalType(), true) case _ if typeToken.isArray => - val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet) + val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet, + skipCircularRefField) (ArrayType(dataType, nullable), true) case _ if ttIsAssignableFrom(iterableType, typeToken) => - val (dataType, nullable) = inferDataType(elementType(typeToken), seenTypeSet) + val (dataType, nullable) = inferDataType(elementType(typeToken), seenTypeSet, + skipCircularRefField) (ArrayType(dataType, nullable), true) case _ if ttIsAssignableFrom(mapType, typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) - val (keyDataType, _) = inferDataType(keyType, seenTypeSet) - val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet) + val (keyDataType, _) = inferDataType(keyType, seenTypeSet, skipCircularRefField) + val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet, skipCircularRefField) (MapType(keyDataType, valueDataType, nullable), true) case other if other.isEnum => (StringType, true) case other => - if (seenTypeSet.contains(other)) { + if (seenTypeSet.contains(other) && !skipCircularRefField) { throw QueryExecutionErrors.cannotHaveCircularReferencesInBeanClassError(other) } // TODO: we should only collect properties that have getter and setter. However, some tests // pass in scala case class as java bean class which doesn't have getter and setter. val properties = getJavaBeanReadableProperties(other) - val fields = properties.map { property => + val fields = properties.filter( + property => + !skipCircularRefField || + !seenTypeSet + .contains(typeToken.method(property.getReadMethod).getReturnType.getRawType) + ) + .map { property => val returnType = typeToken.method(property.getReadMethod).getReturnType - val (dataType, nullable) = inferDataType(returnType, seenTypeSet + other) + val (dataType, nullable) = inferDataType(returnType, seenTypeSet + other, + skipCircularRefField) // The existence of `javax.annotation.Nonnull`, means this field is not nullable. val hasNonNull = property.getReadMethod.isAnnotationPresent(classOf[Nonnull]) new StructField(property.getName, dataType, nullable && !hasNonNull) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index a53914b5f7ab..69a0af89392e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -61,8 +61,9 @@ object ExpressionEncoder { } // TODO: improve error message for java bean encoder. - def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = { - val schema = JavaTypeInference.inferDataType(beanClass)._1 + def javaBean[T](beanClass: Class[T], skipCircularRefField: Boolean = false) + : ExpressionEncoder[T] = { + val schema = JavaTypeInference.inferDataType(beanClass, skipCircularRefField)._1 assert(schema.isInstanceOf[StructType]) val objSerializer = JavaTypeInference.serializerFor(beanClass)