Skip to content

Commit 5e00685

Browse files
author
Punya Biswal
committed
Support map types in java beans
1 parent 327ebf0 commit 5e00685

File tree

3 files changed

+132
-48
lines changed

3 files changed

+132
-48
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import java.beans.Introspector
21+
import java.lang.{Iterable => JIterable}
22+
import java.util.{Iterator => JIterator, Map => JMap}
23+
24+
import com.google.common.reflect.TypeToken
25+
26+
import org.apache.spark.sql.types._
27+
28+
import scala.language.existentials
29+
30+
/**
31+
* Type-inference utilities for POJOs and Java collections.
32+
*/
33+
private [sql] object JavaTypeInference {
34+
35+
private val iterableType = TypeToken.of(classOf[JIterable[_]])
36+
private val mapType = TypeToken.of(classOf[JMap[_, _]])
37+
private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType
38+
private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType
39+
private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType
40+
private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType
41+
42+
/**
43+
* Infers the corresponding SQL data type of a Java type.
44+
* @param typeToken Java type
45+
* @return (SQL data type, nullable)
46+
*/
47+
private [sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
48+
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
49+
typeToken.getRawType match {
50+
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
51+
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
52+
53+
case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
54+
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
55+
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
56+
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
57+
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
58+
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
59+
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
60+
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
61+
62+
case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
63+
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
64+
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
65+
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
66+
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
67+
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
68+
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
69+
70+
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
71+
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
72+
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
73+
74+
case _ if typeToken.isArray =>
75+
val (dataType, nullable) = inferDataType(typeToken.getComponentType)
76+
(ArrayType(dataType, nullable), true)
77+
78+
case _ if mapType.isAssignableFrom(typeToken) =>
79+
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
80+
val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]])
81+
val keyType = elementType(mapSupertype.resolveType(keySetReturnType))
82+
val valueType = elementType(mapSupertype.resolveType(valuesReturnType))
83+
val (keyDataType, _) = inferDataType(keyType)
84+
val (valueDataType, nullable) = inferDataType(valueType)
85+
(MapType(keyDataType, valueDataType, nullable), true)
86+
87+
case _ =>
88+
val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
89+
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
90+
val fields = properties.map { property =>
91+
val returnType = typeToken.method(property.getReadMethod).getReturnType
92+
val (dataType, nullable) = inferDataType(returnType)
93+
new StructField(property.getName, dataType, nullable)
94+
}
95+
(new StructType(fields), true)
96+
}
97+
}
98+
99+
private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
100+
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
101+
val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]])
102+
val iteratorType = iterableSupertype.resolveType(iteratorReturnType)
103+
val itemType = iteratorType.resolveType(nextReturnType)
104+
itemType
105+
}
106+
}

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 5 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import scala.collection.immutable
2525
import scala.language.implicitConversions
2626
import scala.reflect.runtime.universe.TypeTag
2727

28+
import com.google.common.reflect.TypeToken
29+
2830
import org.apache.spark.annotation.{DeveloperApi, Experimental}
2931
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
3032
import org.apache.spark.rdd.RDD
@@ -1222,56 +1224,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
12221224
* Returns a Catalyst Schema for the given java bean class.
12231225
*/
12241226
protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
1225-
val (dataType, _) = inferDataType(beanClass)
1227+
val (dataType, _) = JavaTypeInference.inferDataType(TypeToken.of(beanClass))
12261228
dataType.asInstanceOf[StructType].fields.map { f =>
12271229
AttributeReference(f.name, f.dataType, f.nullable)()
12281230
}
12291231
}
12301232

1231-
/**
1232-
* Infers the corresponding SQL data type of a Java class.
1233-
* @param clazz Java class
1234-
* @return (SQL data type, nullable)
1235-
*/
1236-
private def inferDataType(clazz: Class[_]): (DataType, Boolean) = {
1237-
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
1238-
clazz match {
1239-
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
1240-
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
1241-
1242-
case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
1243-
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
1244-
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
1245-
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
1246-
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
1247-
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
1248-
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
1249-
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
1250-
1251-
case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
1252-
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
1253-
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
1254-
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
1255-
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
1256-
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
1257-
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
1258-
1259-
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
1260-
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
1261-
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
1262-
1263-
case c: Class[_] if c.isArray =>
1264-
val (dataType, nullable) = inferDataType(c.getComponentType)
1265-
(ArrayType(dataType, nullable), true)
1266-
1267-
case _ =>
1268-
val beanInfo = Introspector.getBeanInfo(clazz)
1269-
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
1270-
val fields = properties.map { property =>
1271-
val (dataType, nullable) = inferDataType(property.getPropertyType)
1272-
new StructField(property.getName, dataType, nullable)
1273-
}
1274-
(new StructType(fields), true)
1275-
}
1276-
}
12771233
}
1234+
1235+

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919

2020
import java.io.Serializable;
2121
import java.util.Arrays;
22+
import java.util.Map;
2223

24+
import com.google.common.collect.ImmutableMap;
25+
import com.google.common.primitives.Ints;
26+
import scala.collection.JavaConversions;
2327
import scala.collection.Seq;
2428

2529
import org.junit.After;
@@ -34,6 +38,7 @@
3438
import org.apache.spark.sql.test.TestSQLContext;
3539
import org.apache.spark.sql.test.TestSQLContext$;
3640
import org.apache.spark.sql.types.*;
41+
import scala.collection.mutable.Buffer;
3742

3843
import static org.apache.spark.sql.functions.*;
3944

@@ -106,6 +111,7 @@ public void testShow() {
106111
public static class Bean implements Serializable {
107112
private double a = 0.0;
108113
private Integer[] b = new Integer[]{0, 1};
114+
private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });
109115

110116
public double getA() {
111117
return a;
@@ -114,6 +120,10 @@ public double getA() {
114120
public Integer[] getB() {
115121
return b;
116122
}
123+
124+
public Map<String, int[]> getC() {
125+
return c;
126+
}
117127
}
118128

119129
@Test
@@ -127,7 +137,12 @@ public void testCreateDataFrameFromJavaBeans() {
127137
Assert.assertEquals(
128138
new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()),
129139
schema.apply("b"));
130-
Row first = df.select("a", "b").first();
140+
ArrayType valueType = new ArrayType(DataTypes.IntegerType, false);
141+
MapType mapType = new MapType(DataTypes.StringType, valueType, true);
142+
Assert.assertEquals(
143+
new StructField("c", mapType, true, Metadata.empty()),
144+
schema.apply("c"));
145+
Row first = df.select("a", "b", "c").first();
131146
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
132147
// Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below,
133148
// verify that it has the expected length, and contains expected elements.
@@ -136,5 +151,10 @@ public void testCreateDataFrameFromJavaBeans() {
136151
for (int i = 0; i < result.length(); i++) {
137152
Assert.assertEquals(bean.getB()[i], result.apply(i));
138153
}
154+
Buffer<Integer> outputBuffer = (Buffer<Integer>) first.getJavaMap(2).get("hello");
155+
Assert.assertArrayEquals(
156+
bean.getC().get("hello"),
157+
Ints.toArray(JavaConversions.asJavaList(outputBuffer)));
139158
}
159+
140160
}

0 commit comments

Comments
 (0)