Skip to content

Commit 11a5a9b

Browse files
Punya Biswalnemccarthy
authored andcommitted
[SPARK-6996][SQL] Support map types in java beans
liancheng mengxr this is similar to apache#5146. Author: Punya Biswal <[email protected]> Closes apache#5578 from punya/feature/SPARK-6996 and squashes the following commits: d56c3e0 [Punya Biswal] Fix imports c7e308b [Punya Biswal] Support java iterable types in POJOs 5e00685 [Punya Biswal] Support map types in java beans
1 parent 1c86ed5 commit 11a5a9b

File tree

4 files changed

+180
-59
lines changed

4 files changed

+180
-59
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst
1919

20+
import java.lang.{Iterable => JavaIterable}
2021
import java.util.{Map => JavaMap}
2122

2223
import scala.collection.mutable.HashMap
@@ -49,6 +50,16 @@ object CatalystTypeConverters {
4950
case (s: Seq[_], arrayType: ArrayType) =>
5051
s.map(convertToCatalyst(_, arrayType.elementType))
5152

53+
case (jit: JavaIterable[_], arrayType: ArrayType) => {
54+
val iter = jit.iterator
55+
var listOfItems: List[Any] = List()
56+
while (iter.hasNext) {
57+
val item = iter.next()
58+
listOfItems :+= convertToCatalyst(item, arrayType.elementType)
59+
}
60+
listOfItems
61+
}
62+
5263
case (s: Array[_], arrayType: ArrayType) =>
5364
s.toSeq.map(convertToCatalyst(_, arrayType.elementType))
5465

@@ -124,6 +135,15 @@ object CatalystTypeConverters {
124135
extractOption(item) match {
125136
case a: Array[_] => a.toSeq.map(elementConverter)
126137
case s: Seq[_] => s.map(elementConverter)
138+
case i: JavaIterable[_] => {
139+
val iter = i.iterator
140+
var convertedIterable: List[Any] = List()
141+
while (iter.hasNext) {
142+
val item = iter.next()
143+
convertedIterable :+= elementConverter(item)
144+
}
145+
convertedIterable
146+
}
127147
case null => null
128148
}
129149
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import java.beans.Introspector
21+
import java.lang.{Iterable => JIterable}
22+
import java.util.{Iterator => JIterator, Map => JMap}
23+
24+
import com.google.common.reflect.TypeToken
25+
26+
import org.apache.spark.sql.types._
27+
28+
import scala.language.existentials
29+
30+
/**
31+
* Type-inference utilities for POJOs and Java collections.
32+
*/
33+
private [sql] object JavaTypeInference {
34+
35+
private val iterableType = TypeToken.of(classOf[JIterable[_]])
36+
private val mapType = TypeToken.of(classOf[JMap[_, _]])
37+
private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType
38+
private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType
39+
private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType
40+
private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType
41+
42+
/**
43+
* Infers the corresponding SQL data type of a Java type.
44+
* @param typeToken Java type
45+
* @return (SQL data type, nullable)
46+
*/
47+
private [sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
48+
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
49+
typeToken.getRawType match {
50+
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
51+
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
52+
53+
case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
54+
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
55+
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
56+
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
57+
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
58+
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
59+
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
60+
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
61+
62+
case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
63+
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
64+
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
65+
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
66+
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
67+
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
68+
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
69+
70+
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
71+
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
72+
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
73+
74+
case _ if typeToken.isArray =>
75+
val (dataType, nullable) = inferDataType(typeToken.getComponentType)
76+
(ArrayType(dataType, nullable), true)
77+
78+
case _ if iterableType.isAssignableFrom(typeToken) =>
79+
val (dataType, nullable) = inferDataType(elementType(typeToken))
80+
(ArrayType(dataType, nullable), true)
81+
82+
case _ if mapType.isAssignableFrom(typeToken) =>
83+
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
84+
val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]])
85+
val keyType = elementType(mapSupertype.resolveType(keySetReturnType))
86+
val valueType = elementType(mapSupertype.resolveType(valuesReturnType))
87+
val (keyDataType, _) = inferDataType(keyType)
88+
val (valueDataType, nullable) = inferDataType(valueType)
89+
(MapType(keyDataType, valueDataType, nullable), true)
90+
91+
case _ =>
92+
val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
93+
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
94+
val fields = properties.map { property =>
95+
val returnType = typeToken.method(property.getReadMethod).getReturnType
96+
val (dataType, nullable) = inferDataType(returnType)
97+
new StructField(property.getName, dataType, nullable)
98+
}
99+
(new StructType(fields), true)
100+
}
101+
}
102+
103+
private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
104+
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
105+
val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]])
106+
val iteratorType = iterableSupertype.resolveType(iteratorReturnType)
107+
val itemType = iteratorType.resolveType(nextReturnType)
108+
itemType
109+
}
110+
}

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 5 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import scala.collection.immutable
2525
import scala.language.implicitConversions
2626
import scala.reflect.runtime.universe.TypeTag
2727

28+
import com.google.common.reflect.TypeToken
29+
2830
import org.apache.spark.annotation.{DeveloperApi, Experimental}
2931
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
3032
import org.apache.spark.rdd.RDD
@@ -1222,56 +1224,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
12221224
* Returns a Catalyst Schema for the given java bean class.
12231225
*/
12241226
protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
1225-
val (dataType, _) = inferDataType(beanClass)
1227+
val (dataType, _) = JavaTypeInference.inferDataType(TypeToken.of(beanClass))
12261228
dataType.asInstanceOf[StructType].fields.map { f =>
12271229
AttributeReference(f.name, f.dataType, f.nullable)()
12281230
}
12291231
}
12301232

1231-
/**
1232-
* Infers the corresponding SQL data type of a Java class.
1233-
* @param clazz Java class
1234-
* @return (SQL data type, nullable)
1235-
*/
1236-
private def inferDataType(clazz: Class[_]): (DataType, Boolean) = {
1237-
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
1238-
clazz match {
1239-
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
1240-
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
1241-
1242-
case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
1243-
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
1244-
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
1245-
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
1246-
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
1247-
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
1248-
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
1249-
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
1250-
1251-
case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
1252-
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
1253-
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
1254-
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
1255-
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
1256-
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
1257-
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
1258-
1259-
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
1260-
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
1261-
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
1262-
1263-
case c: Class[_] if c.isArray =>
1264-
val (dataType, nullable) = inferDataType(c.getComponentType)
1265-
(ArrayType(dataType, nullable), true)
1266-
1267-
case _ =>
1268-
val beanInfo = Introspector.getBeanInfo(clazz)
1269-
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
1270-
val fields = properties.map { property =>
1271-
val (dataType, nullable) = inferDataType(property.getPropertyType)
1272-
new StructField(property.getName, dataType, nullable)
1273-
}
1274-
(new StructType(fields), true)
1275-
}
1276-
}
12771233
}
1234+
1235+

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,28 @@
1717

1818
package test.org.apache.spark.sql;
1919

20-
import java.io.Serializable;
21-
import java.util.Arrays;
22-
23-
import scala.collection.Seq;
24-
25-
import org.junit.After;
26-
import org.junit.Assert;
27-
import org.junit.Before;
28-
import org.junit.Ignore;
29-
import org.junit.Test;
20+
import com.google.common.collect.ImmutableMap;
21+
import com.google.common.primitives.Ints;
3022

3123
import org.apache.spark.api.java.JavaRDD;
3224
import org.apache.spark.api.java.JavaSparkContext;
33-
import org.apache.spark.sql.*;
25+
import org.apache.spark.sql.DataFrame;
26+
import org.apache.spark.sql.Row;
27+
import org.apache.spark.sql.SQLContext;
28+
import org.apache.spark.sql.TestData$;
3429
import org.apache.spark.sql.test.TestSQLContext;
3530
import org.apache.spark.sql.test.TestSQLContext$;
3631
import org.apache.spark.sql.types.*;
32+
import org.junit.*;
33+
34+
import scala.collection.JavaConversions;
35+
import scala.collection.Seq;
36+
import scala.collection.mutable.Buffer;
37+
38+
import java.io.Serializable;
39+
import java.util.Arrays;
40+
import java.util.List;
41+
import java.util.Map;
3742

3843
import static org.apache.spark.sql.functions.*;
3944

@@ -106,6 +111,8 @@ public void testShow() {
106111
public static class Bean implements Serializable {
107112
private double a = 0.0;
108113
private Integer[] b = new Integer[]{0, 1};
114+
private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });
115+
private List<String> d = Arrays.asList("floppy", "disk");
109116

110117
public double getA() {
111118
return a;
@@ -114,6 +121,14 @@ public double getA() {
114121
public Integer[] getB() {
115122
return b;
116123
}
124+
125+
public Map<String, int[]> getC() {
126+
return c;
127+
}
128+
129+
public List<String> getD() {
130+
return d;
131+
}
117132
}
118133

119134
@Test
@@ -127,7 +142,15 @@ public void testCreateDataFrameFromJavaBeans() {
127142
Assert.assertEquals(
128143
new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()),
129144
schema.apply("b"));
130-
Row first = df.select("a", "b").first();
145+
ArrayType valueType = new ArrayType(DataTypes.IntegerType, false);
146+
MapType mapType = new MapType(DataTypes.StringType, valueType, true);
147+
Assert.assertEquals(
148+
new StructField("c", mapType, true, Metadata.empty()),
149+
schema.apply("c"));
150+
Assert.assertEquals(
151+
new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()),
152+
schema.apply("d"));
153+
Row first = df.select("a", "b", "c", "d").first();
131154
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
132155
// Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below,
133156
// verify that it has the expected length, and contains expected elements.
@@ -136,5 +159,15 @@ public void testCreateDataFrameFromJavaBeans() {
136159
for (int i = 0; i < result.length(); i++) {
137160
Assert.assertEquals(bean.getB()[i], result.apply(i));
138161
}
162+
Buffer<Integer> outputBuffer = (Buffer<Integer>) first.getJavaMap(2).get("hello");
163+
Assert.assertArrayEquals(
164+
bean.getC().get("hello"),
165+
Ints.toArray(JavaConversions.asJavaList(outputBuffer)));
166+
Seq<String> d = first.getAs(3);
167+
Assert.assertEquals(bean.getD().size(), d.length());
168+
for (int i = 0; i < d.length(); i++) {
169+
Assert.assertEquals(bean.getD().get(i), d.apply(i));
170+
}
139171
}
172+
140173
}

0 commit comments

Comments
 (0)