Skip to content

Commit 3535467

Browse files
adrian-wangmarmbrus
authored andcommitted
[SPARK-4003] [SQL] add 3 types for java SQL context
In JavaSqlContext, we need to let java program use big decimal, timestamp, date types. Author: Daoyuan Wang <[email protected]> Closes #2850 from adrian-wang/javacontext and squashes the following commits: 4c4292c [Daoyuan Wang] change underlying type of JavaSchemaRDD as scala bb0508f [Daoyuan Wang] add test cases 3c58b0d [Daoyuan Wang] add 3 types for java SQL context
1 parent dff0155 commit 3535467

File tree

3 files changed

+59
-5
lines changed

3 files changed

+59
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration
2424
import org.apache.spark.annotation.{DeveloperApi, Experimental}
2525
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
2626
import org.apache.spark.sql.json.JsonRDD
27+
import org.apache.spark.sql.types.util.DataTypeConversions
2728
import org.apache.spark.sql.{SQLContext, StructType => SStructType}
2829
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
2930
import org.apache.spark.sql.parquet.ParquetRelation
@@ -97,7 +98,9 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
9798
localBeanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod)
9899

99100
iter.map { row =>
100-
new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow
101+
new GenericRow(
102+
extractors.map(e => DataTypeConversions.convertJavaToCatalyst(e.invoke(row))).toArray[Any]
103+
): ScalaRow
101104
}
102105
}
103106
new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext))
@@ -226,6 +229,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
226229
(org.apache.spark.sql.FloatType, true)
227230
case c: Class[_] if c == classOf[java.lang.Boolean] =>
228231
(org.apache.spark.sql.BooleanType, true)
232+
case c: Class[_] if c == classOf[java.math.BigDecimal] =>
233+
(org.apache.spark.sql.DecimalType, true)
234+
case c: Class[_] if c == classOf[java.sql.Date] =>
235+
(org.apache.spark.sql.DateType, true)
236+
case c: Class[_] if c == classOf[java.sql.Timestamp] =>
237+
(org.apache.spark.sql.TimestampType, true)
229238
}
230239
AttributeReference(property.getName, dataType, nullable)()
231240
}

sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,16 @@ protected[sql] object DataTypeConversions {
110110
case structType: org.apache.spark.sql.api.java.StructType =>
111111
StructType(structType.getFields.map(asScalaStructField))
112112
}
113+
114+
/** Converts Java objects to catalyst rows / types */
115+
def convertJavaToCatalyst(a: Any): Any = a match {
116+
case d: java.math.BigDecimal => BigDecimal(d)
117+
case other => other
118+
}
119+
120+
/** Converts Java objects to catalyst rows / types */
121+
def convertCatalystToJava(a: Any): Any = a match {
122+
case d: scala.math.BigDecimal => d.underlying()
123+
case other => other
124+
}
113125
}

sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ class AllTypesBean extends Serializable {
4545
@BeanProperty var shortField: java.lang.Short = _
4646
@BeanProperty var byteField: java.lang.Byte = _
4747
@BeanProperty var booleanField: java.lang.Boolean = _
48+
@BeanProperty var dateField: java.sql.Date = _
49+
@BeanProperty var timestampField: java.sql.Timestamp = _
50+
@BeanProperty var bigDecimalField: java.math.BigDecimal = _
4851
}
4952

5053
class JavaSQLSuite extends FunSuite {
@@ -73,6 +76,9 @@ class JavaSQLSuite extends FunSuite {
7376
bean.setShortField(0.toShort)
7477
bean.setByteField(0.toByte)
7578
bean.setBooleanField(false)
79+
bean.setDateField(java.sql.Date.valueOf("2014-10-10"))
80+
bean.setTimestampField(java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0"))
81+
bean.setBigDecimalField(new java.math.BigDecimal(0))
7682

7783
val rdd = javaCtx.parallelize(bean :: Nil)
7884
val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean])
@@ -82,10 +88,34 @@ class JavaSQLSuite extends FunSuite {
8288
javaSqlCtx.sql(
8389
"""
8490
|SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField,
85-
| booleanField
91+
| booleanField, dateField, timestampField, bigDecimalField
8692
|FROM allTypes
8793
""".stripMargin).collect.head.row ===
88-
Seq("", 0, 0L, 0F, 0.0, 0.toShort, 0.toByte, false))
94+
Seq("", 0, 0L, 0F, 0.0, 0.toShort, 0.toByte, false, java.sql.Date.valueOf("2014-10-10"),
95+
java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0"), scala.math.BigDecimal(0)))
96+
}
97+
98+
test("decimal types in JavaBeans") {
99+
val bean = new AllTypesBean
100+
bean.setStringField("")
101+
bean.setIntField(0)
102+
bean.setLongField(0)
103+
bean.setFloatField(0.0F)
104+
bean.setDoubleField(0.0)
105+
bean.setShortField(0.toShort)
106+
bean.setByteField(0.toByte)
107+
bean.setBooleanField(false)
108+
bean.setDateField(java.sql.Date.valueOf("2014-10-10"))
109+
bean.setTimestampField(java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0"))
110+
bean.setBigDecimalField(new java.math.BigDecimal(0))
111+
112+
val rdd = javaCtx.parallelize(bean :: Nil)
113+
val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean])
114+
schemaRDD.registerTempTable("decimalTypes")
115+
116+
assert(javaSqlCtx.sql(
117+
"select bigDecimalField + bigDecimalField from decimalTypes"
118+
).collect.head.row === Seq(scala.math.BigDecimal(0)))
89119
}
90120

91121
test("all types null in JavaBeans") {
@@ -98,6 +128,9 @@ class JavaSQLSuite extends FunSuite {
98128
bean.setShortField(null)
99129
bean.setByteField(null)
100130
bean.setBooleanField(null)
131+
bean.setDateField(null)
132+
bean.setTimestampField(null)
133+
bean.setBigDecimalField(null)
101134

102135
val rdd = javaCtx.parallelize(bean :: Nil)
103136
val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean])
@@ -107,10 +140,10 @@ class JavaSQLSuite extends FunSuite {
107140
javaSqlCtx.sql(
108141
"""
109142
|SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField,
110-
| booleanField
143+
| booleanField, dateField, timestampField, bigDecimalField
111144
|FROM allTypes
112145
""".stripMargin).collect.head.row ===
113-
Seq.fill(8)(null))
146+
Seq.fill(11)(null))
114147
}
115148

116149
test("loads JSON datasets") {

0 commit comments

Comments
 (0)