diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 3c6649b26ecd2..178dc6176e7c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Driver, DriverManager, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} +import java.sql.{Connection, Driver, DriverManager, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException, SQLFeatureNotSupportedException} import java.util.Locale import scala.collection.JavaConverters._ @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ @@ -94,13 +95,7 @@ object JdbcUtils extends Logging { * Drops a table from the JDBC database. */ def dropTable(conn: Connection, table: String, options: JDBCOptions): Unit = { - val statement = conn.createStatement - try { - statement.setQueryTimeout(options.queryTimeout) - statement.executeUpdate(s"DROP TABLE $table") - } finally { - statement.close() - } + executeStatement(conn, options, s"DROP TABLE $table") } /** @@ -184,7 +179,7 @@ object JdbcUtils extends Logging { } } - private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { + def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse( throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.catalogString}")) } @@ -882,13 +877,7 @@ object JdbcUtils extends Logging { // table_options or partition_options. // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" val sql = s"CREATE TABLE $tableName ($strSchema) $createTableOptions" - val statement = conn.createStatement - try { - statement.setQueryTimeout(options.queryTimeout) - statement.executeUpdate(sql) - } finally { - statement.close() - } + executeStatement(conn, options, sql) } /** @@ -900,10 +889,51 @@ object JdbcUtils extends Logging { newTable: String, options: JDBCOptions): Unit = { val dialect = JdbcDialects.get(options.url) + executeStatement(conn, options, dialect.renameTable(oldTable, newTable)) + } + + /** + * Update a table from the JDBC database. + */ + def alterTable( + conn: Connection, + tableName: String, + changes: Seq[TableChange], + options: JDBCOptions): Unit = { + val dialect = JdbcDialects.get(options.url) + if (changes.length == 1) { + executeStatement(conn, options, dialect.alterTable(tableName, changes)(0)) + } else { + val metadata = conn.getMetaData + if (!metadata.supportsTransactions) { + throw new SQLFeatureNotSupportedException("The target JDBC server does not support " + + "transaction and can only support ALTER TABLE with a single action.") + } else { + conn.setAutoCommit(false) + val statement = conn.createStatement + try { + statement.setQueryTimeout(options.queryTimeout) + for (sql <- dialect.alterTable(tableName, changes)) { + statement.executeUpdate(sql) + } + conn.commit() + } catch { + case e: Exception => + if (conn != null) conn.rollback() + throw e + } finally { + statement.close() + conn.setAutoCommit(true) + } + } + } + } + + private def executeStatement(conn: Connection, options: JDBCOptions, sql: String): Unit = { val statement = conn.createStatement try { statement.setQueryTimeout(options.queryTimeout) - statement.executeUpdate(dialect.renameTable(oldTable, newTable)) + statement.executeUpdate(sql) } finally { statement.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala index 5d64cf4ca896e..0138014a8e21e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala @@ -129,11 +129,12 @@ class JDBCTableCatalog extends TableCatalog with Logging { JDBCTable(ident, schema, writeOptions) } - // TODO (SPARK-32402): Implement ALTER TABLE in JDBC Table Catalog override def alterTable(ident: Identifier, changes: TableChange*): Table = { - // scalastyle:off throwerror - throw new NotImplementedError() - // scalastyle:on throwerror + checkNamespace(ident.namespace()) + withConnection { conn => + JdbcUtils.alterTable(conn, getTableName(ident), changes, options) + loadTable(ident) + } } private def checkNamespace(namespace: Array[String]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index b0f9aba859d3a..cea5a20917532 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.jdbc -import java.sql.{Connection, Date, Timestamp} +import java.sql.{Connection, Date, SQLFeatureNotSupportedException, Timestamp} + +import scala.collection.mutable.ArrayBuilder import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.sql.connector.catalog.TableChange +import org.apache.spark.sql.connector.catalog.TableChange._ +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.types._ /** @@ -184,8 +189,6 @@ abstract class JdbcDialect extends Serializable { /** * Rename an existing table. * - * TODO (SPARK-32382): Override this method in the dialects that don't support such syntax. - * * @param oldTable The existing table. * @param newTable New name of the table. * @return The SQL statement to use for renaming the table. @@ -193,6 +196,44 @@ abstract class JdbcDialect extends Serializable { def renameTable(oldTable: String, newTable: String): String = { s"ALTER TABLE $oldTable RENAME TO $newTable" } + + /** + * Alter an existing table. + * TODO (SPARK-32523): Override this method in the dialects that have different syntax. + * + * @param tableName The name of the table to be altered. + * @param changes Changes to apply to the table. + * @return The SQL statements to use for altering the table. + */ + def alterTable(tableName: String, changes: Seq[TableChange]): Array[String] = { + val updateClause = ArrayBuilder.make[String] + for (change <- changes) { + change match { + case add: AddColumn if add.fieldNames.length == 1 => + val dataType = JdbcUtils.getJdbcType(add.dataType(), this).databaseTypeDefinition + val name = add.fieldNames + updateClause += s"ALTER TABLE $tableName ADD COLUMN ${name(0)} $dataType" + case rename: RenameColumn if rename.fieldNames.length == 1 => + val name = rename.fieldNames + updateClause += s"ALTER TABLE $tableName RENAME COLUMN ${name(0)} TO ${rename.newName}" + case delete: DeleteColumn if delete.fieldNames.length == 1 => + val name = delete.fieldNames + updateClause += s"ALTER TABLE $tableName DROP COLUMN ${name(0)}" + case updateColumnType: UpdateColumnType if updateColumnType.fieldNames.length == 1 => + val name = updateColumnType.fieldNames + val dataType = JdbcUtils.getJdbcType(updateColumnType.newDataType(), this) + .databaseTypeDefinition + updateClause += s"ALTER TABLE $tableName ALTER COLUMN ${name(0)} $dataType" + case updateNull: UpdateColumnNullability if updateNull.fieldNames.length == 1 => + val name = updateNull.fieldNames + val nullable = if (updateNull.nullable()) "NULL" else "NOT NULL" + updateClause += s"ALTER TABLE $tableName ALTER COLUMN ${name(0)} SET $nullable" + case _ => + throw new SQLFeatureNotSupportedException(s"Unsupported TableChange $change") + } + } + updateClause.result() + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala index 0eb96b7813e6e..b308934ba03c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalogSuite.scala @@ -22,7 +22,7 @@ import java.util.Properties import org.apache.spark.SparkConf import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { @@ -106,4 +106,71 @@ class JDBCTableCatalogSuite extends QueryTest with SharedSparkSession { Seq(Row("test", "people"), Row("test", "new_table"))) } } + + test("alter table ... add column") { + withTable("h2.test.alt_table") { + sql("CREATE TABLE h2.test.alt_table (ID INTEGER) USING _") + sql("ALTER TABLE h2.test.alt_table ADD COLUMNS (C1 INTEGER, C2 STRING)") + var t = spark.table("h2.test.alt_table") + var expectedSchema = new StructType() + .add("ID", IntegerType) + .add("C1", IntegerType) + .add("C2", StringType) + assert(t.schema === expectedSchema) + sql("ALTER TABLE h2.test.alt_table ADD COLUMNS (C3 DOUBLE)") + t = spark.table("h2.test.alt_table") + expectedSchema = expectedSchema.add("C3", DoubleType) + assert(t.schema === expectedSchema) + } + } + + test("alter table ... rename column") { + withTable("h2.test.alt_table") { + sql("CREATE TABLE h2.test.alt_table (ID INTEGER) USING _") + sql("ALTER TABLE h2.test.alt_table RENAME COLUMN ID TO C") + val t = spark.table("h2.test.alt_table") + val expectedSchema = new StructType().add("C", IntegerType) + assert(t.schema === expectedSchema) + } + } + + test("alter table ... drop column") { + withTable("h2.test.alt_table") { + sql("CREATE TABLE h2.test.alt_table (C1 INTEGER, C2 INTEGER) USING _") + sql("ALTER TABLE h2.test.alt_table DROP COLUMN C1") + val t = spark.table("h2.test.alt_table") + val expectedSchema = new StructType().add("C2", IntegerType) + assert(t.schema === expectedSchema) + } + } + + test("alter table ... update column type") { + withTable("h2.test.alt_table") { + sql("CREATE TABLE h2.test.alt_table (ID INTEGER) USING _") + sql("ALTER TABLE h2.test.alt_table ALTER COLUMN id TYPE DOUBLE") + val t = spark.table("h2.test.alt_table") + val expectedSchema = new StructType().add("ID", DoubleType) + assert(t.schema === expectedSchema) + } + } + + test("alter table ... update column nullability") { + withTable("h2.test.alt_table") { + sql("CREATE TABLE h2.test.alt_table (ID INTEGER NOT NULL) USING _") + sql("ALTER TABLE h2.test.alt_table ALTER COLUMN ID DROP NOT NULL") + val t = spark.table("h2.test.alt_table") + val expectedSchema = new StructType().add("ID", IntegerType, nullable = true) + assert(t.schema === expectedSchema) + } + } + + test("alter table ... update column comment not supported") { + withTable("h2.test.alt_table") { + sql("CREATE TABLE h2.test.alt_table (ID INTEGER) USING _") + val thrown = intercept[java.sql.SQLFeatureNotSupportedException] { + sql("ALTER TABLE h2.test.alt_table ALTER COLUMN ID COMMENT 'test'") + } + assert(thrown.getMessage.contains("Unsupported TableChange")) + } + } }