Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,10 @@ object JavaTypeInference {
val (valueDataType, nullable) = inferDataType(valueType)
(MapType(keyDataType, valueDataType, nullable), true)

case _ =>
case other =>
// TODO: we should only collect properties that have getter and setter. However, some tests
// pass in scala case class as java bean class which doesn't have getter and setter.
val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
val properties = getJavaBeanReadableProperties(other)
val fields = properties.map { property =>
val returnType = typeToken.method(property.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(returnType)
Expand All @@ -131,10 +130,15 @@ object JavaTypeInference {
}
}

private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
val beanInfo = Introspector.getBeanInfo(beanClass)
beanInfo.getPropertyDescriptors
.filter(p => p.getReadMethod != null && p.getWriteMethod != null)
beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
.filter(_.getReadMethod != null)
}

private def getJavaBeanReadableAndWritableProperties(
beanClass: Class[_]): Array[PropertyDescriptor] = {
getJavaBeanReadableProperties(beanClass).filter(_.getWriteMethod != null)
}

private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
Expand Down Expand Up @@ -298,9 +302,7 @@ object JavaTypeInference {
keyData :: valueData :: Nil)

case other =>
val properties = getJavaBeanProperties(other)
assert(properties.length > 0)

val properties = getJavaBeanReadableAndWritableProperties(other)
val setters = properties.map { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
Expand Down Expand Up @@ -417,21 +419,16 @@ object JavaTypeInference {
)

case other =>
val properties = getJavaBeanProperties(other)
if (properties.length > 0) {
CreateNamedStruct(properties.flatMap { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val fieldValue = Invoke(
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
})
} else {
throw new UnsupportedOperationException(
s"Cannot infer type for class ${other.getName} because it is not bean-compliant")
}
val properties = getJavaBeanReadableAndWritableProperties(other)
CreateNamedStruct(properties.flatMap { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val fieldValue = Invoke(
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
})
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1090,14 +1090,14 @@ object SQLContext {
*/
private[sql] def beansToRows(
data: Iterator[_],
beanInfo: BeanInfo,
beanClass: Class[_],
attrs: Seq[AttributeReference]): Iterator[InternalRow] = {
val extractors =
beanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod)
JavaTypeInference.getJavaBeanReadableProperties(beanClass).map(_.getReadMethod)
val methodsToConverts = extractors.zip(attrs).map { case (e, attr) =>
(e, CatalystTypeConverters.createToCatalystConverter(attr.dataType))
}
data.map{ element =>
data.map { element =>
new GenericInternalRow(
methodsToConverts.map { case (e, convert) => convert(e.invoke(element)) }
): InternalRow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql

import java.beans.Introspector
import java.io.Closeable
import java.util.concurrent.atomic.AtomicReference

Expand Down Expand Up @@ -347,8 +346,7 @@ class SparkSession private(
val className = beanClass.getName
val rowRdd = rdd.mapPartitions { iter =>
// BeanInfo is not serializable so we must rediscover it remotely for each partition.
val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className))
SQLContext.beansToRows(iter, localBeanInfo, attributeSeq)
SQLContext.beansToRows(iter, Utils.classForName(className), attributeSeq)
}
Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd)(self))
}
Expand All @@ -374,8 +372,7 @@ class SparkSession private(
*/
def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = {
val attrSeq = getSchema(beanClass)
val beanInfo = Introspector.getBeanInfo(beanClass)
val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq)
val rows = SQLContext.beansToRows(data.asScala.iterator, beanClass, attrSeq)
Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,4 +397,21 @@ public void testBloomFilter() {
Assert.assertTrue(filter4.mightContain(i * 3));
}
}

public static class BeanWithoutGetter implements Serializable {
private String a;

public void setA(String a) {
this.a = a;
}
}

@Test
public void testBeanWithoutGetter() {
BeanWithoutGetter bean = new BeanWithoutGetter();
List<BeanWithoutGetter> data = Arrays.asList(bean);
Dataset<Row> df = spark.createDataFrame(data, BeanWithoutGetter.class);
Assert.assertEquals(df.schema().length(), 0);
Assert.assertEquals(df.collectAsList().size(), 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1276,4 +1276,15 @@ public void test() {
spark.createDataset(data, Encoders.bean(NestedComplicatedJavaBean.class));
ds.collectAsList();
}

public static class EmptyBean implements Serializable {}

@Test
public void testEmptyBean() {
EmptyBean bean = new EmptyBean();
List<EmptyBean> data = Arrays.asList(bean);
Dataset<EmptyBean> df = spark.createDataset(data, Encoders.bean(EmptyBean.class));
Assert.assertEquals(df.schema().length(), 0);
Assert.assertEquals(df.collectAsList().size(), 1);
}
}