diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index d160ad82888a..ab574df4557a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.jdbc import java.sql.Types +import java.util.Locale import org.apache.spark.sql.types._ private object DB2Dialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") + override def canHandle(url: String): Boolean = + url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2") override def getCatalystType( sqlType: Int, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index d13c29ed46bd..d528d5a9fef5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -18,13 +18,15 @@ package org.apache.spark.sql.jdbc import java.sql.Types +import java.util.Locale import org.apache.spark.sql.types._ private object DerbyDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby") + override def canHandle(url: String): Boolean = + url.toLowerCase(Locale.ROOT).startsWith("jdbc:derby") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 805f73dee141..2511067abc3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.jdbc +import java.util.Locale + import org.apache.spark.sql.types._ private object MsSqlServerDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") + override def canHandle(url: String): Boolean = + url.toLowerCase(Locale.ROOT).startsWith("jdbc:sqlserver") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index b2cff7877d8b..24b31b14d942 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.jdbc import java.sql.Types +import java.util.Locale import org.apache.spark.sql.types.{BooleanType, DataType, LongType, MetadataBuilder} private case object MySQLDialect extends JdbcDialect { - override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + override def canHandle(url : String): Boolean = + url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index f4a6d0a4d2e4..4c0623729e00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Date, Timestamp, Types} -import java.util.TimeZone +import java.util.{Locale, TimeZone} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf @@ -30,7 +30,8 @@ private case object OracleDialect extends JdbcDialect { private[jdbc] val BINARY_DOUBLE = 101 private[jdbc] val TIMESTAMPTZ = -101 - override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") + override def canHandle(url: String): Boolean = + url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle") private def supportTimeZoneTypes: Boolean = { val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 2645e4c9d528..c8d8a3392128 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, Types} +import java.util.Locale import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types._ @@ -25,7 +26,8 @@ import org.apache.spark.sql.types._ private object PostgresDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") + override def canHandle(url: String): Boolean = + url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 6c17bd7ed9ec..552d7a484f3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.util.Locale import org.apache.spark.sql.types._ private case object TeradataDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = { url.startsWith("jdbc:teradata") } + override def canHandle(url: String): Boolean = + url.toLowerCase(Locale.ROOT).startsWith("jdbc:teradata") override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 472bdda01d13..43f6381c1979 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -51,7 +51,7 @@ class JDBCSuite extends QueryTest val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) val testH2Dialect = new JdbcDialect { - override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") + override def canHandle(url: String): Boolean = url.startsWith("jdbc:h2") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = Some(StringType) @@ -1662,4 +1662,21 @@ class JDBCSuite extends QueryTest "Invalid value `test` for parameter `isolationLevel`. This can be " + "`NONE`, `READ_UNCOMMITTED`, `READ_COMMITTED`, `REPEATABLE_READ` or `SERIALIZABLE`.")) } + + test("SPARK-28552: Case-insensitive database URLs in JdbcDialect") { + assert(JdbcDialects.get("jdbc:mysql://localhost/db") === MySQLDialect) + assert(JdbcDialects.get("jdbc:MySQL://localhost/db") === MySQLDialect) + assert(JdbcDialects.get("jdbc:postgresql://localhost/db") === PostgresDialect) + assert(JdbcDialects.get("jdbc:postGresql://localhost/db") === PostgresDialect) + assert(JdbcDialects.get("jdbc:db2://localhost/db") === DB2Dialect) + assert(JdbcDialects.get("jdbc:DB2://localhost/db") === DB2Dialect) + assert(JdbcDialects.get("jdbc:sqlserver://localhost/db") === MsSqlServerDialect) + assert(JdbcDialects.get("jdbc:sqlServer://localhost/db") === MsSqlServerDialect) + assert(JdbcDialects.get("jdbc:derby://localhost/db") === DerbyDialect) + assert(JdbcDialects.get("jdbc:derBy://localhost/db") === DerbyDialect) + assert(JdbcDialects.get("jdbc:oracle://localhost/db") === OracleDialect) + assert(JdbcDialects.get("jdbc:Oracle://localhost/db") === OracleDialect) + assert(JdbcDialects.get("jdbc:teradata://localhost/db") === TeradataDialect) + assert(JdbcDialects.get("jdbc:Teradata://localhost/db") === TeradataDialect) + } }