From 708a320747e25eedecd8e75d89f26bbc09784c2a Mon Sep 17 00:00:00 2001 From: Daniel van der Ende Date: Wed, 6 Dec 2017 15:06:25 +0100 Subject: [PATCH 1/2] [SPARK-22729][SQL] Add getTruncateQuery to JdbcDialect In order to enable truncate for PostgreSQL databases in Spark JDBC, a change is needed to the query used for truncating a PostgreSQL table. By default, PostgreSQL will automatically truncate any descendant tables if a TRUNCATE query is executed. As this may result in (unwanted) side-effects, the query used for the truncate should be specified separately for PostgreSQL, specifying only to TRUNCATE a single table. This change replaces the isCascadingTruncateTable by a getTruncateQuery method for each dialect, that needs to be implemented in each dialect. --- .../jdbc/JdbcRelationProvider.scala | 4 +- .../datasources/jdbc/JdbcUtils.scala | 10 ++-- .../spark/sql/jdbc/AggregatedDialect.scala | 12 ++--- .../apache/spark/sql/jdbc/DB2Dialect.scala | 4 +- .../apache/spark/sql/jdbc/DerbyDialect.scala | 4 ++ .../apache/spark/sql/jdbc/JdbcDialects.scala | 18 ++++---- .../spark/sql/jdbc/MsSqlServerDialect.scala | 4 +- .../apache/spark/sql/jdbc/MySQLDialect.scala | 4 +- .../apache/spark/sql/jdbc/OracleDialect.scala | 4 +- .../spark/sql/jdbc/PostgresDialect.scala | 14 ++++-- .../spark/sql/jdbc/TeradataDialect.scala | 4 ++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 46 ++++++++----------- .../spark/sql/jdbc/JDBCWriteSuite.scala | 1 - 13 files changed, 69 insertions(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 37e7bb0a59bb..b90c16cccad8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -66,9 +66,9 @@ class JdbcRelationProvider extends CreatableRelationProvider if (tableExists) { mode match { case SaveMode.Overwrite => - if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) { + if (options.isTruncate) { // In this case, we should truncate table and then load. - truncateTable(conn, options.table) + truncateTable(conn, options) val tableSchema = JdbcUtils.getSchemaOption(conn, options) saveTable(df, tableSchema, isCaseSensitive, options) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 75c94fc48649..7796bb9965d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -96,20 +96,18 @@ object JdbcUtils extends Logging { } /** - * Truncates a table from the JDBC database. + * Truncates a table from the JDBC database without side effects. */ - def truncateTable(conn: Connection, table: String): Unit = { + def truncateTable(conn: Connection, options: JDBCOptions): Unit = { + val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { - statement.executeUpdate(s"TRUNCATE TABLE $table") + statement.executeUpdate(dialect.getTruncateQuery(options.table)) } finally { statement.close() } } - def isCascadingTruncateTable(url: String): Option[Boolean] = { - JdbcDialects.get(url).isCascadingTruncateTable() - } /** * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index f3bfea5f6bfc..6a3b67451bc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.types.{DataType, MetadataBuilder} /** * AggregatedDialect can unify multiple dialects into one virtual Dialect. * Dialects are tried in order, and the first dialect that does not return a - * neutral element will will. + * neutral element will win. * * @param dialects List of dialects. */ @@ -54,13 +54,7 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect dialects.head.getSchemaQuery(table) } - override def isCascadingTruncateTable(): Option[Boolean] = { - // If any dialect claims cascading truncate, this dialect is also cascading truncate. - // Otherwise, if any dialect has unknown cascading truncate, this dialect is also unknown. - dialects.flatMap(_.isCascadingTruncateTable()).reduceOption(_ || _) match { - case Some(true) => Some(true) - case _ if dialects.exists(_.isCascadingTruncateTable().isEmpty) => None - case _ => Some(false) - } + override def getTruncateQuery(table: String): String = { + dialects.head.getTruncateQuery(table) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index d160ad82888a..5aa1ad6f007e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -48,5 +48,7 @@ private object DB2Dialect extends JdbcDialect { case _ => None } - override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + override def getTruncateQuery(table: String): String = { + s"TRUNCATE $table" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 84f68e779c38..a38eec5f4109 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -41,4 +41,8 @@ private object DerbyDialect extends JdbcDialect { Option(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) case _ => None } + + override def getTruncateQuery(table: String): String = { + s"TRUNCATE $table" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 7c38ed68c041..f771c7885ebb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -116,6 +116,16 @@ abstract class JdbcDialect extends Serializable { s"SELECT * FROM $table WHERE 1=0" } + /** + * The SQL query that should be used to truncate a table. Dialects can override this method to + * return a query that is suitable for a particular database. For PostgreSQL, for instance, + * a different query is used to prevent "TRUNCATE" affecting other tables. + * @param table The name of the table. + * @return The SQL query to use for truncating a table + */ + @Since("2.3.0") + def getTruncateQuery(table: String): String + /** * Override connection specific properties to run before a select is made. This is in place to * allow dialects that need special treatment to optimize behavior. @@ -147,14 +157,6 @@ abstract class JdbcDialect extends Serializable { case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") case _ => value } - - /** - * Return Some[true] iff `TRUNCATE TABLE` causes cascading default. - * Some[true] : TRUNCATE TABLE causes cascading. - * Some[false] : TRUNCATE TABLE does not cause cascading. - * None: The behavior of TRUNCATE TABLE is unknown (default). - */ - def isCascadingTruncateTable(): Option[Boolean] = None } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index da787b4859a7..7825b2056a41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -41,5 +41,7 @@ private object MsSqlServerDialect extends JdbcDialect { case _ => None } - override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + override def getTruncateQuery(table: String): String = { + s"TRUNCATE $table" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index b2cff7877d8b..b5d7c9e0da83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -45,5 +45,7 @@ private case object MySQLDialect extends JdbcDialect { s"SELECT 1 FROM $table LIMIT 1" } - override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + override def getTruncateQuery(table: String): String = { + s"TRUNCATE $table" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index e3f106c41c7f..a4898aa4574d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -83,5 +83,7 @@ private case object OracleDialect extends JdbcDialect { case _ => value } - override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + override def getTruncateQuery(table: String): String = { + s"TRUNCATE $table" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 4f61a328f47c..69fb2649868f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -85,6 +85,17 @@ private object PostgresDialect extends JdbcDialect { s"SELECT 1 FROM $table LIMIT 1" } + /** + * The SQL query used to truncate a table. For Postgres, the default behaviour is to + * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, + * the Postgres dialect adds 'ONLY' to truncate only the table in question + * @param table The name of the table. + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery(table: String): String = { + s"TRUNCATE ONLY $table" + } + override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { super.beforeFetch(connection, properties) @@ -97,8 +108,5 @@ private object PostgresDialect extends JdbcDialect { if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) { connection.setAutoCommit(false) } - } - - override def isCascadingTruncateTable(): Option[Boolean] = Some(true) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 5749b791fca2..cc546ef017bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -31,4 +31,8 @@ private case object TeradataDialect extends JdbcDialect { case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) case _ => None } + + override def getTruncateQuery(table: String): String = { + s"TRUNCATE $table" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 61571bccdcb5..e7179bbd31f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -752,43 +752,19 @@ class JDBCSuite extends SparkFunSuite override def getSchemaQuery(table: String): String = { s"My $table Schema" } - override def isCascadingTruncateTable(): Option[Boolean] = Some(true) + override def getTruncateQuery(table: String): String = { + s"TRUNCATE $table" + } }, testH2Dialect)) assert(agg.canHandle("jdbc:h2:xxx")) assert(!agg.canHandle("jdbc:h2")) assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) - assert(agg.isCascadingTruncateTable() === Some(true)) assert(agg.quoteIdentifier ("Dummy") === "My Dummy quoteIdentifier") assert(agg.getTableExistsQuery ("Dummy") === "My Dummy Table") assert(agg.getSchemaQuery ("Dummy") === "My Dummy Schema") } - test("Aggregated dialects: isCascadingTruncateTable") { - def genDialect(cascadingTruncateTable: Option[Boolean]): JdbcDialect = new JdbcDialect { - override def canHandle(url: String): Boolean = true - override def getCatalystType( - sqlType: Int, - typeName: String, - size: Int, - md: MetadataBuilder): Option[DataType] = None - override def isCascadingTruncateTable(): Option[Boolean] = cascadingTruncateTable - } - - def testDialects(cascadings: List[Option[Boolean]], expected: Option[Boolean]): Unit = { - val dialects = cascadings.map(genDialect(_)) - val agg = new AggregatedDialect(dialects) - assert(agg.isCascadingTruncateTable() === expected) - } - - testDialects(List(Some(true), Some(false), None), Some(true)) - testDialects(List(Some(true), Some(true), None), Some(true)) - testDialects(List(Some(false), Some(false), None), None) - testDialects(List(Some(true), Some(true)), Some(true)) - testDialects(List(Some(false), Some(false)), Some(false)) - testDialects(List(None, None), None) - } - test("DB2Dialect type mapping") { val db2Dialect = JdbcDialects.get("jdbc:db2://127.0.0.1/db") assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") @@ -854,6 +830,22 @@ class JDBCSuite extends SparkFunSuite assert(derby.getTableExistsQuery(table) == defaultQuery) } + test("truncate table query by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val derby = JdbcDialects.get("jdbc:derby:db") + val table = "weblogs" + val defaultQuery = s"TRUNCATE $table" + val postgresQuery = s"TRUNCATE ONLY $table" + assert(MySQL.getTableExistsQuery(table) == defaultQuery) + assert(Postgres.getTableExistsQuery(table) == postgresQuery) + assert(db2.getTableExistsQuery(table) == defaultQuery) + assert(h2.getTableExistsQuery(table) == defaultQuery) + assert(derby.getTableExistsQuery(table) == defaultQuery) + } + test("Test DataFrame.where for Date and Timestamp") { // Regression test for bug SPARK-11788 val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543"); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 1985b1dc8287..3c69df20f502 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -46,7 +46,6 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val testH2Dialect = new JdbcDialect { override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") - override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } before { From 0b990e1451491c6319194475bc429bf06ab16c21 Mon Sep 17 00:00:00 2001 From: Daniel van der Ende Date: Sat, 9 Dec 2017 17:50:05 +0100 Subject: [PATCH 2/2] Reinstate isCascadingTruncateTable function --- .../jdbc/JdbcRelationProvider.scala | 2 +- .../datasources/jdbc/JdbcUtils.scala | 3 ++ .../spark/sql/jdbc/AggregatedDialect.scala | 10 +++++ .../apache/spark/sql/jdbc/DB2Dialect.scala | 4 +- .../apache/spark/sql/jdbc/DerbyDialect.scala | 4 -- .../apache/spark/sql/jdbc/JdbcDialects.scala | 12 ++++- .../spark/sql/jdbc/MsSqlServerDialect.scala | 4 +- .../apache/spark/sql/jdbc/MySQLDialect.scala | 4 +- .../apache/spark/sql/jdbc/OracleDialect.scala | 4 +- .../spark/sql/jdbc/PostgresDialect.scala | 4 +- .../spark/sql/jdbc/TeradataDialect.scala | 4 -- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 44 ++++++++++++++----- .../spark/sql/jdbc/JDBCWriteSuite.scala | 1 + 13 files changed, 67 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index b90c16cccad8..cc506e51bd0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -66,7 +66,7 @@ class JdbcRelationProvider extends CreatableRelationProvider if (tableExists) { mode match { case SaveMode.Overwrite => - if (options.isTruncate) { + if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) { // In this case, we should truncate table and then load. truncateTable(conn, options) val tableSchema = JdbcUtils.getSchemaOption(conn, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 7796bb9965d3..2193b4cdc383 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -108,6 +108,9 @@ object JdbcUtils extends Logging { } } + def isCascadingTruncateTable(url: String): Option[Boolean] = { + JdbcDialects.get(url).isCascadingTruncateTable() + } /** * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 6a3b67451bc4..8b92c8b4f56b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -54,6 +54,16 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect dialects.head.getSchemaQuery(table) } + override def isCascadingTruncateTable(): Option[Boolean] = { + // If any dialect claims cascading truncate, this dialect is also cascading truncate. + // Otherwise, if any dialect has unknown cascading truncate, this dialect is also unknown. + dialects.flatMap(_.isCascadingTruncateTable()).reduceOption(_ || _) match { + case Some(true) => Some(true) + case _ if dialects.exists(_.isCascadingTruncateTable().isEmpty) => None + case _ => Some(false) + } + } + override def getTruncateQuery(table: String): String = { dialects.head.getTruncateQuery(table) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 5aa1ad6f007e..d160ad82888a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -48,7 +48,5 @@ private object DB2Dialect extends JdbcDialect { case _ => None } - override def getTruncateQuery(table: String): String = { - s"TRUNCATE $table" - } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index a38eec5f4109..84f68e779c38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -41,8 +41,4 @@ private object DerbyDialect extends JdbcDialect { Option(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) case _ => None } - - override def getTruncateQuery(table: String): String = { - s"TRUNCATE $table" - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index f771c7885ebb..83d87a11810c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -124,7 +124,9 @@ abstract class JdbcDialect extends Serializable { * @return The SQL query to use for truncating a table */ @Since("2.3.0") - def getTruncateQuery(table: String): String + def getTruncateQuery(table: String): String = { + s"TRUNCATE TABLE $table" + } /** * Override connection specific properties to run before a select is made. This is in place to @@ -157,6 +159,14 @@ abstract class JdbcDialect extends Serializable { case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") case _ => value } + + /** + * Return Some[true] iff `TRUNCATE TABLE` causes cascading default. + * Some[true] : TRUNCATE TABLE causes cascading. + * Some[false] : TRUNCATE TABLE does not cause cascading. + * None: The behavior of TRUNCATE TABLE is unknown (default). + */ + def isCascadingTruncateTable(): Option[Boolean] = None } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 7825b2056a41..da787b4859a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -41,7 +41,5 @@ private object MsSqlServerDialect extends JdbcDialect { case _ => None } - override def getTruncateQuery(table: String): String = { - s"TRUNCATE $table" - } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index b5d7c9e0da83..b2cff7877d8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -45,7 +45,5 @@ private case object MySQLDialect extends JdbcDialect { s"SELECT 1 FROM $table LIMIT 1" } - override def getTruncateQuery(table: String): String = { - s"TRUNCATE $table" - } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index a4898aa4574d..e3f106c41c7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -83,7 +83,5 @@ private case object OracleDialect extends JdbcDialect { case _ => value } - override def getTruncateQuery(table: String): String = { - s"TRUNCATE $table" - } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 69fb2649868f..13a2035f4d0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -93,7 +93,7 @@ private object PostgresDialect extends JdbcDialect { * @return The SQL query to use for truncating a table */ override def getTruncateQuery(table: String): String = { - s"TRUNCATE ONLY $table" + s"TRUNCATE TABLE ONLY $table" } override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { @@ -109,4 +109,6 @@ private object PostgresDialect extends JdbcDialect { connection.setAutoCommit(false) } } + + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index cc546ef017bb..5749b791fca2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -31,8 +31,4 @@ private case object TeradataDialect extends JdbcDialect { case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) case _ => None } - - override def getTruncateQuery(table: String): String = { - s"TRUNCATE $table" - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index e7179bbd31f8..43b9307e5aab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -752,19 +752,43 @@ class JDBCSuite extends SparkFunSuite override def getSchemaQuery(table: String): String = { s"My $table Schema" } - override def getTruncateQuery(table: String): String = { - s"TRUNCATE $table" - } + override def isCascadingTruncateTable(): Option[Boolean] = Some(true) }, testH2Dialect)) assert(agg.canHandle("jdbc:h2:xxx")) assert(!agg.canHandle("jdbc:h2")) assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) + assert(agg.isCascadingTruncateTable() === Some(true)) assert(agg.quoteIdentifier ("Dummy") === "My Dummy quoteIdentifier") assert(agg.getTableExistsQuery ("Dummy") === "My Dummy Table") assert(agg.getSchemaQuery ("Dummy") === "My Dummy Schema") } + test("Aggregated dialects: isCascadingTruncateTable") { + def genDialect(cascadingTruncateTable: Option[Boolean]): JdbcDialect = new JdbcDialect { + override def canHandle(url: String): Boolean = true + override def getCatalystType( + sqlType: Int, + typeName: String, + size: Int, + md: MetadataBuilder): Option[DataType] = None + override def isCascadingTruncateTable(): Option[Boolean] = cascadingTruncateTable + } + + def testDialects(cascadings: List[Option[Boolean]], expected: Option[Boolean]): Unit = { + val dialects = cascadings.map(genDialect(_)) + val agg = new AggregatedDialect(dialects) + assert(agg.isCascadingTruncateTable() === expected) + } + + testDialects(List(Some(true), Some(false), None), Some(true)) + testDialects(List(Some(true), Some(true), None), Some(true)) + testDialects(List(Some(false), Some(false), None), None) + testDialects(List(Some(true), Some(true)), Some(true)) + testDialects(List(Some(false), Some(false)), Some(false)) + testDialects(List(None, None), None) + } + test("DB2Dialect type mapping") { val db2Dialect = JdbcDialects.get("jdbc:db2://127.0.0.1/db") assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") @@ -837,13 +861,13 @@ class JDBCSuite extends SparkFunSuite val h2 = JdbcDialects.get(url) val derby = JdbcDialects.get("jdbc:derby:db") val table = "weblogs" - val defaultQuery = s"TRUNCATE $table" - val postgresQuery = s"TRUNCATE ONLY $table" - assert(MySQL.getTableExistsQuery(table) == defaultQuery) - assert(Postgres.getTableExistsQuery(table) == postgresQuery) - assert(db2.getTableExistsQuery(table) == defaultQuery) - assert(h2.getTableExistsQuery(table) == defaultQuery) - assert(derby.getTableExistsQuery(table) == defaultQuery) + val defaultQuery = s"TRUNCATE TABLE $table" + val postgresQuery = s"TRUNCATE TABLE ONLY $table" + assert(MySQL.getTruncateQuery(table) == defaultQuery) + assert(Postgres.getTruncateQuery(table) == postgresQuery) + assert(db2.getTruncateQuery(table) == defaultQuery) + assert(h2.getTruncateQuery(table) == defaultQuery) + assert(derby.getTruncateQuery(table) == defaultQuery) } test("Test DataFrame.where for Date and Timestamp") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 3c69df20f502..1985b1dc8287 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -46,6 +46,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val testH2Dialect = new JdbcDialect { override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } before {