Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution.datasources.jdbc

import java.sql.{Connection, Driver, DriverManager, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException}
import java.sql.{Connection, Driver, DriverManager, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException, SQLFeatureNotSupportedException}
import java.util.Locale

import scala.collection.JavaConverters._
Expand All @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.connector.catalog.TableChange
import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -94,13 +95,7 @@ object JdbcUtils extends Logging {
* Drops a table from the JDBC database.
*/
def dropTable(conn: Connection, table: String, options: JDBCOptions): Unit = {
val statement = conn.createStatement
try {
statement.setQueryTimeout(options.queryTimeout)
statement.executeUpdate(s"DROP TABLE $table")
} finally {
statement.close()
}
executeStatement(conn, options, s"DROP TABLE $table")
}

/**
Expand Down Expand Up @@ -184,7 +179,7 @@ object JdbcUtils extends Logging {
}
}

private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.catalogString}"))
}
Expand Down Expand Up @@ -882,13 +877,7 @@ object JdbcUtils extends Logging {
// table_options or partition_options.
// E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
val sql = s"CREATE TABLE $tableName ($strSchema) $createTableOptions"
val statement = conn.createStatement
try {
statement.setQueryTimeout(options.queryTimeout)
statement.executeUpdate(sql)
} finally {
statement.close()
}
executeStatement(conn, options, sql)
}

/**
Expand All @@ -900,10 +889,51 @@ object JdbcUtils extends Logging {
newTable: String,
options: JDBCOptions): Unit = {
val dialect = JdbcDialects.get(options.url)
executeStatement(conn, options, dialect.renameTable(oldTable, newTable))
}

/**
* Update a table from the JDBC database.
*/
def alterTable(
conn: Connection,
tableName: String,
changes: Seq[TableChange],
options: JDBCOptions): Unit = {
val dialect = JdbcDialects.get(options.url)
if (changes.length == 1) {
executeStatement(conn, options, dialect.alterTable(tableName, changes)(0))
} else {
val metadata = conn.getMetaData
if (!metadata.supportsTransactions) {
throw new SQLFeatureNotSupportedException("The target JDBC server does not support " +
"transaction and can only support ALTER TABLE with a single action.")
} else {
conn.setAutoCommit(false)
val statement = conn.createStatement
try {
statement.setQueryTimeout(options.queryTimeout)
for (sql <- dialect.alterTable(tableName, changes)) {
statement.executeUpdate(sql)
}
conn.commit()
} catch {
case e: Exception =>
if (conn != null) conn.rollback()
throw e
} finally {
statement.close()
conn.setAutoCommit(true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we record auto commit status and restore back?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't need to record the original auto commit status. The default value of autocommit is true.

}
}
}
}

private def executeStatement(conn: Connection, options: JDBCOptions, sql: String): Unit = {
val statement = conn.createStatement
try {
statement.setQueryTimeout(options.queryTimeout)
statement.executeUpdate(dialect.renameTable(oldTable, newTable))
statement.executeUpdate(sql)
} finally {
statement.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,12 @@ class JDBCTableCatalog extends TableCatalog with Logging {
JDBCTable(ident, schema, writeOptions)
}

// TODO (SPARK-32402): Implement ALTER TABLE in JDBC Table Catalog
override def alterTable(ident: Identifier, changes: TableChange*): Table = {
// scalastyle:off throwerror
throw new NotImplementedError()
// scalastyle:on throwerror
checkNamespace(ident.namespace())
withConnection { conn =>
JdbcUtils.alterTable(conn, getTableName(ident), changes, options)
loadTable(ident)
}
}

private def checkNamespace(namespace: Array[String]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@

package org.apache.spark.sql.jdbc

import java.sql.{Connection, Date, Timestamp}
import java.sql.{Connection, Date, SQLFeatureNotSupportedException, Timestamp}

import scala.collection.mutable.ArrayBuilder

import org.apache.commons.lang3.StringUtils

import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.sql.connector.catalog.TableChange
import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -184,15 +189,51 @@ abstract class JdbcDialect extends Serializable {
/**
* Rename an existing table.
*
* TODO (SPARK-32382): Override this method in the dialects that don't support such syntax.
*
* @param oldTable The existing table.
* @param newTable New name of the table.
* @return The SQL statement to use for renaming the table.
*/
def renameTable(oldTable: String, newTable: String): String = {
s"ALTER TABLE $oldTable RENAME TO $newTable"
}

/**
* Alter an existing table.
* TODO (SPARK-32523): Override this method in the dialects that have different syntax.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because you will override this method in other places, not here. Remember to remove this later. :)

*
* @param tableName The name of the table to be altered.
* @param changes Changes to apply to the table.
* @return The SQL statements to use for altering the table.
*/
def alterTable(tableName: String, changes: Seq[TableChange]): Array[String] = {
val updateClause = ArrayBuilder.make[String]
for (change <- changes) {
change match {
case add: AddColumn if add.fieldNames.length == 1 =>
val dataType = JdbcUtils.getJdbcType(add.dataType(), this).databaseTypeDefinition
val name = add.fieldNames
updateClause += s"ALTER TABLE $tableName ADD COLUMN ${name(0)} $dataType"
case rename: RenameColumn if rename.fieldNames.length == 1 =>
val name = rename.fieldNames
updateClause += s"ALTER TABLE $tableName RENAME COLUMN ${name(0)} TO ${rename.newName}"
case delete: DeleteColumn if delete.fieldNames.length == 1 =>
val name = delete.fieldNames
updateClause += s"ALTER TABLE $tableName DROP COLUMN ${name(0)}"
case updateColumnType: UpdateColumnType if updateColumnType.fieldNames.length == 1 =>
val name = updateColumnType.fieldNames
val dataType = JdbcUtils.getJdbcType(updateColumnType.newDataType(), this)
.databaseTypeDefinition
updateClause += s"ALTER TABLE $tableName ALTER COLUMN ${name(0)} $dataType"
case updateNull: UpdateColumnNullability if updateNull.fieldNames.length == 1 =>
val name = updateNull.fieldNames
val nullable = if (updateNull.nullable()) "NULL" else "NOT NULL"
updateClause += s"ALTER TABLE $tableName ALTER COLUMN ${name(0)} SET $nullable"
case _ =>
throw new SQLFeatureNotSupportedException(s"Unsupported TableChange $change")
}
}
updateClause.result()
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.Properties
import org.apache.spark.SparkConf
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession {
Expand Down Expand Up @@ -106,4 +106,71 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession {
Seq(Row("test", "people"), Row("test", "new_table")))
}
}

test("alter table ... add column") {
withTable("h2.test.alt_table") {
sql("CREATE TABLE h2.test.alt_table (ID INTEGER) USING _")
sql("ALTER TABLE h2.test.alt_table ADD COLUMNS (C1 INTEGER, C2 STRING)")
var t = spark.table("h2.test.alt_table")
var expectedSchema = new StructType()
.add("ID", IntegerType)
.add("C1", IntegerType)
.add("C2", StringType)
assert(t.schema === expectedSchema)
sql("ALTER TABLE h2.test.alt_table ADD COLUMNS (C3 DOUBLE)")
t = spark.table("h2.test.alt_table")
expectedSchema = expectedSchema.add("C3", DoubleType)
assert(t.schema === expectedSchema)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to compare schemas than my previous method to use DESCRIBE TABLE.

}
}

test("alter table ... rename column") {
withTable("h2.test.alt_table") {
sql("CREATE TABLE h2.test.alt_table (ID INTEGER) USING _")
sql("ALTER TABLE h2.test.alt_table RENAME COLUMN ID TO C")
val t = spark.table("h2.test.alt_table")
val expectedSchema = new StructType().add("C", IntegerType)
assert(t.schema === expectedSchema)
}
}

test("alter table ... drop column") {
withTable("h2.test.alt_table") {
sql("CREATE TABLE h2.test.alt_table (C1 INTEGER, C2 INTEGER) USING _")
sql("ALTER TABLE h2.test.alt_table DROP COLUMN C1")
val t = spark.table("h2.test.alt_table")
val expectedSchema = new StructType().add("C2", IntegerType)
assert(t.schema === expectedSchema)
}
}

test("alter table ... update column type") {
withTable("h2.test.alt_table") {
sql("CREATE TABLE h2.test.alt_table (ID INTEGER) USING _")
sql("ALTER TABLE h2.test.alt_table ALTER COLUMN id TYPE DOUBLE")
val t = spark.table("h2.test.alt_table")
val expectedSchema = new StructType().add("ID", DoubleType)
assert(t.schema === expectedSchema)
}
}

test("alter table ... update column nullability") {
withTable("h2.test.alt_table") {
sql("CREATE TABLE h2.test.alt_table (ID INTEGER NOT NULL) USING _")
sql("ALTER TABLE h2.test.alt_table ALTER COLUMN ID DROP NOT NULL")
val t = spark.table("h2.test.alt_table")
val expectedSchema = new StructType().add("ID", IntegerType, nullable = true)
assert(t.schema === expectedSchema)
}
}

test("alter table ... update column comment not supported") {
withTable("h2.test.alt_table") {
sql("CREATE TABLE h2.test.alt_table (ID INTEGER) USING _")
val thrown = intercept[java.sql.SQLFeatureNotSupportedException] {
sql("ALTER TABLE h2.test.alt_table ALTER COLUMN ID COMMENT 'test'")
}
assert(thrown.getMessage.contains("Unsupported TableChange"))
}
}
}