Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class VectorUDTSuite extends SparkFunSuite {
}

test("JavaTypeInference with VectorUDT") {
val (dataType, _) = JavaTypeInference.inferDataType(classOf[LabeledPoint])
val (dataType, _) = JavaTypeInference.inferDataType(classOf[LabeledPoint], false)
assert(dataType.asInstanceOf[StructType].fields.map(_.dataType)
=== Seq(new VectorUDT, DoubleType))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ object Encoders {
*
* @since 1.6.0
*/
def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass)
def bean[T](beanClass: Class[T], skipCircularRefField: Boolean = false)
: Encoder[T] = ExpressionEncoder.javaBean(beanClass, skipCircularRefField)

/**
* (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ object JavaTypeInference {
* @param beanClass Java type
* @return (SQL data type, nullable)
*/
def inferDataType(beanClass: Class[_]): (DataType, Boolean) = {
inferDataType(TypeToken.of(beanClass))
def inferDataType(beanClass: Class[_], skipCircularRefField: Boolean)
: (DataType, Boolean) = {
inferDataType(TypeToken.of(beanClass), Set.empty, skipCircularRefField)
}

/**
Expand All @@ -80,11 +81,12 @@ object JavaTypeInference {

/**
* Infers the corresponding SQL data type of a Java type.
* Overload the method with configurable skipCircularRefField
* @param typeToken Java type
* @return (SQL data type, nullable)
*/
private def inferDataType(typeToken: TypeToken[_], seenTypeSet: Set[Class[_]] = Set.empty)
: (DataType, Boolean) = {
private def inferDataType(typeToken: TypeToken[_], seenTypeSet: Set[Class[_]] = Set.empty,
skipCircularRefField: Boolean = false) : (DataType, Boolean) = {
typeToken.getRawType match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance(), true)
Expand Down Expand Up @@ -124,33 +126,42 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.time.Period] => (YearMonthIntervalType(), true)

case _ if typeToken.isArray =>
val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet)
val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet,
skipCircularRefField)
(ArrayType(dataType, nullable), true)

case _ if ttIsAssignableFrom(iterableType, typeToken) =>
val (dataType, nullable) = inferDataType(elementType(typeToken), seenTypeSet)
val (dataType, nullable) = inferDataType(elementType(typeToken), seenTypeSet,
skipCircularRefField)
(ArrayType(dataType, nullable), true)

case _ if ttIsAssignableFrom(mapType, typeToken) =>
val (keyType, valueType) = mapKeyValueType(typeToken)
val (keyDataType, _) = inferDataType(keyType, seenTypeSet)
val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet)
val (keyDataType, _) = inferDataType(keyType, seenTypeSet, skipCircularRefField)
val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet, skipCircularRefField)
(MapType(keyDataType, valueDataType, nullable), true)

case other if other.isEnum =>
(StringType, true)

case other =>
if (seenTypeSet.contains(other)) {
if (seenTypeSet.contains(other) && !skipCircularRefField) {
throw QueryExecutionErrors.cannotHaveCircularReferencesInBeanClassError(other)
}

// TODO: we should only collect properties that have getter and setter. However, some tests
// pass in scala case class as java bean class which doesn't have getter and setter.
val properties = getJavaBeanReadableProperties(other)
val fields = properties.map { property =>
val fields = properties.filter(
property =>
!skipCircularRefField ||
!seenTypeSet
.contains(typeToken.method(property.getReadMethod).getReturnType.getRawType)
)
.map { property =>
val returnType = typeToken.method(property.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(returnType, seenTypeSet + other)
val (dataType, nullable) = inferDataType(returnType, seenTypeSet + other,
skipCircularRefField)
// The existence of `javax.annotation.Nonnull`, means this field is not nullable.
val hasNonNull = property.getReadMethod.isAnnotationPresent(classOf[Nonnull])
new StructField(property.getName, dataType, nullable && !hasNonNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ object ExpressionEncoder {
}

// TODO: improve error message for java bean encoder.
def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
val schema = JavaTypeInference.inferDataType(beanClass)._1
def javaBean[T](beanClass: Class[T], skipCircularRefField: Boolean = false)
: ExpressionEncoder[T] = {
val schema = JavaTypeInference.inferDataType(beanClass, skipCircularRefField)._1
assert(schema.isInstanceOf[StructType])

val objSerializer = JavaTypeInference.serializerFor(beanClass)
Expand Down