diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index af6018472cb0..dfb12f272eb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1098,16 +1098,29 @@ object SQLContext { data: Iterator[_], beanClass: Class[_], attrs: Seq[AttributeReference]): Iterator[InternalRow] = { - val extractors = - JavaTypeInference.getJavaBeanReadableProperties(beanClass).map(_.getReadMethod) - val methodsToConverts = extractors.zip(attrs).map { case (e, attr) => - (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType)) + def createStructConverter(cls: Class[_], fieldTypes: Seq[DataType]): Any => InternalRow = { + val methodConverters = + JavaTypeInference.getJavaBeanReadableProperties(cls).zip(fieldTypes) + .map { case (property, fieldType) => + val method = property.getReadMethod + method -> createConverter(method.getReturnType, fieldType) + } + value => + if (value == null) { + null + } else { + new GenericInternalRow( + methodConverters.map { case (method, converter) => + converter(method.invoke(value)) + }) + } } - data.map { element => - new GenericInternalRow( - methodsToConverts.map { case (e, convert) => convert(e.invoke(element)) } - ): InternalRow + def createConverter(cls: Class[_], dataType: DataType): Any => Any = dataType match { + case struct: StructType => createStructConverter(cls, struct.map(_.dataType)) + case _ => CatalystTypeConverters.createToCatalystConverter(dataType) } + val dataConverter = createStructConverter(beanClass, attrs.map(_.dataType)) + data.map(dataConverter) } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 3f37e5814cca..df8613f73300 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -134,6 +134,8 @@ public static class Bean implements Serializable { private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); private List d = Arrays.asList("floppy", "disk"); private BigInteger e = new BigInteger("1234567"); + private NestedBean f = new NestedBean(); + private NestedBean g = null; public double getA() { return a; @@ -152,6 +154,22 @@ public List getD() { } public BigInteger getE() { return e; } + + public NestedBean getF() { + return f; + } + + public NestedBean getG() { + return g; + } + + public static class NestedBean implements Serializable { + private int a = 1; + + public int getA() { + return a; + } + } } void validateDataFrameWithBeans(Bean bean, Dataset df) { @@ -171,7 +189,14 @@ void validateDataFrameWithBeans(Bean bean, Dataset df) { schema.apply("d")); Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, Metadata.empty()), schema.apply("e")); - Row first = df.select("a", "b", "c", "d", "e").first(); + StructType nestedBeanType = + DataTypes.createStructType(Collections.singletonList(new StructField( + "a", IntegerType$.MODULE$, false, Metadata.empty()))); + Assert.assertEquals(new StructField("f", nestedBeanType, true, Metadata.empty()), + schema.apply("f")); + Assert.assertEquals(new StructField("g", nestedBeanType, true, Metadata.empty()), + schema.apply("g")); + Row first = df.select("a", "b", "c", "d", "e", "f", "g").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. @@ -192,6 +217,9 @@ void validateDataFrameWithBeans(Bean bean, Dataset df) { } // Java.math.BigInteger is equivalent to Spark Decimal(38,0) Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4)); + Row nested = first.getStruct(5); + Assert.assertEquals(bean.getF().getA(), nested.getInt(0)); + Assert.assertTrue(first.isNullAt(6)); } @Test