Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.jdbc

import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException}
import java.util.Properties

import org.apache.commons.lang.StringEscapeUtils.escapeSql
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
Expand Down Expand Up @@ -90,9 +91,9 @@ private[sql] object JDBCRDD extends Logging {
* @throws SQLException if the table specification is garbage.
* @throws SQLException if the table contains an unsupported type.
*/
def resolveTable(url: String, table: String): StructType = {
def resolveTable(url: String, table: String, properties: Properties): StructType = {
val quirks = DriverQuirks.get(url)
val conn: Connection = DriverManager.getConnection(url)
val conn: Connection = DriverManager.getConnection(url, properties)
try {
val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery()
try {
Expand Down Expand Up @@ -147,7 +148,7 @@ private[sql] object JDBCRDD extends Logging {
*
* @return A function that loads the driver and connects to the url.
*/
def getConnector(driver: String, url: String): () => Connection = {
def getConnector(driver: String, url: String, properties: Properties): () => Connection = {
() => {
try {
if (driver != null) Class.forName(driver)
Expand All @@ -156,7 +157,7 @@ private[sql] object JDBCRDD extends Logging {
logWarning(s"Couldn't find class $driver", e);
}
}
DriverManager.getConnection(url)
DriverManager.getConnection(url, properties)
}
}
/**
Expand All @@ -179,6 +180,7 @@ private[sql] object JDBCRDD extends Logging {
schema: StructType,
driver: String,
url: String,
properties: Properties,
fqTable: String,
requiredColumns: Array[String],
filters: Array[Filter],
Expand All @@ -189,7 +191,7 @@ private[sql] object JDBCRDD extends Logging {
return new
JDBCRDD(
sc,
getConnector(driver, url),
getConnector(driver, url, properties),
prunedSchema,
fqTable,
requiredColumns,
Expand Down Expand Up @@ -361,7 +363,7 @@ private[sql] class JDBCRDD(
var ans = 0L
var j = 0
while (j < bytes.size) {
ans = 256*ans + (255 & bytes(j))
ans = 256 * ans + (255 & bytes(j))
j = j + 1;
}
mutableRow.setLong(i, ans)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@

package org.apache.spark.sql.jdbc

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.types.StructType
import java.sql.DriverManager
import java.util.Properties

import scala.collection.mutable.ArrayBuffer
import java.sql.DriverManager

import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType

/**
* Data corresponding to one partition of a JDBCRDD.
Expand Down Expand Up @@ -115,18 +116,21 @@ private[sql] class DefaultSource extends RelationProvider {
numPartitions.toInt)
}
val parts = JDBCRelation.columnPartition(partitionInfo)
JDBCRelation(url, table, parts)(sqlContext)
val properties = new Properties() // Additional properties that we will pass to getConnection
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably document this behavior somewhere -- I'm not sure if there's a standard place for data sources. If nowhere else, at least add it to the doc of createRelation().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add this to the programming guide under JDBC.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
JDBCRelation(url, table, parts, properties)(sqlContext)
}
}

private[sql] case class JDBCRelation(
url: String,
table: String,
parts: Array[Partition])(@transient val sqlContext: SQLContext)
parts: Array[Partition],
properties: Properties = new Properties())(@transient val sqlContext: SQLContext)
extends BaseRelation
with PrunedFilteredScan {

override val schema: StructType = JDBCRDD.resolveTable(url, table)
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName
Expand All @@ -135,6 +139,7 @@ private[sql] case class JDBCRelation(
schema,
driver,
url,
properties,
table,
requiredColumns,
filters,
Expand Down
55 changes: 39 additions & 16 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,31 @@ package org.apache.spark.sql.jdbc

import java.math.BigDecimal
import java.sql.DriverManager
import java.util.{Calendar, GregorianCalendar}
import java.util.{Calendar, GregorianCalendar, Properties}

import org.apache.spark.sql.test._
import org.h2.jdbc.JdbcSQLException
import org.scalatest.{FunSuite, BeforeAndAfter}
import TestSQLContext._
import TestSQLContext.implicits._

class JDBCSuite extends FunSuite with BeforeAndAfter {
val url = "jdbc:h2:mem:testdb0"
val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
var conn: java.sql.Connection = null

val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte)

before {
Class.forName("org.h2.Driver")
conn = DriverManager.getConnection(url)
// Extra properties that will be specified for our database. We need these to test
// usage of parameters from OPTIONS clause in queries.
val properties = new Properties()
properties.setProperty("user", "testUser")
properties.setProperty("password", "testPass")
properties.setProperty("rowId", "false")

conn = DriverManager.getConnection(url, properties)
conn.prepareStatement("create schema test").executeUpdate()
conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate()
Expand All @@ -46,15 +55,15 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
s"""
|CREATE TEMPORARY TABLE foobar
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE')
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))

sql(
s"""
|CREATE TEMPORARY TABLE parts
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE',
|partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3')
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass',
| partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3')
""".stripMargin.replaceAll("\n", " "))

conn.prepareStatement("create table test.inttypes (a INT, b BOOLEAN, c TINYINT, "
Expand All @@ -68,12 +77,12 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
s"""
|CREATE TEMPORARY TABLE inttypes
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.INTTYPES')
|OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))

conn.prepareStatement("create table test.strtypes (a BINARY(20), b VARCHAR(20), "
+ "c VARCHAR_IGNORECASE(20), d CHAR(20), e BLOB, f CLOB)").executeUpdate()
var stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)")
val stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)")
stmt.setBytes(1, testBytes)
stmt.setString(2, "Sensitive")
stmt.setString(3, "Insensitive")
Expand All @@ -85,7 +94,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
s"""
|CREATE TEMPORARY TABLE strtypes
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.STRTYPES')
|OPTIONS (url '$url', dbtable 'TEST.STRTYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))

conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP)"
Expand All @@ -97,7 +106,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
s"""
|CREATE TEMPORARY TABLE timetypes
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.TIMETYPES')
|OPTIONS (url '$url', dbtable 'TEST.TIMETYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))


Expand All @@ -112,7 +121,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
s"""
|CREATE TEMPORARY TABLE flttypes
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.FLTTYPES')
|OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))

// Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types.
Expand Down Expand Up @@ -174,16 +183,17 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
}

test("Basic API") {
assert(TestSQLContext.jdbc(url, "TEST.PEOPLE").collect.size == 3)
assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect.size == 3)
}

test("Partitioning via JDBCPartitioningInfo API") {
assert(TestSQLContext.jdbc(url, "TEST.PEOPLE", "THEID", 0, 4, 3).collect.size == 3)
assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3)
.collect.size == 3)
}

test("Partitioning via list-of-where-clauses API") {
val parts = Array[String]("THEID < 2", "THEID >= 2")
assert(TestSQLContext.jdbc(url, "TEST.PEOPLE", parts).collect.size == 3)
assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect.size == 3)
}

test("H2 integral types") {
Expand Down Expand Up @@ -216,7 +226,6 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
assert(rows(0).getString(5).equals("I am a clob!"))
}


test("H2 time types") {
val rows = sql("SELECT * FROM timetypes").collect()
val cal = new GregorianCalendar(java.util.Locale.ROOT)
Expand Down Expand Up @@ -246,17 +255,31 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
.equals(new BigDecimal("123456789012345.54321543215432100000")))
}


test("SQL query as table name") {
sql(
s"""
|CREATE TEMPORARY TABLE hack
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)')
|OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)',
| user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
val rows = sql("SELECT * FROM hack").collect()
assert(rows(0).getDouble(0) == 1.00000011920928955) // Yes, I meant ==.
// For some reason, H2 computes this square incorrectly...
assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12)
}

test("Pass extra properties via OPTIONS") {
// We set rowId to false during setup, which means that _ROWID_ column should be absent from
// all tables. If rowId is true (default), the query below doesn't throw an exception.
intercept[JdbcSQLException] {
sql(
s"""
|CREATE TEMPORARY TABLE abc
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable '(SELECT _ROWID_ FROM test.people)',
| user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
}
}
}