Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
case SaveMode.Overwrite =>
if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
// In this case, we should truncate table and then load.
truncateTable(conn, options.table)
truncateTable(conn, options)
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
saveTable(df, tableSchema, isCaseSensitive, options)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,13 @@ object JdbcUtils extends Logging {
}

/**
* Truncates a table from the JDBC database.
* Truncates a table from the JDBC database without side effects.
*/
def truncateTable(conn: Connection, table: String): Unit = {
def truncateTable(conn: Connection, options: JDBCOptions): Unit = {
val dialect = JdbcDialects.get(options.url)
val statement = conn.createStatement
try {
statement.executeUpdate(s"TRUNCATE TABLE $table")
statement.executeUpdate(dialect.getTruncateQuery(options.table))
} finally {
statement.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.types.{DataType, MetadataBuilder}
/**
* AggregatedDialect can unify multiple dialects into one virtual Dialect.
* Dialects are tried in order, and the first dialect that does not return a
* neutral element will will.
* neutral element will win.
*
* @param dialects List of dialects.
*/
Expand Down Expand Up @@ -63,4 +63,8 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect
case _ => Some(false)
}
}

override def getTruncateQuery(table: String): String = {
dialects.head.getTruncateQuery(table)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,18 @@ abstract class JdbcDialect extends Serializable {
s"SELECT * FROM $table WHERE 1=0"
}

/**
* The SQL query that should be used to truncate a table. Dialects can override this method to
* return a query that is suitable for a particular database. For PostgreSQL, for instance,
* a different query is used to prevent "TRUNCATE" affecting other tables.
* @param table The name of the table.
* @return The SQL query to use for truncating a table
*/
@Since("2.3.0")
def getTruncateQuery(table: String): String = {
s"TRUNCATE TABLE $table"
}

/**
* Override connection specific properties to run before a select is made. This is in place to
* allow dialects that need special treatment to optimize behavior.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ private object PostgresDialect extends JdbcDialect {
s"SELECT 1 FROM $table LIMIT 1"
}

/**
* The SQL query used to truncate a table. For Postgres, the default behaviour is to
* also truncate any descendant tables. As this is a (possibly unwanted) side-effect,
* the Postgres dialect adds 'ONLY' to truncate only the table in question
* @param table The name of the table.
* @return The SQL query to use for truncating a table
*/
override def getTruncateQuery(table: String): String = {
s"TRUNCATE TABLE ONLY $table"
}

override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
super.beforeFetch(connection, properties)

Expand All @@ -97,8 +108,7 @@ private object PostgresDialect extends JdbcDialect {
if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) {
connection.setAutoCommit(false)
}

}

override def isCascadingTruncateTable(): Option[Boolean] = Some(true)
override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
}
16 changes: 16 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,22 @@ class JDBCSuite extends SparkFunSuite
assert(derby.getTableExistsQuery(table) == defaultQuery)
}

test("truncate table query by jdbc dialect") {
val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db")
val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db")
val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db")
val h2 = JdbcDialects.get(url)
val derby = JdbcDialects.get("jdbc:derby:db")
val table = "weblogs"
val defaultQuery = s"TRUNCATE TABLE $table"
val postgresQuery = s"TRUNCATE TABLE ONLY $table"
assert(MySQL.getTruncateQuery(table) == defaultQuery)
assert(Postgres.getTruncateQuery(table) == postgresQuery)
assert(db2.getTruncateQuery(table) == defaultQuery)
assert(h2.getTruncateQuery(table) == defaultQuery)
assert(derby.getTruncateQuery(table) == defaultQuery)
}

test("Test DataFrame.where for Date and Timestamp") {
// Regression test for bug SPARK-11788
val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543");
Expand Down