diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 591096d5efd2..fb021c1e245a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, DriverManager} import java.util.{Locale, Properties} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap /** @@ -33,6 +34,14 @@ class JDBCOptions( def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + @DeveloperApi + def this(url: String) = { + this(CaseInsensitiveMap(Map( + JDBCOptions.JDBC_URL -> url, + JDBCOptions.JDBC_DRIVER_CLASS -> "org.h2.Driver", + JDBCOptions.JDBC_TABLE_NAME -> ""))) + } + def this(url: String, table: String, parameters: Map[String, String]) = { this(CaseInsensitiveMap(parameters ++ Map( JDBCOptions.JDBC_URL -> url, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 2bdc43254133..4b1c98859264 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -52,9 +52,8 @@ object JDBCRDD extends Logging { * @throws SQLException if the table contains an unsupported type. */ def resolveTable(options: JDBCOptions): StructType = { - val url = options.url val table = options.table - val dialect = JdbcDialects.get(url) + val dialect = JdbcDialects.get(options) val conn: Connection = JdbcUtils.createConnectionFactory(options)() try { val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) @@ -167,8 +166,7 @@ object JDBCRDD extends Logging { filters: Array[Filter], parts: Array[Partition], options: JDBCOptions): RDD[InternalRow] = { - val url = options.url - val dialect = JdbcDialects.get(url) + val dialect = JdbcDialects.get(options) val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) new JDBCRDD( sc, @@ -177,7 +175,7 @@ object JDBCRDD extends Logging { quotedColumns, filters, parts, - url, + options.url, options) } } @@ -217,7 +215,7 @@ private[jdbc] class JDBCRDD( */ private val filterWhereClause: String = filters - .flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url))) + .flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(options))) .map(p => s"($p)").mkString(" AND ") /** @@ -284,7 +282,7 @@ private[jdbc] class JDBCRDD( val inputMetrics = context.taskMetrics().inputMetrics val part = thePart.asInstanceOf[JDBCPartition] conn = getConnection() - val dialect = JdbcDialects.get(url) + val dialect = JdbcDialects.get(options) import scala.collection.JavaConverters._ dialect.beforeFetch(conn, options.asConnectionProperties.asScala.toMap) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index a06f1ce3287e..35d99b17fc41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -114,7 +114,7 @@ private[sql] case class JDBCRelation( // Check if JDBCRDD.compileFilter can accept input filters override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { - filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) + filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions)).isEmpty) } override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 74dcfb06f5c2..bf7720bb1c9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -61,7 +61,7 @@ class JdbcRelationProvider extends CreatableRelationProvider if (tableExists) { mode match { case SaveMode.Overwrite => - if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) { + if (options.isTruncate && isCascadingTruncateTable(options) == Some(false)) { // In this case, we should truncate table and then load. truncateTable(conn, options.table) val tableSchema = JdbcUtils.getSchemaOption(conn, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index ca61c2efe2dd..96598b3f19af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -66,7 +66,7 @@ object JdbcUtils extends Logging { * Returns true if the table already exists in the JDBC database. */ def tableExists(conn: Connection, options: JDBCOptions): Boolean = { - val dialect = JdbcDialects.get(options.url) + val dialect = JdbcDialects.get(options) // Somewhat hacky, but there isn't a good way to identify whether a table exists for all // SQL database systems using JDBC meta data calls, considering "table" could also include @@ -105,8 +105,8 @@ object JdbcUtils extends Logging { } } - def isCascadingTruncateTable(url: String): Option[Boolean] = { - JdbcDialects.get(url).isCascadingTruncateTable() + def isCascadingTruncateTable(options: JDBCOptions): Option[Boolean] = { + JdbcDialects.get(options).isCascadingTruncateTable() } /** @@ -247,7 +247,7 @@ object JdbcUtils extends Logging { * Returns the schema if the table already exists in the JDBC database. */ def getSchemaOption(conn: Connection, options: JDBCOptions): Option[StructType] = { - val dialect = JdbcDialects.get(options.url) + val dialect = JdbcDialects.get(options) try { val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table)) @@ -702,10 +702,10 @@ object JdbcUtils extends Logging { */ def schemaString( df: DataFrame, - url: String, + options: JDBCOptions, createTableColumnTypes: Option[String] = None): String = { val sb = new StringBuilder() - val dialect = JdbcDialects.get(url) + val dialect = JdbcDialects.get(options) val userSpecifiedColTypesMap = createTableColumnTypes .map(parseUserSpecifiedCreateTableColumnTypes(df, _)) .getOrElse(Map.empty[String, String]) @@ -772,9 +772,8 @@ object JdbcUtils extends Logging { tableSchema: Option[StructType], isCaseSensitive: Boolean, options: JDBCOptions): Unit = { - val url = options.url val table = options.table - val dialect = JdbcDialects.get(url) + val dialect = JdbcDialects.get(options) val rddSchema = df.schema val getConnection: () => Connection = createConnectionFactory(options) val batchSize = options.batchSize @@ -801,7 +800,7 @@ object JdbcUtils extends Logging { df: DataFrame, options: JDBCOptions): Unit = { val strSchema = schemaString( - df, options.url, options.createTableColumnTypes) + df, options, options.createTableColumnTypes) val table = options.table val createTableOptions = options.createTableOptions // Create the table if the table does not exist. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 467d8d62d1b7..ec5cec106cfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.jdbc +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types.{DataType, MetadataBuilder} /** @@ -30,8 +31,8 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect require(dialects.nonEmpty) - override def canHandle(url : String): Boolean = - dialects.map(_.canHandle(url)).reduce(_ && _) + override def canHandle(options: JDBCOptions): Boolean = + dialects.map(_.canHandle(options)).reduce(_ && _) override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 190463df0d92..235c7acd9e45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.jdbc +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types.{BooleanType, DataType, StringType} private object DB2Dialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") + override def canHandle(options: JDBCOptions): Boolean = options.url.startsWith("jdbc:db2") override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 84f68e779c38..e222b931dba2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.jdbc import java.sql.Types +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types._ private object DerbyDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby") + override def canHandle(options: JDBCOptions): Boolean = options.url.startsWith("jdbc:derby") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index a86a86d40890..eb4b14480b76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.Connection import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types._ /** @@ -58,11 +59,11 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) abstract class JdbcDialect extends Serializable { /** * Check if this dialect instance can handle a certain jdbc url. - * @param url the jdbc url. + * @param options the jdbc options. * @return True if the dialect can be applied on the given jdbc url. * @throws NullPointerException if the url is null. */ - def canHandle(url : String): Boolean + def canHandle(options: JDBCOptions): Boolean /** * Get the custom datatype mapping for the given jdbc meta information. @@ -179,8 +180,8 @@ object JdbcDialects { /** * Fetch the JdbcDialect class corresponding to a given database url. */ - def get(url: String): JdbcDialect = { - val matchingDialects = dialects.filter(_.canHandle(url)) + def get(options: JDBCOptions): JdbcDialect = { + val matchingDialects = dialects.filter(_.canHandle(options)) matchingDialects.length match { case 0 => NoopDialect case 1 => matchingDialects.head @@ -193,5 +194,5 @@ object JdbcDialects { * NOOP dialect object, always returning the neutral element. */ private object NoopDialect extends JdbcDialect { - override def canHandle(url : String): Boolean = true + override def canHandle(options: JDBCOptions): Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index da787b4859a7..8777156e98f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.jdbc +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types._ private object MsSqlServerDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") + override def canHandle(options: JDBCOptions): Boolean = options.url.startsWith("jdbc:sqlserver") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index b2cff7877d8b..6b7d46250676 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.jdbc import java.sql.Types +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types.{BooleanType, DataType, LongType, MetadataBuilder} private case object MySQLDialect extends JdbcDialect { - override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + override def canHandle(options: JDBCOptions): Boolean = options.url.startsWith("jdbc:mysql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index f541996b651e..3945d8777748 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -19,17 +19,25 @@ package org.apache.spark.sql.jdbc import java.sql.Types +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types._ private case object OracleDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") + private var isAutoConvertNumber2Boolean: Boolean = true + + override def canHandle(options: JDBCOptions): Boolean = { + isAutoConvertNumber2Boolean = + options.asProperties.getProperty("autoConvertNumber2Boolean", true.toString).toBoolean + options.url.startsWith("jdbc:oracle") + } override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.NUMERIC) { - val scale = if (null != md) md.build().getLong("scale") else 0L + val scale = + if (null != md && md.build().contains("scale")) md.build().getLong("scale") else 0L size match { // Handle NUMBER fields that have no precision/scale in special way // because JDBC ResultSetMetaData converts this to 0 precision and -127 scale @@ -43,7 +51,8 @@ private case object OracleDialect extends JdbcDialect { // Not sure if there is a more robust way to identify the field as a float (or other // numeric types that do not specify a scale. case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - case 1 => Option(BooleanType) + case 1 if isAutoConvertNumber2Boolean => Option(BooleanType) + case 1 if !isAutoConvertNumber2Boolean => Option(IntegerType) case 3 | 5 | 10 => Option(IntegerType) case 19 if scale == 0L => Option(LongType) case 19 if scale == 4L => Option(FloatType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 4f61a328f47c..6baf6f9a85ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types._ private object PostgresDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") + override def canHandle(options: JDBCOptions): Boolean = options.url.startsWith("jdbc:postgresql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 5749b791fca2..e93b23d4a236 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.jdbc import java.sql.Types +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types._ private case object TeradataDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = { url.startsWith("jdbc:teradata") } + override def canHandle(options: JDBCOptions): Boolean = { + options.url.startsWith("jdbc:teradata") + } override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 70bee929b31d..5c357039e334 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -48,7 +48,7 @@ class JDBCSuite extends SparkFunSuite val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) val testH2Dialect = new JdbcDialect { - override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") + override def canHandle(options: JDBCOptions) : Boolean = options.url.startsWith("jdbc:h2") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = Some(StringType) @@ -634,18 +634,18 @@ class JDBCSuite extends SparkFunSuite } test("Default jdbc dialect registration") { - assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) - assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) - assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") == DB2Dialect) - assert(JdbcDialects.get("jdbc:sqlserver://127.0.0.1/db") == MsSqlServerDialect) - assert(JdbcDialects.get("jdbc:derby:db") == DerbyDialect) - assert(JdbcDialects.get("test.invalid") == NoopDialect) + assert(JdbcDialects.get(new JDBCOptions("jdbc:mysql://127.0.0.1/db")) == MySQLDialect) + assert(JdbcDialects.get(new JDBCOptions("jdbc:postgresql://127.0.0.1/db")) == PostgresDialect) + assert(JdbcDialects.get(new JDBCOptions("jdbc:db2://127.0.0.1/db")) == DB2Dialect) + assert(JdbcDialects.get(new JDBCOptions("jdbc:sqlserver://127.0.0.1/db")) == MsSqlServerDialect) + assert(JdbcDialects.get(new JDBCOptions("jdbc:derby:db")) == DerbyDialect) + assert(JdbcDialects.get(new JDBCOptions("test.invalid")) == NoopDialect) } test("quote column names by jdbc dialect") { - val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") - val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") - val Derby = JdbcDialects.get("jdbc:derby:db") + val MySQL = JdbcDialects.get(new JDBCOptions("jdbc:mysql://127.0.0.1/db")) + val Postgres = JdbcDialects.get(new JDBCOptions("jdbc:postgresql://127.0.0.1/db")) + val Derby = JdbcDialects.get(new JDBCOptions("jdbc:derby:db")) val columns = Seq("abc", "key") val MySQLColumns = columns.map(MySQL.quoteIdentifier(_)) @@ -659,7 +659,8 @@ class JDBCSuite extends SparkFunSuite test("compile filters") { val compileFilter = PrivateMethod[Option[String]]('compileFilter) def doCompileFilter(f: Filter): String = - JDBCRDD invokePrivate compileFilter(f, JdbcDialects.get("jdbc:")) getOrElse("") + JDBCRDD invokePrivate compileFilter(f, + JdbcDialects.get(new JDBCOptions("jdbc:"))) getOrElse("") assert(doCompileFilter(EqualTo("col0", 3)) === """"col0" = 3""") assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === """(NOT ("col1" = 'abc'))""") assert(doCompileFilter(And(EqualTo("col0", 0), EqualTo("col1", "def"))) @@ -689,12 +690,12 @@ class JDBCSuite extends SparkFunSuite test("Dialect unregister") { JdbcDialects.registerDialect(testH2Dialect) JdbcDialects.unregisterDialect(testH2Dialect) - assert(JdbcDialects.get(urlWithUserAndPass) == NoopDialect) + assert(JdbcDialects.get(new JDBCOptions(urlWithUserAndPass)) == NoopDialect) } test("Aggregated dialects") { val agg = new AggregatedDialect(List(new JdbcDialect { - override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") + override def canHandle(options: JDBCOptions) : Boolean = options.url.startsWith("jdbc:h2:") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = if (sqlType % 2 == 0) { @@ -703,20 +704,20 @@ class JDBCSuite extends SparkFunSuite None } }, testH2Dialect)) - assert(agg.canHandle("jdbc:h2:xxx")) - assert(!agg.canHandle("jdbc:h2")) + assert(agg.canHandle(new JDBCOptions("jdbc:h2:xxx"))) + assert(!agg.canHandle(new JDBCOptions("jdbc:h2"))) assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } test("DB2Dialect type mapping") { - val db2Dialect = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val db2Dialect = JdbcDialects.get(new JDBCOptions("jdbc:db2://127.0.0.1/db")) assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") } test("PostgresDialect type mapping") { - val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val Postgres = JdbcDialects.get(new JDBCOptions("jdbc:postgresql://127.0.0.1/db")) assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType)) assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType)) assert(Postgres.getJDBCType(FloatType).map(_.databaseTypeDefinition).get == "FLOAT4") @@ -728,27 +729,39 @@ class JDBCSuite extends SparkFunSuite } test("DerbyDialect jdbc type mapping") { - val derbyDialect = JdbcDialects.get("jdbc:derby:db") + val derbyDialect = JdbcDialects.get(new JDBCOptions("jdbc:derby:db")) assert(derbyDialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") assert(derbyDialect.getJDBCType(ByteType).map(_.databaseTypeDefinition).get == "SMALLINT") assert(derbyDialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "BOOLEAN") } test("OracleDialect jdbc type mapping") { - val oracleDialect = JdbcDialects.get("jdbc:oracle") + val jdbcOptions = new JDBCOptions("jdbc:oracle") + val oracleDialect = JdbcDialects.get(jdbcOptions) val metadata = new MetadataBuilder().putString("name", "test_column").putLong("scale", -127) - assert(oracleDialect.getCatalystType(java.sql.Types.NUMERIC, "float", 1, metadata) == + assert(oracleDialect.getCatalystType(java.sql.Types.NUMERIC, "float", 1, metadata) === Some(DecimalType(DecimalType.MAX_PRECISION, 10))) - assert(oracleDialect.getCatalystType(java.sql.Types.NUMERIC, "numeric", 0, null) == + assert(oracleDialect.getCatalystType(java.sql.Types.NUMERIC, "numeric", 0, null) === Some(DecimalType(DecimalType.MAX_PRECISION, 10))) + + val metadataN1 = new MetadataBuilder().putString("name", "test_column") + // number(1) to BooleanType default + val oracleDialectN11 = JdbcDialects.get(jdbcOptions) + assert(oracleDialectN11.getCatalystType(java.sql.Types.NUMERIC, "NUMBER", 1, metadataN1) === + Some(BooleanType)) + // number(1) to IntegerType + jdbcOptions.asProperties.put("autoConvertNumber2Boolean", "false") + val oracleDialect12 = JdbcDialects.get(jdbcOptions) + assert(oracleDialect12.getCatalystType(java.sql.Types.NUMERIC, "NUMBER", 1, metadataN1) === + Some(IntegerType)) } test("table exists query by jdbc dialect") { - val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") - val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") - val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") - val h2 = JdbcDialects.get(url) - val derby = JdbcDialects.get("jdbc:derby:db") + val MySQL = JdbcDialects.get(new JDBCOptions("jdbc:mysql://127.0.0.1/db")) + val Postgres = JdbcDialects.get(new JDBCOptions("jdbc:postgresql://127.0.0.1/db")) + val db2 = JdbcDialects.get(new JDBCOptions("jdbc:db2://127.0.0.1/db")) + val h2 = JdbcDialects.get(new JDBCOptions(url)) + val derby = JdbcDialects.get(new JDBCOptions("jdbc:derby:db")) val table = "weblogs" val defaultQuery = s"SELECT * FROM $table WHERE 1=0" val limitQuery = s"SELECT 1 FROM $table LIMIT 1" @@ -820,7 +833,7 @@ class JDBCSuite extends SparkFunSuite } test("SPARK 12941: The data type mapping for StringType to Oracle") { - val oracleDialect = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val oracleDialect = JdbcDialects.get(new JDBCOptions("jdbc:oracle://127.0.0.1/db")) assert(oracleDialect.getJDBCType(StringType). map(_.databaseTypeDefinition).get == "VARCHAR2(255)") } @@ -832,7 +845,7 @@ class JDBCSuite extends SparkFunSuite map(_.databaseTypeDefinition).get } - val oracleDialect = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val oracleDialect = JdbcDialects.get(new JDBCOptions("jdbc:oracle://127.0.0.1/db")) assert(getJdbcType(oracleDialect, BooleanType) == "NUMBER(1)") assert(getJdbcType(oracleDialect, IntegerType) == "NUMBER(10)") assert(getJdbcType(oracleDialect, LongType) == "NUMBER(19)") @@ -874,7 +887,7 @@ class JDBCSuite extends SparkFunSuite test("SPARK-16387: Reserved SQL words are not escaped by JDBC writer") { val df = spark.createDataset(Seq("a", "b", "c")).toDF("order") - val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp") + val schema = JdbcUtils.schemaString(df, new JDBCOptions("jdbc:mysql://localhost:3306/temp")) assert(schema.contains("`order` TEXT")) } @@ -923,13 +936,13 @@ class JDBCSuite extends SparkFunSuite } test("SPARK-15648: teradataDialect StringType data mapping") { - val teradataDialect = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + val teradataDialect = JdbcDialects.get(new JDBCOptions("jdbc:teradata://127.0.0.1/db")) assert(teradataDialect.getJDBCType(StringType). map(_.databaseTypeDefinition).get == "VARCHAR(255)") } test("SPARK-15648: teradataDialect BooleanType data mapping") { - val teradataDialect = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + val teradataDialect = JdbcDialects.get(new JDBCOptions("jdbc:teradata://127.0.0.1/db")) assert(teradataDialect.getJDBCType(BooleanType). map(_.databaseTypeDefinition).get == "CHAR(1)") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index bf1fd160704f..c697a5be885f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -34,9 +34,9 @@ import org.apache.spark.util.Utils class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { - val url = "jdbc:h2:mem:testdb2" + val jdbcOptions = new JDBCOptions("jdbc:h2:mem:testdb2") var conn: java.sql.Connection = null - val url1 = "jdbc:h2:mem:testdb3" + val jdbcOptions1 = new JDBCOptions("jdbc:h2:mem:testdb3") var conn1: java.sql.Connection = null val properties = new Properties() properties.setProperty("user", "testUser") @@ -44,16 +44,16 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { properties.setProperty("rowId", "false") val testH2Dialect = new JdbcDialect { - override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") + override def canHandle(options: JDBCOptions) : Boolean = options.url.startsWith("jdbc:h2") override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } before { Utils.classForName("org.h2.Driver") - conn = DriverManager.getConnection(url) + conn = DriverManager.getConnection(jdbcOptions.url) conn.prepareStatement("create schema test").executeUpdate() - conn1 = DriverManager.getConnection(url1, properties) + conn1 = DriverManager.getConnection(jdbcOptions1.url, properties) conn1.prepareStatement("create schema test").executeUpdate() conn1.prepareStatement("drop table if exists test.people").executeUpdate() conn1.prepareStatement( @@ -69,14 +69,16 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { s""" |CREATE OR REPLACE TEMPORARY VIEW PEOPLE |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') + |OPTIONS (url '${jdbcOptions1.url}', + |dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) sql( s""" |CREATE OR REPLACE TEMPORARY VIEW PEOPLE1 |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE1', user 'testUser', password 'testPass') + |OPTIONS (url '${jdbcOptions1.url}', + |dbtable 'TEST.PEOPLE1', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) } @@ -104,10 +106,13 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { test("Basic CREATE") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties()) - assert(2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).count()) - assert( - 2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).collect()(0).length) + df.write.jdbc(jdbcOptions.url, "TEST.BASICCREATETEST", new Properties()) + assert(2 === + spark.read.jdbc(jdbcOptions.url, "TEST.BASICCREATETEST", new Properties()).count()) + assert(2 === + spark.read.jdbc(jdbcOptions.url, + "TEST.BASICCREATETEST", + new Properties()).collect()(0).length) } test("Basic CREATE with illegal batchsize") { @@ -117,7 +122,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val properties = new Properties() properties.setProperty(JDBCOptions.JDBC_BATCH_INSERT_SIZE, size.toString) val e = intercept[IllegalArgumentException] { - df.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.BASICCREATETEST", properties) + df.write.mode(SaveMode.Overwrite).jdbc(jdbcOptions.url, "TEST.BASICCREATETEST", properties) }.getMessage assert(e.contains(s"Invalid value `$size` for parameter `batchsize`")) } @@ -129,8 +134,9 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { (1 to 3).foreach { size => val properties = new Properties() properties.setProperty(JDBCOptions.JDBC_BATCH_INSERT_SIZE, size.toString) - df.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.BASICCREATETEST", properties) - assert(2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).count()) + df.write.mode(SaveMode.Overwrite).jdbc(jdbcOptions.url, "TEST.BASICCREATETEST", properties) + assert(2 === + spark.read.jdbc(jdbcOptions.url, "TEST.BASICCREATETEST", new Properties()).count()) } } @@ -138,55 +144,57 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val df = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) - df.write.mode(SaveMode.Ignore).jdbc(url1, "TEST.DROPTEST", properties) - assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count()) - assert(3 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + df.write.mode(SaveMode.Ignore).jdbc(jdbcOptions1.url, "TEST.DROPTEST", properties) + assert(2 === spark.read.jdbc(jdbcOptions1.url, "TEST.DROPTEST", properties).count()) + assert(3 === spark.read.jdbc(jdbcOptions1.url, "TEST.DROPTEST", properties).collect()(0).length) - df2.write.mode(SaveMode.Ignore).jdbc(url1, "TEST.DROPTEST", properties) - assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count()) - assert(3 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + df2.write.mode(SaveMode.Ignore).jdbc(jdbcOptions1.url, "TEST.DROPTEST", properties) + assert(2 === spark.read.jdbc(jdbcOptions1.url, "TEST.DROPTEST", properties).count()) + assert(3 === spark.read.jdbc(jdbcOptions1.url, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE with overwrite") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) - df.write.jdbc(url1, "TEST.DROPTEST", properties) - assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count()) - assert(3 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + df.write.jdbc(jdbcOptions1.url, "TEST.DROPTEST", properties) + assert(2 === spark.read.jdbc(jdbcOptions1.url, "TEST.DROPTEST", properties).count()) + assert(3 === spark.read.jdbc(jdbcOptions1.url, "TEST.DROPTEST", properties).collect()(0).length) - df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) - assert(1 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count()) - assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + df2.write.mode(SaveMode.Overwrite).jdbc(jdbcOptions1.url, "TEST.DROPTEST", properties) + assert(1 === spark.read.jdbc(jdbcOptions1.url, "TEST.DROPTEST", properties).count()) + assert(2 === spark.read.jdbc(jdbcOptions1.url, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) - df.write.jdbc(url, "TEST.APPENDTEST", new Properties()) - df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties()) - assert(3 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).count()) - assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length) + df.write.jdbc(jdbcOptions.url, "TEST.APPENDTEST", new Properties()) + df2.write.mode(SaveMode.Append).jdbc(jdbcOptions.url, "TEST.APPENDTEST", new Properties()) + assert(3 === spark.read.jdbc(jdbcOptions.url, "TEST.APPENDTEST", new Properties()).count()) + assert(2 === + spark.read.jdbc(jdbcOptions.url, "TEST.APPENDTEST", new Properties()).collect()(0).length) } test("SPARK-18123 Append with column names with different cases") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema4) - df.write.jdbc(url, "TEST.APPENDTEST", new Properties()) + df.write.jdbc(jdbcOptions.url, "TEST.APPENDTEST", new Properties()) withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { val m = intercept[AnalysisException] { - df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties()) + df2.write.mode(SaveMode.Append).jdbc(jdbcOptions.url, "TEST.APPENDTEST", new Properties()) }.getMessage assert(m.contains("Column \"NAME\" not found")) } withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties()) - assert(3 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).count()) - assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length) + df2.write.mode(SaveMode.Append).jdbc(jdbcOptions.url, "TEST.APPENDTEST", new Properties()) + assert(3 === spark.read.jdbc(jdbcOptions.url, "TEST.APPENDTEST", new Properties()).count()) + assert(2 === + spark.read.jdbc(jdbcOptions.url, "TEST.APPENDTEST", new Properties()).collect()(0).length) } } @@ -196,18 +204,19 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) val df3 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) - df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) + df.write.jdbc(jdbcOptions1.url, "TEST.TRUNCATETEST", properties) df2.write.mode(SaveMode.Overwrite).option("truncate", true) - .jdbc(url1, "TEST.TRUNCATETEST", properties) - assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count()) - assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + .jdbc(jdbcOptions1.url, "TEST.TRUNCATETEST", properties) + assert(1 === spark.read.jdbc(jdbcOptions1.url, "TEST.TRUNCATETEST", properties).count()) + assert(2 === + spark.read.jdbc(jdbcOptions1.url, "TEST.TRUNCATETEST", properties).collect()(0).length) val m = intercept[AnalysisException] { df3.write.mode(SaveMode.Overwrite).option("truncate", true) - .jdbc(url1, "TEST.TRUNCATETEST", properties) + .jdbc(jdbcOptions1.url, "TEST.TRUNCATETEST", properties) }.getMessage assert(m.contains("Column \"seq\" not found")) - assert(0 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count()) + assert(0 === spark.read.jdbc(jdbcOptions1.url, "TEST.TRUNCATETEST", properties).count()) JdbcDialects.unregisterDialect(testH2Dialect) } @@ -217,7 +226,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val m = intercept[org.h2.jdbc.JdbcSQLException] { df.write.option("createTableOptions", "ENGINE tableEngineName") - .jdbc(url1, "TEST.CREATETBLOPTS", properties) + .jdbc(jdbcOptions1.url, "TEST.CREATETBLOPTS", properties) }.getMessage assert(m.contains("Class \"TABLEENGINENAME\" not found")) JdbcDialects.unregisterDialect(testH2Dialect) @@ -227,36 +236,37 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) - df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties()) + df.write.jdbc(jdbcOptions.url, "TEST.INCOMPATIBLETEST", new Properties()) val m = intercept[AnalysisException] { - df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties()) + df2.write.mode(SaveMode.Append) + .jdbc(jdbcOptions.url, "TEST.INCOMPATIBLETEST", new Properties()) }.getMessage assert(m.contains("Column \"seq\" not found")) } test("INSERT to JDBC Datasource") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count()) - assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === spark.read.jdbc(jdbcOptions1.url, "TEST.PEOPLE1", properties).count()) + assert(2 === spark.read.jdbc(jdbcOptions1.url, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count()) - assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === spark.read.jdbc(jdbcOptions1.url, "TEST.PEOPLE1", properties).count()) + assert(2 === spark.read.jdbc(jdbcOptions1.url, "TEST.PEOPLE1", properties).collect()(0).length) } test("save works for format(\"jdbc\") if url and dbtable are set") { val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) df.write.format("jdbc") - .options(Map("url" -> url, "dbtable" -> "TEST.SAVETEST")) + .options(Map("url" -> jdbcOptions.url, "dbtable" -> "TEST.SAVETEST")) .save() - assert(2 === sqlContext.read.jdbc(url, "TEST.SAVETEST", new Properties).count) - assert( - 2 === sqlContext.read.jdbc(url, "TEST.SAVETEST", new Properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(jdbcOptions.url, "TEST.SAVETEST", new Properties).count) + assert(2 === + sqlContext.read.jdbc(jdbcOptions.url, "TEST.SAVETEST", new Properties).collect()(0).length) } test("save API with SaveMode.Overwrite") { @@ -264,17 +274,17 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.format("jdbc") - .option("url", url1) + .option("url", jdbcOptions1.url) .option("dbtable", "TEST.SAVETEST") .options(properties.asScala) .save() df2.write.mode(SaveMode.Overwrite).format("jdbc") - .option("url", url1) + .option("url", jdbcOptions1.url) .option("dbtable", "TEST.SAVETEST") .options(properties.asScala) .save() - assert(1 === spark.read.jdbc(url1, "TEST.SAVETEST", properties).count()) - assert(2 === spark.read.jdbc(url1, "TEST.SAVETEST", properties).collect()(0).length) + assert(1 === spark.read.jdbc(jdbcOptions1.url, "TEST.SAVETEST", properties).count()) + assert(2 === spark.read.jdbc(jdbcOptions1.url, "TEST.SAVETEST", properties).collect()(0).length) } test("save errors if url is not specified") { @@ -294,7 +304,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val e = intercept[RuntimeException] { df.write.format("jdbc") - .option("url", url1) + .option("url", jdbcOptions1.url) .options(properties.asScala) .save() }.getMessage @@ -307,7 +317,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val e = intercept[org.h2.jdbc.JdbcSQLException] { df.write.format("jdbc") .option("dbtable", "TEST.SAVETEST") - .option("url", url1) + .option("url", jdbcOptions1.url) .save() }.getMessage assert(e.contains("Wrong user name or password")) @@ -319,7 +329,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val e = intercept[java.lang.IllegalArgumentException] { df.write.format("jdbc") .option("dbtable", "TEST.SAVETEST") - .option("url", url1) + .option("url", jdbcOptions1.url) .option("partitionColumn", "foo") .save() }.getMessage @@ -330,7 +340,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) df.write.format("jdbc") - .option("Url", url1) + .option("Url", jdbcOptions1.url) .option("dbtable", "TEST.SAVETEST") .options(properties.asScala) .save() @@ -341,7 +351,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val e = intercept[IllegalArgumentException] { df.write.format("jdbc") .option("dbtable", "TEST.SAVETEST") - .option("url", url1) + .option("url", jdbcOptions1.url) .option("user", "testUser") .option("password", "testPass") .option(s"${JDBCOptions.JDBC_NUM_PARTITIONS}", "0") @@ -357,7 +367,8 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { s""" |CREATE TEMPORARY VIEW people_view |USING org.apache.spark.sql.jdbc - |OPTIONS (uRl '$url1', DbTaBlE 'TEST.PEOPLE1', User 'testUser', PassWord 'testPass') + |OPTIONS (uRl '${jdbcOptions1.url}', + |DbTaBlE 'TEST.PEOPLE1', User 'testUser', PassWord 'testPass') """.stripMargin.replaceAll("\n", " ")) sql("INSERT OVERWRITE TABLE PEOPLE_VIEW SELECT * FROM PEOPLE") assert(sql("select * from people_view").count() == 2) @@ -376,7 +387,8 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val expectedSchemaStr = colTypes.map { case (col, dataType) => s""""$col" $dataType """ }.mkString(", ") - assert(JdbcUtils.schemaString(df, url1, Option(createTableColTypes)) == expectedSchemaStr) + assert(JdbcUtils.schemaString(df, jdbcOptions1, Option(createTableColTypes)) == + expectedSchemaStr) } testCreateTableColDataTypes(Seq("boolean")) @@ -396,7 +408,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { df.write .mode(SaveMode.Overwrite) .option("createTableColumnTypes", createTableColTypes) - .jdbc(url1, "TEST.DBCOLTYPETEST", properties) + .jdbc(jdbcOptions1.url, "TEST.DBCOLTYPETEST", properties) // verify the data types of the created table by reading the database catalog of H2 val query = @@ -404,7 +416,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { |(SELECT column_name, type_name, character_maximum_length | FROM information_schema.columns WHERE table_name = 'DBCOLTYPETEST') """.stripMargin - val rows = spark.read.jdbc(url1, query, properties).collect() + val rows = spark.read.jdbc(jdbcOptions1.url, query, properties).collect() rows.foreach { row => val typeName = row.getString(1) @@ -454,7 +466,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val msg = intercept[ParseException] { df.write.mode(SaveMode.Overwrite) .option("createTableColumnTypes", "name CLOB(2000)") - .jdbc(url1, "TEST.USERDBTYPETEST", properties) + .jdbc(jdbcOptions1.url, "TEST.USERDBTYPETEST", properties) }.getMessage() assert(msg.contains("DataType clob(2000) is not supported.")) } @@ -464,7 +476,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val msg = intercept[ParseException] { df.write.mode(SaveMode.Overwrite) .option("createTableColumnTypes", "`name char(20)") // incorrectly quoted column - .jdbc(url1, "TEST.USERDBTYPETEST", properties) + .jdbc(jdbcOptions1.url, "TEST.USERDBTYPETEST", properties) }.getMessage() assert(msg.contains("no viable alternative at input")) } @@ -475,7 +487,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val msg = intercept[AnalysisException] { df.write.mode(SaveMode.Overwrite) .option("createTableColumnTypes", "name CHAR(20), id int, NaMe VARCHAR(100)") - .jdbc(url1, "TEST.USERDBTYPETEST", properties) + .jdbc(jdbcOptions1.url, "TEST.USERDBTYPETEST", properties) }.getMessage() assert(msg.contains( "Found duplicate column(s) in createTableColumnTypes option value: name, NaMe")) @@ -490,7 +502,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val msg = intercept[AnalysisException] { df.write.mode(SaveMode.Overwrite) .option("createTableColumnTypes", "firstName CHAR(20), id int") - .jdbc(url1, "TEST.USERDBTYPETEST", properties) + .jdbc(jdbcOptions1.url, "TEST.USERDBTYPETEST", properties) }.getMessage() assert(msg.contains("createTableColumnTypes option column firstName not found in " + "schema struct")) @@ -500,7 +512,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val msg = intercept[AnalysisException] { df.write.mode(SaveMode.Overwrite) .option("createTableColumnTypes", "id int, Name VARCHAR(100)") - .jdbc(url1, "TEST.USERDBTYPETEST", properties) + .jdbc(jdbcOptions1.url, "TEST.USERDBTYPETEST", properties) }.getMessage() assert(msg.contains("createTableColumnTypes option column Name not found in " + "schema struct"))