diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index e6a8dfac0adc..753b64b983d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -387,6 +387,15 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash
* your external database systems.
*
+ * You can set the following JDBC-specific option(s) for storing JDBC:
+ *
`truncate` (default `false`): use `TRUNCATE TABLE` instead of `DROP TABLE`.
+ *
+ * In case of failures, users should turn off `truncate` option to use `DROP TABLE` again. Also,
+ * due to the different behavior of `TRUNCATE TABLE` among DBMS, it's not always safe to use this.
+ * MySQLDialect, DB2Dialect, MsSqlServerDialect, DerbyDialect, and OracleDialect supports this
+ * while PostgresDialect and default JDBCDirect doesn't. For unknown and unsupported JDBCDirect,
+ * the user option `truncate` is ignored.
+ *
* @param url JDBC database url of the form `jdbc:subprotocol:subname`
* @param table Name of the table in the external database.
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
@@ -423,8 +432,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}
if (mode == SaveMode.Overwrite && tableExists) {
- JdbcUtils.dropTable(conn, table)
- tableExists = false
+ if (extraOptions.getOrElse("truncate", "false").toBoolean &&
+ JdbcUtils.isCascadingTruncateTable(url) == Some(false)) {
+ JdbcUtils.truncateTable(conn, table)
+ } else {
+ JdbcUtils.dropTable(conn, table)
+ tableExists = false
+ }
}
// Create the table if the table didn't exist.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index ce71a7d1e6a2..cb474cbd0ae7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -98,6 +98,22 @@ object JdbcUtils extends Logging {
}
}
+ /**
+ * Truncates a table from the JDBC database.
+ */
+ def truncateTable(conn: Connection, table: String): Unit = {
+ val statement = conn.createStatement
+ try {
+ statement.executeUpdate(s"TRUNCATE TABLE $table")
+ } finally {
+ statement.close()
+ }
+ }
+
+ def isCascadingTruncateTable(url: String): Option[Boolean] = {
+ JdbcDialects.get(url).isCascadingTruncateTable()
+ }
+
/**
* Returns a PreparedStatement that inserts a row into table via conn.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
index f12b6ca9d6ad..190463df0d92 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
@@ -28,4 +28,6 @@ private object DB2Dialect extends JdbcDialect {
case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR))
case _ => None
}
+
+ override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 948106fd062a..78107809a1cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -108,6 +108,13 @@ abstract class JdbcDialect extends Serializable {
def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
}
+ /**
+ * Return Some[true] iff `TRUNCATE TABLE` causes cascading default.
+ * Some[true] : TRUNCATE TABLE causes cascading.
+ * Some[false] : TRUNCATE TABLE does not cause cascading.
+ * None: The behavior of TRUNCATE TABLE is unknown (default).
+ */
+ def isCascadingTruncateTable(): Option[Boolean] = None
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
index 3eb722b070d5..70122f259914 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
@@ -38,4 +38,6 @@ private object MsSqlServerDialect extends JdbcDialect {
case TimestampType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP))
case _ => None
}
+
+ override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
index e1717049f383..b2cff7877d8b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
@@ -44,4 +44,6 @@ private case object MySQLDialect extends JdbcDialect {
override def getTableExistsQuery(table: String): String = {
s"SELECT 1 FROM $table LIMIT 1"
}
+
+ override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
index b795e8b42df0..ce8731efd166 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
@@ -53,4 +53,6 @@ private case object OracleDialect extends JdbcDialect {
case StringType => Some(JdbcType("VARCHAR2(255)", java.sql.Types.VARCHAR))
case _ => None
}
+
+ override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index 6baf1b6f16cd..fb959d881e9d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -94,4 +94,6 @@ private object PostgresDialect extends JdbcDialect {
}
}
+
+ override def isCascadingTruncateTable(): Option[Boolean] = Some(true)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index 2c6449fa6870..d99b3cf975f4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -40,6 +40,14 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
properties.setProperty("password", "testPass")
properties.setProperty("rowId", "false")
+ val testH2Dialect = new JdbcDialect {
+ override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2")
+ override def getCatalystType(
+ sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
+ Some(StringType)
+ override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
+ }
+
before {
Utils.classForName("org.h2.Driver")
conn = DriverManager.getConnection(url)
@@ -145,14 +153,25 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length)
}
- test("CREATE then INSERT to truncate") {
+ test("Truncate") {
+ JdbcDialects.registerDialect(testH2Dialect)
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
+ val df3 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3)
df.write.jdbc(url1, "TEST.TRUNCATETEST", properties)
- df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties)
+ df2.write.mode(SaveMode.Overwrite).option("truncate", true)
+ .jdbc(url1, "TEST.TRUNCATETEST", properties)
assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count())
assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
+
+ val m = intercept[SparkException] {
+ df3.write.mode(SaveMode.Overwrite).option("truncate", true)
+ .jdbc(url1, "TEST.TRUNCATETEST", properties)
+ }.getMessage
+ assert(m.contains("Column \"seq\" not found"))
+ assert(0 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count())
+ JdbcDialects.unregisterDialect(testH2Dialect)
}
test("Incompatible INSERT to append") {