diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a1d7b1108bf7..0fa1cbc56844 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -304,9 +304,9 @@ registered as a table. Tables can be used in subsequent SQL statements. Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) into a DataFrame. -The `BeanInfo`, obtained using reflection, defines the schema of the table. Currently, Spark SQL -does not support JavaBeans that contain `Map` field(s). Nested JavaBeans and `List` or `Array` -fields are supported though. You can create a JavaBean by creating a class that implements +The `BeanInfo`, obtained using reflection, defines the schema of the table. Spark SQL supports +fields that contain `List`, `Array`, `Map` or a nested JavaBean. JavaBeans are also supported as collection elements. +You can create a JavaBean by creating a class that implements Serializable and has getters and setters for all of its fields. {% include_example schema_inferring java/org/apache/spark/examples/sql/JavaSparkSQLExample.java %} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dfb12f272eb2..f36e66e1d4a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql +import java.lang.reflect.{Array => JavaArray, ParameterizedType, Type} import java.util.Properties +import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag @@ -30,6 +32,7 @@ import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} import org.apache.spark.sql.sources.BaseRelation @@ -1098,12 +1101,20 @@ object SQLContext { data: Iterator[_], beanClass: Class[_], attrs: Seq[AttributeReference]): Iterator[InternalRow] = { + def interfaceParameters(t: Type, interface: Class[_], dataType: DataType): Array[Type] = + t match { + case parType: ParameterizedType if parType.getRawType == interface => + parType.getActualTypeArguments + case _ => throw new UnsupportedOperationException( + s"Type ${t.getTypeName} is not supported for data type ${dataType.simpleString}. " + + s"Expected ${interface.getName}") + } def createStructConverter(cls: Class[_], fieldTypes: Seq[DataType]): Any => InternalRow = { val methodConverters = JavaTypeInference.getJavaBeanReadableProperties(cls).zip(fieldTypes) .map { case (property, fieldType) => val method = property.getReadMethod - method -> createConverter(method.getReturnType, fieldType) + method -> createConverter(method.getGenericReturnType, fieldType) } value => if (value == null) { @@ -1115,9 +1126,38 @@ object SQLContext { }) } } - def createConverter(cls: Class[_], dataType: DataType): Any => Any = dataType match { - case struct: StructType => createStructConverter(cls, struct.map(_.dataType)) - case _ => CatalystTypeConverters.createToCatalystConverter(dataType) + def createConverter(t: Type, dataType: DataType): Any => Any = (t, dataType) match { + case (cls: Class[_], struct: StructType) => + // bean type + createStructConverter(cls, struct.map(_.dataType)) + case (arrayType: Class[_], array: ArrayType) if arrayType.isArray => + // array type + val converter = createConverter(arrayType.getComponentType, array.elementType) + value => new GenericArrayData( + (0 until JavaArray.getLength(value)).map(i => + converter(JavaArray.get(value, i))).toArray) + case (_, array: ArrayType) => + // java.util.List type + val cls = classOf[java.util.List[_]] + val params = interfaceParameters(t, cls, dataType) + val converter = createConverter(params(0), array.elementType) + value => new GenericArrayData( + value.asInstanceOf[java.util.List[_]].asScala.map(converter).toArray) + case (_, map: MapType) => + // java.util.Map type + val cls = classOf[java.util.Map[_, _]] + val params = interfaceParameters(t, cls, dataType) + val keyConverter = createConverter(params(0), map.keyType) + val valueConverter = createConverter(params(1), map.valueType) + value => { + val (keys, values) = value.asInstanceOf[java.util.Map[_, _]].asScala.unzip[Any, Any] + new ArrayBasedMapData( + new GenericArrayData(keys.map(keyConverter).toArray), + new GenericArrayData(values.map(valueConverter).toArray)) + } + case _ => + // other types + CatalystTypeConverters.createToCatalystConverter(dataType) } val dataConverter = createStructConverter(beanClass, attrs.map(_.dataType)) data.map(dataConverter) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index a05afa4f6ba3..21035d50caaf 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -136,6 +136,9 @@ public static class Bean implements Serializable { private BigInteger e = new BigInteger("1234567"); private NestedBean f = new NestedBean(); private NestedBean g = null; + private NestedBean[] h = new NestedBean[] { new NestedBean() }; + private List i = Collections.singletonList(new NestedBean()); + private Map j = Collections.singletonMap(1, new NestedBean()); public double getA() { return a; @@ -163,6 +166,18 @@ public NestedBean getG() { return g; } + public NestedBean[] getH() { + return h; + } + + public List getI() { + return i; + } + + public Map getJ() { + return j; + } + public static class NestedBean implements Serializable { private int a = 1; @@ -196,7 +211,18 @@ void validateDataFrameWithBeans(Bean bean, Dataset df) { schema.apply("f")); Assert.assertEquals(new StructField("g", nestedBeanType, true, Metadata.empty()), schema.apply("g")); - Row first = df.select("a", "b", "c", "d", "e", "f", "g").first(); + ArrayType nestedBeanTypeList = new ArrayType(nestedBeanType, true); + Assert.assertEquals( + new StructField("h", nestedBeanTypeList, true, Metadata.empty()), + schema.apply("h")); + Assert.assertEquals( + new StructField("i", nestedBeanTypeList, true, Metadata.empty()), + schema.apply("i")); + Assert.assertEquals( + new StructField("j", new MapType(IntegerType$.MODULE$, nestedBeanType, true), + true, Metadata.empty()), + schema.apply("j")); + Row first = df.select("a", "b", "c", "d", "e", "f", "g", "h", "i", "j").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. @@ -220,6 +246,21 @@ void validateDataFrameWithBeans(Bean bean, Dataset df) { Row nested = first.getStruct(5); Assert.assertEquals(bean.getF().getA(), nested.getInt(0)); Assert.assertTrue(first.isNullAt(6)); + List nestedList = first.getList(7); + Assert.assertEquals(bean.getH().length, nestedList.size()); + for (int i = 0; i < bean.getH().length; ++i) { + Assert.assertEquals(bean.getH()[i].getA(), nestedList.get(i).getInt(0)); + } + nestedList = first.getList(8); + Assert.assertEquals(bean.getI().size(), nestedList.size()); + for (int i = 0; i < bean.getI().size(); ++i) { + Assert.assertEquals(bean.getI().get(i).getA(), nestedList.get(i).getInt(0)); + } + Map nestedMap = first.getJavaMap(9); + Assert.assertEquals(bean.getJ().size(), nestedMap.size()); + for (Map.Entry entry : bean.getJ().entrySet()) { + Assert.assertEquals(entry.getValue().getA(), nestedMap.get(entry.getKey()).getInt(0)); + } } @Test