Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1098,16 +1098,29 @@ object SQLContext {
data: Iterator[_],
beanClass: Class[_],
attrs: Seq[AttributeReference]): Iterator[InternalRow] = {
val extractors =
JavaTypeInference.getJavaBeanReadableProperties(beanClass).map(_.getReadMethod)
val methodsToConverts = extractors.zip(attrs).map { case (e, attr) =>
(e, CatalystTypeConverters.createToCatalystConverter(attr.dataType))
def createStructConverter(cls: Class[_], fieldTypes: Seq[DataType]): Any => InternalRow = {
val methodConverters =
JavaTypeInference.getJavaBeanReadableProperties(cls).zip(fieldTypes)
.map { case (property, fieldType) =>
val method = property.getReadMethod
method -> createConverter(method.getReturnType, fieldType)
}
value =>
if (value == null) {
null
} else {
new GenericInternalRow(
methodConverters.map { case (method, converter) =>
converter(method.invoke(value))
})
}
}
data.map { element =>
new GenericInternalRow(
methodsToConverts.map { case (e, convert) => convert(e.invoke(element)) }
): InternalRow
def createConverter(cls: Class[_], dataType: DataType): Any => Any = dataType match {
case struct: StructType => createStructConverter(cls, struct.map(_.dataType))
case _ => CatalystTypeConverters.createToCatalystConverter(dataType)
}
val dataConverter = createStructConverter(beanClass, attrs.map(_.dataType))
data.map(dataConverter)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ public static class Bean implements Serializable {
private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });
private List<String> d = Arrays.asList("floppy", "disk");
private BigInteger e = new BigInteger("1234567");
private NestedBean f = new NestedBean();
private NestedBean g = null;

public double getA() {
return a;
Expand All @@ -152,6 +154,22 @@ public List<String> getD() {
}

public BigInteger getE() { return e; }

public NestedBean getF() {
return f;
}

public NestedBean getG() {
return g;
}

public static class NestedBean implements Serializable {
private int a = 1;

public int getA() {
return a;
}
}
}

void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
Expand All @@ -171,7 +189,14 @@ void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
schema.apply("d"));
Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true,
Metadata.empty()), schema.apply("e"));
Row first = df.select("a", "b", "c", "d", "e").first();
StructType nestedBeanType =
DataTypes.createStructType(Collections.singletonList(new StructField(
"a", IntegerType$.MODULE$, false, Metadata.empty())));
Assert.assertEquals(new StructField("f", nestedBeanType, true, Metadata.empty()),
schema.apply("f"));
Assert.assertEquals(new StructField("g", nestedBeanType, true, Metadata.empty()),
schema.apply("g"));
Row first = df.select("a", "b", "c", "d", "e", "f", "g").first();
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
// Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below,
// verify that it has the expected length, and contains expected elements.
Expand All @@ -192,6 +217,9 @@ void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
}
// Java.math.BigInteger is equivalent to Spark Decimal(38,0)
Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4));
Row nested = first.getStruct(5);
Assert.assertEquals(bean.getF().getA(), nested.getInt(0));
Assert.assertTrue(first.isNullAt(6));
}

@Test
Expand Down