Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.types.util.DataTypeConversions
import org.apache.spark.sql.{SQLContext, StructType => SStructType}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
import org.apache.spark.sql.parquet.ParquetRelation
Expand Down Expand Up @@ -97,7 +98,9 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
localBeanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod)

iter.map { row =>
new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow
new GenericRow(
extractors.map(e => DataTypeConversions.convertJavaToCatalyst(e.invoke(row))).toArray[Any]
): ScalaRow
}
}
new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext))
Expand Down Expand Up @@ -226,6 +229,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
(org.apache.spark.sql.FloatType, true)
case c: Class[_] if c == classOf[java.lang.Boolean] =>
(org.apache.spark.sql.BooleanType, true)
case c: Class[_] if c == classOf[java.math.BigDecimal] =>
(org.apache.spark.sql.DecimalType, true)
case c: Class[_] if c == classOf[java.sql.Date] =>
(org.apache.spark.sql.DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] =>
(org.apache.spark.sql.TimestampType, true)
}
AttributeReference(property.getName, dataType, nullable)()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,16 @@ protected[sql] object DataTypeConversions {
case structType: org.apache.spark.sql.api.java.StructType =>
StructType(structType.getFields.map(asScalaStructField))
}

/** Converts Java objects to catalyst rows / types */
def convertJavaToCatalyst(a: Any): Any = a match {
case d: java.math.BigDecimal => BigDecimal(d)
case other => other
}

/** Converts Java objects to catalyst rows / types */
def convertCatalystToJava(a: Any): Any = a match {
case d: scala.math.BigDecimal => d.underlying()
case other => other
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class AllTypesBean extends Serializable {
@BeanProperty var shortField: java.lang.Short = _
@BeanProperty var byteField: java.lang.Byte = _
@BeanProperty var booleanField: java.lang.Boolean = _
@BeanProperty var dateField: java.sql.Date = _
@BeanProperty var timestampField: java.sql.Timestamp = _
@BeanProperty var bigDecimalField: java.math.BigDecimal = _
}

class JavaSQLSuite extends FunSuite {
Expand Down Expand Up @@ -73,6 +76,9 @@ class JavaSQLSuite extends FunSuite {
bean.setShortField(0.toShort)
bean.setByteField(0.toByte)
bean.setBooleanField(false)
bean.setDateField(java.sql.Date.valueOf("2014-10-10"))
bean.setTimestampField(java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0"))
bean.setBigDecimalField(new java.math.BigDecimal(0))

val rdd = javaCtx.parallelize(bean :: Nil)
val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean])
Expand All @@ -82,10 +88,34 @@ class JavaSQLSuite extends FunSuite {
javaSqlCtx.sql(
"""
|SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField,
| booleanField
| booleanField, dateField, timestampField, bigDecimalField
|FROM allTypes
""".stripMargin).collect.head.row ===
Seq("", 0, 0L, 0F, 0.0, 0.toShort, 0.toByte, false))
Seq("", 0, 0L, 0F, 0.0, 0.toShort, 0.toByte, false, java.sql.Date.valueOf("2014-10-10"),
java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0"), scala.math.BigDecimal(0)))
}

test("decimal types in JavaBeans") {
val bean = new AllTypesBean
bean.setStringField("")
bean.setIntField(0)
bean.setLongField(0)
bean.setFloatField(0.0F)
bean.setDoubleField(0.0)
bean.setShortField(0.toShort)
bean.setByteField(0.toByte)
bean.setBooleanField(false)
bean.setDateField(java.sql.Date.valueOf("2014-10-10"))
bean.setTimestampField(java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0"))
bean.setBigDecimalField(new java.math.BigDecimal(0))

val rdd = javaCtx.parallelize(bean :: Nil)
val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean])
schemaRDD.registerTempTable("decimalTypes")

assert(javaSqlCtx.sql(
"select bigDecimalField + bigDecimalField from decimalTypes"
).collect.head.row === Seq(scala.math.BigDecimal(0)))
}

test("all types null in JavaBeans") {
Expand All @@ -98,6 +128,9 @@ class JavaSQLSuite extends FunSuite {
bean.setShortField(null)
bean.setByteField(null)
bean.setBooleanField(null)
bean.setDateField(null)
bean.setTimestampField(null)
bean.setBigDecimalField(null)

val rdd = javaCtx.parallelize(bean :: Nil)
val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean])
Expand All @@ -107,10 +140,10 @@ class JavaSQLSuite extends FunSuite {
javaSqlCtx.sql(
"""
|SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField,
| booleanField
| booleanField, dateField, timestampField, bigDecimalField
|FROM allTypes
""".stripMargin).collect.head.row ===
Seq.fill(8)(null))
Seq.fill(11)(null))
}

test("loads JSON datasets") {
Expand Down