Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ import java.math.BigDecimal
import java.sql.{Connection, Date, Timestamp}
import java.util.Properties

import org.scalatest._
import org.scalatest.Ignore

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{BooleanType, ByteType, ShortType, StructType}
import org.apache.spark.tags.DockerTest


@DockerTest
@Ignore // AMPLab Jenkins needs to be updated before shared memory works on docker
class DB2IntegrationSuite extends DockerJDBCIntegrationSuite {
Expand All @@ -47,19 +50,22 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite {
conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate()

conn.prepareStatement("CREATE TABLE numbers ( small SMALLINT, med INTEGER, big BIGINT, "
+ "deci DECIMAL(31,20), flt FLOAT, dbl DOUBLE)").executeUpdate()
+ "deci DECIMAL(31,20), flt FLOAT, dbl DOUBLE, real REAL, "
+ "decflt DECFLOAT, decflt16 DECFLOAT(16), decflt34 DECFLOAT(34))").executeUpdate()
conn.prepareStatement("INSERT INTO numbers VALUES (17, 77777, 922337203685477580, "
+ "123456745.56789012345000000000, 42.75, 5.4E-70)").executeUpdate()
+ "123456745.56789012345000000000, 42.75, 5.4E-70, "
+ "3.4028234663852886e+38, 4.2999, DECFLOAT('9.999999999999999E19', 16), "
+ "DECFLOAT('1234567891234567.123456789123456789', 34))").executeUpdate()

conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, ts TIMESTAMP )").executeUpdate()
conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', "
+ "'2009-02-13 23:31:30')").executeUpdate()

// TODO: Test locale conversion for strings.
conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c CLOB, d BLOB)")
.executeUpdate()
conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', BLOB('fox'))")
conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c CLOB, d BLOB, e XML)")
.executeUpdate()
conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', BLOB('fox'),"
+ "'<cinfo cid=\"10\"><name>Kathy</name></cinfo>')").executeUpdate()
}

test("Basic test") {
Expand All @@ -77,20 +83,28 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite {
val rows = df.collect()
assert(rows.length == 1)
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types.length == 6)
assert(types.length == 10)
assert(types(0).equals("class java.lang.Integer"))
assert(types(1).equals("class java.lang.Integer"))
assert(types(2).equals("class java.lang.Long"))
assert(types(3).equals("class java.math.BigDecimal"))
assert(types(4).equals("class java.lang.Double"))
assert(types(5).equals("class java.lang.Double"))
assert(types(6).equals("class java.lang.Float"))
assert(types(7).equals("class java.math.BigDecimal"))
assert(types(8).equals("class java.math.BigDecimal"))
assert(types(9).equals("class java.math.BigDecimal"))
assert(rows(0).getInt(0) == 17)
assert(rows(0).getInt(1) == 77777)
assert(rows(0).getLong(2) == 922337203685477580L)
val bd = new BigDecimal("123456745.56789012345000000000")
assert(rows(0).getAs[BigDecimal](3).equals(bd))
assert(rows(0).getDouble(4) == 42.75)
assert(rows(0).getDouble(5) == 5.4E-70)
assert(rows(0).getFloat(6) == 3.4028234663852886e+38)
assert(rows(0).getDecimal(7) == new BigDecimal("4.299900000000000000"))
assert(rows(0).getDecimal(8) == new BigDecimal("99999999999999990000.000000000000000000"))
assert(rows(0).getDecimal(9) == new BigDecimal("1234567891234567.123456789123456789"))
}

test("Date types") {
Expand All @@ -112,7 +126,7 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite {
val rows = df.collect()
assert(rows.length == 1)
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types.length == 4)
assert(types.length == 5)
assert(types(0).equals("class java.lang.String"))
assert(types(1).equals("class java.lang.String"))
assert(types(2).equals("class java.lang.String"))
Expand All @@ -121,14 +135,27 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite {
assert(rows(0).getString(1).equals("quick"))
assert(rows(0).getString(2).equals("brown"))
assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](3), Array[Byte](102, 111, 120)))
assert(rows(0).getString(4).equals("""<cinfo cid="10"><name>Kathy</name></cinfo>"""))
}

test("Basic write test") {
// val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties)
// cast decflt column with precision value of 38 to DB2 max decimal precision value of 31.
val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties)
.selectExpr("small", "med", "big", "deci", "flt", "dbl", "real",
"cast(decflt as decimal(31, 5)) as decflt")
val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties)
val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties)
// df1.write.jdbc(jdbcUrl, "numberscopy", new Properties)
df1.write.jdbc(jdbcUrl, "numberscopy", new Properties)
df2.write.jdbc(jdbcUrl, "datescopy", new Properties)
df3.write.jdbc(jdbcUrl, "stringscopy", new Properties)
// spark types that does not have exact matching db2 table types.
val df4 = sqlContext.createDataFrame(
sparkContext.parallelize(Seq(Row("1".toShort, "20".toByte, true))),
new StructType().add("c1", ShortType).add("b", ByteType).add("c3", BooleanType))
df4.write.jdbc(jdbcUrl, "otherscopy", new Properties)
val rows = sqlContext.read.jdbc(jdbcUrl, "otherscopy", new Properties).collect()
assert(rows(0).getInt(0) == 1)
assert(rows(0).getInt(1) == 20)
assert(rows(0).getString(2) == "1")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,34 @@

package org.apache.spark.sql.jdbc

import org.apache.spark.sql.types.{BooleanType, DataType, StringType}
import java.sql.Types

import org.apache.spark.sql.types._

private object DB2Dialect extends JdbcDialect {

override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2")

override def getCatalystType(
sqlType: Int,
typeName: String,
size: Int,
md: MetadataBuilder): Option[DataType] = sqlType match {
case Types.REAL => Option(FloatType)
case Types.OTHER =>
typeName match {
case "DECFLOAT" => Option(DecimalType(38, 18))
case "XML" => Option(StringType)
case t if (t.startsWith("TIMESTAMP")) => Option(TimestampType) // TIMESTAMP WITH TIMEZONE
case _ => None
}
case _ => None
}

override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB))
case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR))
case ShortType | ByteType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT))
case _ => None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,15 @@ class JDBCSuite extends SparkFunSuite
val db2Dialect = JdbcDialects.get("jdbc:db2://127.0.0.1/db")
assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB")
assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)")
assert(db2Dialect.getJDBCType(ShortType).map(_.databaseTypeDefinition).get == "SMALLINT")
assert(db2Dialect.getJDBCType(ByteType).map(_.databaseTypeDefinition).get == "SMALLINT")
// test db2 dialect mappings on read
assert(db2Dialect.getCatalystType(java.sql.Types.REAL, "REAL", 1, null) == Option(FloatType))
assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "DECFLOAT", 1, null) ==
Option(DecimalType(38, 18)))
assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "XML", 1, null) == Option(StringType))
assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "TIMESTAMP WITH TIME ZONE", 1, null) ==
Option(TimestampType))
}

test("PostgresDialect type mapping") {
Expand Down