diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 37e7bb0a59bb..cc506e51bd0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -68,7 +68,7 @@ class JdbcRelationProvider extends CreatableRelationProvider case SaveMode.Overwrite => if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) { // In this case, we should truncate table and then load. - truncateTable(conn, options.table) + truncateTable(conn, options) val tableSchema = JdbcUtils.getSchemaOption(conn, options) saveTable(df, tableSchema, isCaseSensitive, options) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 75c94fc48649..2193b4cdc383 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -96,12 +96,13 @@ object JdbcUtils extends Logging { } /** - * Truncates a table from the JDBC database. + * Truncates a table from the JDBC database without side effects. */ - def truncateTable(conn: Connection, table: String): Unit = { + def truncateTable(conn: Connection, options: JDBCOptions): Unit = { + val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { - statement.executeUpdate(s"TRUNCATE TABLE $table") + statement.executeUpdate(dialect.getTruncateQuery(options.table)) } finally { statement.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index f3bfea5f6bfc..8b92c8b4f56b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.types.{DataType, MetadataBuilder} /** * AggregatedDialect can unify multiple dialects into one virtual Dialect. * Dialects are tried in order, and the first dialect that does not return a - * neutral element will will. + * neutral element will win. * * @param dialects List of dialects. */ @@ -63,4 +63,8 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect case _ => Some(false) } } + + override def getTruncateQuery(table: String): String = { + dialects.head.getTruncateQuery(table) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 7c38ed68c041..83d87a11810c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -116,6 +116,18 @@ abstract class JdbcDialect extends Serializable { s"SELECT * FROM $table WHERE 1=0" } + /** + * The SQL query that should be used to truncate a table. Dialects can override this method to + * return a query that is suitable for a particular database. For PostgreSQL, for instance, + * a different query is used to prevent "TRUNCATE" affecting other tables. + * @param table The name of the table. + * @return The SQL query to use for truncating a table + */ + @Since("2.3.0") + def getTruncateQuery(table: String): String = { + s"TRUNCATE TABLE $table" + } + /** * Override connection specific properties to run before a select is made. This is in place to * allow dialects that need special treatment to optimize behavior. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 4f61a328f47c..13a2035f4d0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -85,6 +85,17 @@ private object PostgresDialect extends JdbcDialect { s"SELECT 1 FROM $table LIMIT 1" } + /** + * The SQL query used to truncate a table. For Postgres, the default behaviour is to + * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, + * the Postgres dialect adds 'ONLY' to truncate only the table in question + * @param table The name of the table. + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery(table: String): String = { + s"TRUNCATE TABLE ONLY $table" + } + override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { super.beforeFetch(connection, properties) @@ -97,8 +108,7 @@ private object PostgresDialect extends JdbcDialect { if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) { connection.setAutoCommit(false) } - } - override def isCascadingTruncateTable(): Option[Boolean] = Some(true) + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 61571bccdcb5..43b9307e5aab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -854,6 +854,22 @@ class JDBCSuite extends SparkFunSuite assert(derby.getTableExistsQuery(table) == defaultQuery) } + test("truncate table query by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val derby = JdbcDialects.get("jdbc:derby:db") + val table = "weblogs" + val defaultQuery = s"TRUNCATE TABLE $table" + val postgresQuery = s"TRUNCATE TABLE ONLY $table" + assert(MySQL.getTruncateQuery(table) == defaultQuery) + assert(Postgres.getTruncateQuery(table) == postgresQuery) + assert(db2.getTruncateQuery(table) == defaultQuery) + assert(h2.getTruncateQuery(table) == defaultQuery) + assert(derby.getTruncateQuery(table) == defaultQuery) + } + test("Test DataFrame.where for Date and Timestamp") { // Regression test for bug SPARK-11788 val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543");