diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala
index 3da34b1b382d..f5930bc281e8 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala
@@ -21,10 +21,13 @@ import java.math.BigDecimal
import java.sql.{Connection, Date, Timestamp}
import java.util.Properties
-import org.scalatest._
+import org.scalatest.Ignore
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.{BooleanType, ByteType, ShortType, StructType}
import org.apache.spark.tags.DockerTest
+
@DockerTest
@Ignore // AMPLab Jenkins needs to be updated before shared memory works on docker
class DB2IntegrationSuite extends DockerJDBCIntegrationSuite {
@@ -47,19 +50,22 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite {
conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate()
conn.prepareStatement("CREATE TABLE numbers ( small SMALLINT, med INTEGER, big BIGINT, "
- + "deci DECIMAL(31,20), flt FLOAT, dbl DOUBLE)").executeUpdate()
+ + "deci DECIMAL(31,20), flt FLOAT, dbl DOUBLE, real REAL, "
+ + "decflt DECFLOAT, decflt16 DECFLOAT(16), decflt34 DECFLOAT(34))").executeUpdate()
conn.prepareStatement("INSERT INTO numbers VALUES (17, 77777, 922337203685477580, "
- + "123456745.56789012345000000000, 42.75, 5.4E-70)").executeUpdate()
+ + "123456745.56789012345000000000, 42.75, 5.4E-70, "
+ + "3.4028234663852886e+38, 4.2999, DECFLOAT('9.999999999999999E19', 16), "
+ + "DECFLOAT('1234567891234567.123456789123456789', 34))").executeUpdate()
conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, ts TIMESTAMP )").executeUpdate()
conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', "
+ "'2009-02-13 23:31:30')").executeUpdate()
// TODO: Test locale conversion for strings.
- conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c CLOB, d BLOB)")
- .executeUpdate()
- conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', BLOB('fox'))")
+ conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c CLOB, d BLOB, e XML)")
.executeUpdate()
+ conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', BLOB('fox'),"
+ + "'Kathy')").executeUpdate()
}
test("Basic test") {
@@ -77,13 +83,17 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite {
val rows = df.collect()
assert(rows.length == 1)
val types = rows(0).toSeq.map(x => x.getClass.toString)
- assert(types.length == 6)
+ assert(types.length == 10)
assert(types(0).equals("class java.lang.Integer"))
assert(types(1).equals("class java.lang.Integer"))
assert(types(2).equals("class java.lang.Long"))
assert(types(3).equals("class java.math.BigDecimal"))
assert(types(4).equals("class java.lang.Double"))
assert(types(5).equals("class java.lang.Double"))
+ assert(types(6).equals("class java.lang.Float"))
+ assert(types(7).equals("class java.math.BigDecimal"))
+ assert(types(8).equals("class java.math.BigDecimal"))
+ assert(types(9).equals("class java.math.BigDecimal"))
assert(rows(0).getInt(0) == 17)
assert(rows(0).getInt(1) == 77777)
assert(rows(0).getLong(2) == 922337203685477580L)
@@ -91,6 +101,10 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite {
assert(rows(0).getAs[BigDecimal](3).equals(bd))
assert(rows(0).getDouble(4) == 42.75)
assert(rows(0).getDouble(5) == 5.4E-70)
+ assert(rows(0).getFloat(6) == 3.4028234663852886e+38)
+ assert(rows(0).getDecimal(7) == new BigDecimal("4.299900000000000000"))
+ assert(rows(0).getDecimal(8) == new BigDecimal("99999999999999990000.000000000000000000"))
+ assert(rows(0).getDecimal(9) == new BigDecimal("1234567891234567.123456789123456789"))
}
test("Date types") {
@@ -112,7 +126,7 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite {
val rows = df.collect()
assert(rows.length == 1)
val types = rows(0).toSeq.map(x => x.getClass.toString)
- assert(types.length == 4)
+ assert(types.length == 5)
assert(types(0).equals("class java.lang.String"))
assert(types(1).equals("class java.lang.String"))
assert(types(2).equals("class java.lang.String"))
@@ -121,14 +135,27 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite {
assert(rows(0).getString(1).equals("quick"))
assert(rows(0).getString(2).equals("brown"))
assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](3), Array[Byte](102, 111, 120)))
+ assert(rows(0).getString(4).equals("""Kathy"""))
}
test("Basic write test") {
- // val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties)
+ // cast decflt column with precision value of 38 to DB2 max decimal precision value of 31.
+ val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties)
+ .selectExpr("small", "med", "big", "deci", "flt", "dbl", "real",
+ "cast(decflt as decimal(31, 5)) as decflt")
val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties)
val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties)
- // df1.write.jdbc(jdbcUrl, "numberscopy", new Properties)
+ df1.write.jdbc(jdbcUrl, "numberscopy", new Properties)
df2.write.jdbc(jdbcUrl, "datescopy", new Properties)
df3.write.jdbc(jdbcUrl, "stringscopy", new Properties)
+ // spark types that does not have exact matching db2 table types.
+ val df4 = sqlContext.createDataFrame(
+ sparkContext.parallelize(Seq(Row("1".toShort, "20".toByte, true))),
+ new StructType().add("c1", ShortType).add("b", ByteType).add("c3", BooleanType))
+ df4.write.jdbc(jdbcUrl, "otherscopy", new Properties)
+ val rows = sqlContext.read.jdbc(jdbcUrl, "otherscopy", new Properties).collect()
+ assert(rows(0).getInt(0) == 1)
+ assert(rows(0).getInt(1) == 20)
+ assert(rows(0).getString(2) == "1")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
index 190463df0d92..d160ad82888a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
@@ -17,15 +17,34 @@
package org.apache.spark.sql.jdbc
-import org.apache.spark.sql.types.{BooleanType, DataType, StringType}
+import java.sql.Types
+
+import org.apache.spark.sql.types._
private object DB2Dialect extends JdbcDialect {
override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2")
+ override def getCatalystType(
+ sqlType: Int,
+ typeName: String,
+ size: Int,
+ md: MetadataBuilder): Option[DataType] = sqlType match {
+ case Types.REAL => Option(FloatType)
+ case Types.OTHER =>
+ typeName match {
+ case "DECFLOAT" => Option(DecimalType(38, 18))
+ case "XML" => Option(StringType)
+ case t if (t.startsWith("TIMESTAMP")) => Option(TimestampType) // TIMESTAMP WITH TIMEZONE
+ case _ => None
+ }
+ case _ => None
+ }
+
override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB))
case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR))
+ case ShortType | ByteType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT))
case _ => None
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 70bee929b31d..d1daf860fdff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -713,6 +713,15 @@ class JDBCSuite extends SparkFunSuite
val db2Dialect = JdbcDialects.get("jdbc:db2://127.0.0.1/db")
assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB")
assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)")
+ assert(db2Dialect.getJDBCType(ShortType).map(_.databaseTypeDefinition).get == "SMALLINT")
+ assert(db2Dialect.getJDBCType(ByteType).map(_.databaseTypeDefinition).get == "SMALLINT")
+ // test db2 dialect mappings on read
+ assert(db2Dialect.getCatalystType(java.sql.Types.REAL, "REAL", 1, null) == Option(FloatType))
+ assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "DECFLOAT", 1, null) ==
+ Option(DecimalType(38, 18)))
+ assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "XML", 1, null) == Option(StringType))
+ assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "TIMESTAMP WITH TIME ZONE", 1, null) ==
+ Option(TimestampType))
}
test("PostgresDialect type mapping") {