diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index b5eca76480eb8..72f6970e1800a 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1308,6 +1308,13 @@ the following case-insensitive options:
+
+ sessionInitStatement |
+
+ After each database session is opened to the remote DB and before starting to read data, this option executes a custom SQL statement (or a PL/SQL block). Use this to implement session initialization code. Example: option("sessionInitStatement", """BEGIN execute immediate 'alter session set "_serial_direct_read"=true'; END;""")
+ |
+
+
truncate |
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 591096d5efd22..ff804da2a79cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -135,6 +135,8 @@ class JDBCOptions(
case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ
case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE
}
+ // An option to execute custom SQL before fetching data from the remote DB
+ val sessionInitStatement = parameters.get(JDBC_SESSION_INIT_STATEMENT)
}
object JDBCOptions {
@@ -158,4 +160,5 @@ object JDBCOptions {
val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes")
val JDBC_BATCH_INSERT_SIZE = newOption("batchsize")
val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
+ val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 24e13697c0c9f..3274be91d4817 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -273,6 +273,21 @@ private[jdbc] class JDBCRDD(
import scala.collection.JavaConverters._
dialect.beforeFetch(conn, options.asProperties.asScala.toMap)
+ // This executes a generic SQL statement (or PL/SQL block) before reading
+ // the table/query via JDBC. Use this feature to initialize the database
+ // session environment, e.g. for optimizations and/or troubleshooting.
+ options.sessionInitStatement match {
+ case Some(sql) =>
+ val statement = conn.prepareStatement(sql)
+ logInfo(s"Executing sessionInitStatement: $sql")
+ try {
+ statement.execute()
+ } finally {
+ statement.close()
+ }
+ case None =>
+ }
+
// H2's JDBC driver does not support the setSchema() method. We pass a
// fully-qualified table name in the SELECT statement. I don't know how to
// talk about a table in a completely portable way.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index d1daf860fdfff..b21adbdbf1362 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -1007,4 +1007,35 @@ class JDBCSuite extends SparkFunSuite
assert(sql("select * from people_view").count() == 3)
}
}
+
+ test("SPARK-21519: option sessionInitStatement, run SQL to initialize the database session.") {
+ val initSQL1 = "SET @MYTESTVAR 21519"
+ val df1 = spark.read.format("jdbc")
+ .option("url", urlWithUserAndPass)
+ .option("dbtable", "(SELECT NVL(@MYTESTVAR, -1))")
+ .option("sessionInitStatement", initSQL1)
+ .load()
+ assert(df1.collect() === Array(Row(21519)))
+
+ val initSQL2 = "SET SCHEMA DUMMY"
+ val df2 = spark.read.format("jdbc")
+ .option("url", urlWithUserAndPass)
+ .option("dbtable", "TEST.PEOPLE")
+ .option("sessionInitStatement", initSQL2)
+ .load()
+ val e = intercept[SparkException] {df2.collect()}.getMessage
+ assert(e.contains("""Schema "DUMMY" not found"""))
+
+ sql(
+ s"""
+ |CREATE OR REPLACE TEMPORARY VIEW test_sessionInitStatement
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (url '$urlWithUserAndPass',
+ |dbtable '(SELECT NVL(@MYTESTVAR1, -1), NVL(@MYTESTVAR2, -1))',
+ |sessionInitStatement 'SET @MYTESTVAR1 21519; SET @MYTESTVAR2 1234')
+ """.stripMargin)
+
+ val df3 = sql("SELECT * FROM test_sessionInitStatement")
+ assert(df3.collect() === Array(Row(21519, 1234)))
+ }
}
|