diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index d184f3cb71b1a..5d1feaed81a9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -147,14 +147,7 @@ class JDBCOptions( """.stripMargin ) - val fetchSize = { - val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt - require(size >= 0, - s"Invalid value `${size.toString}` for parameter " + - s"`$JDBC_BATCH_FETCH_SIZE`. The minimum value is 0. When the value is 0, " + - "the JDBC driver ignores the value and does the estimates.") - size - } + val fetchSize = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt // ------------------------------------------------------------ // Optional parameters only for writing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index e25ce53941ff6..37879e98fcbb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, PreparedStatement, ResultSet, SQLException} +import scala.collection.JavaConverters._ import scala.util.control.NonFatal import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} @@ -184,6 +185,8 @@ private[jdbc] class JDBCRDD( options: JDBCOptions) extends RDD[InternalRow](sc, Nil) { + JdbcDialects.get(url).validateProperties(options.asProperties.asScala.toMap) + /** * Retrieve the list of partitions corresponding to this RDD. */ @@ -271,7 +274,6 @@ private[jdbc] class JDBCRDD( val part = thePart.asInstanceOf[JDBCPartition] conn = getConnection() val dialect = JdbcDialects.get(url) - import scala.collection.JavaConverters._ dialect.beforeFetch(conn, options.asProperties.asScala.toMap) // This executes a generic SQL statement (or PL/SQL block) before reading diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index a0c6d20f36451..c8b6492b7f6c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -22,6 +22,7 @@ import java.sql.{Connection, Date, Timestamp} import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types._ /** @@ -150,6 +151,20 @@ abstract class JdbcDialect extends Serializable { def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { } + /** + * Do some extra properties validation work in addition to the validation + * in [[org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions]]. + * @param properties The connection properties. This is passed through from the relation. + */ + def validateProperties(properties: Map[String, String]): Unit = { + val fetchSize = properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt + require(fetchSize >= 0, + s"Invalid value `${fetchSize.toString}` for parameter " + + s"`${JDBCOptions.JDBC_BATCH_FETCH_SIZE}` for dialect ${this.getClass.getSimpleName}. " + + s"The minimum value is 0. When the value is 0, " + + "the JDBC driver ignores the value and does the estimates.") + } + /** * Escape special characters in SQL string literals. * @param value The string to be escaped. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index b2cff7877d8b5..ece8a17628b2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.Types +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types.{BooleanType, DataType, LongType, MetadataBuilder} private case object MySQLDialect extends JdbcDialect { @@ -46,4 +47,15 @@ private case object MySQLDialect extends JdbcDialect { } override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + + override def validateProperties(properties: Map[String, String]): Unit = { + val fetchSize = properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt + require(fetchSize >= 0 || fetchSize == Integer.MIN_VALUE, + s"Invalid value `${fetchSize.toString}` for parameter " + + s"`${JDBCOptions.JDBC_BATCH_FETCH_SIZE}` for MySQL. " + + s"The value should be >= 0 or equal Integer.MIN_VALUE; " + + s"When the value is 0, the JDBC driver ignores the value and does the estimates; " + + s"When the value is Integer.MIN_VALUE, the data will be fetched in streaming manner, " + + s"namely, fetch one row at a time.") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 3c8ce0a3fc3e4..5b2aaa735950f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -42,6 +42,7 @@ import org.apache.spark.util.Utils class JDBCSuite extends QueryTest with BeforeAndAfter with PrivateMethodTester with SharedSparkSession { + import scala.collection.JavaConverters._ import testImplicits._ val url = "jdbc:h2:mem:testdb0" @@ -459,6 +460,30 @@ class JDBCSuite extends QueryTest assert(e.contains("Invalid value `-1` for parameter `fetchsize`")) } + test("[SPARK-21287] Dialect validate properties") { + val mysqlDialect = JdbcDialects.get("jdbc:mysql:xxx") + val h2Dialect = JdbcDialects.get("jdbc:h2:xxx") + val properties = new Properties() + properties.setProperty(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "-1") + val e1 = intercept[IllegalArgumentException] { + mysqlDialect.validateProperties(properties.asScala.toMap) + }.getMessage + val e2 = intercept[IllegalArgumentException] { + h2Dialect.validateProperties(properties.asScala.toMap) + }.getMessage + properties.setProperty(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "1") + mysqlDialect.validateProperties(properties.asScala.toMap) + h2Dialect.validateProperties(properties.asScala.toMap) + properties.setProperty(JDBCOptions.JDBC_BATCH_FETCH_SIZE, Integer.MIN_VALUE.toString) + mysqlDialect.validateProperties(properties.asScala.toMap) + val e3 = intercept[IllegalArgumentException] { + h2Dialect.validateProperties(properties.asScala.toMap) + }.getMessage + assert(e1.contains("Invalid value `-1` for parameter `fetchsize`")) + assert(e2.contains("Invalid value `-1` for parameter `fetchsize`")) + assert(e3.contains(s"Invalid value `${Integer.MIN_VALUE.toString}` for parameter `fetchsize`")) + } + test("Missing partition columns") { withView("tempPeople") { val e = intercept[IllegalArgumentException] {