diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 6d465469f413..1df9d1526671 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -110,10 +110,11 @@ statement RENAME COLUMN from=multipartIdentifier TO to=errorCapturingIdentifier #renameTableColumn | ALTER TABLE multipartIdentifier - DROP (COLUMN | COLUMNS) + DROP (COLUMN | COLUMNS) (IF EXISTS)? LEFT_PAREN columns=multipartIdentifierList RIGHT_PAREN #dropTableColumns | ALTER TABLE multipartIdentifier - DROP (COLUMN | COLUMNS) columns=multipartIdentifierList #dropTableColumns + DROP (COLUMN | COLUMNS) (IF EXISTS)? + columns=multipartIdentifierList #dropTableColumns | ALTER (TABLE | VIEW) from=multipartIdentifier RENAME TO to=multipartIdentifier #renameTable | ALTER (TABLE | VIEW) multipartIdentifier diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java index 3ed185a82452..c63d2d458619 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java @@ -224,10 +224,11 @@ static TableChange updateColumnPosition(String[] fieldNames, ColumnPosition newP * If the field does not exist, the change will result in an {@link IllegalArgumentException}. * * @param fieldNames field names of the column to delete + * @param ifExists silence the error if column doesn't exist during drop * @return a TableChange for the delete */ - static TableChange deleteColumn(String[] fieldNames) { - return new DeleteColumn(fieldNames); + static TableChange deleteColumn(String[] fieldNames, Boolean ifExists) { + return new DeleteColumn(fieldNames, ifExists); } /** @@ -651,9 +652,11 @@ public int hashCode() { */ final class DeleteColumn implements ColumnChange { private final String[] fieldNames; + private final Boolean ifExists; - private DeleteColumn(String[] fieldNames) { + private DeleteColumn(String[] fieldNames, Boolean ifExists) { this.fieldNames = fieldNames; + this.ifExists = ifExists; } @Override @@ -661,12 +664,14 @@ public String[] fieldNames() { return fieldNames; } + public Boolean ifExists() { return ifExists; } + @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; DeleteColumn that = (DeleteColumn) o; - return Arrays.equals(fieldNames, that.fieldNames); + return Arrays.equals(fieldNames, that.fieldNames) && that.ifExists() == this.ifExists(); } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d00818ba1ea9..065cd9cde4dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3668,6 +3668,12 @@ class Analyzer(override val catalogManager: CatalogManager) case other => other }) + case a: DropColumns if a.table.resolved && hasUnresolvedFieldName(a) && a.ifExists => + // for DropColumn with IF EXISTS clause, we should resolve and ignore missing column errors + val table = a.table.asInstanceOf[ResolvedTable] + val columnsToDrop = a.columnsToDrop + a.copy(columnsToDrop = columnsToDrop.flatMap(c => resolveFieldNamesOpt(table, c.name, c))) + case a: AlterTableCommand if a.table.resolved && hasUnresolvedFieldName(a) => val table = a.table.asInstanceOf[ResolvedTable] a.transformExpressions { @@ -3757,11 +3763,19 @@ class Analyzer(override val catalogManager: CatalogManager) table: ResolvedTable, fieldName: Seq[String], context: Expression): ResolvedFieldName = { + resolveFieldNamesOpt(table, fieldName, context) + .getOrElse(throw QueryCompilationErrors.missingFieldError(fieldName, table, context.origin)) + } + + private def resolveFieldNamesOpt( + table: ResolvedTable, + fieldName: Seq[String], + context: Expression): Option[ResolvedFieldName] = { table.schema.findNestedField( fieldName, includeCollections = true, conf.resolver, context.origin ).map { case (path, field) => ResolvedFieldName(path, field) - }.getOrElse(throw QueryCompilationErrors.missingFieldError(fieldName, table, context.origin)) + } } private def hasUnresolvedFieldName(a: AlterTableCommand): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4ffb4f1cfe1e..75987e78ea78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3930,12 +3930,14 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit */ override def visitDropTableColumns( ctx: DropTableColumnsContext): LogicalPlan = withOrigin(ctx) { + val ifExists = ctx.EXISTS() != null val columnsToDrop = ctx.columns.multipartIdentifier.asScala.map(typedVisit[Seq[String]]) DropColumns( createUnresolvedTable( ctx.multipartIdentifier, "ALTER TABLE ... DROP COLUMNS"), - columnsToDrop.map(UnresolvedFieldName(_)).toSeq) + columnsToDrop.map(UnresolvedFieldName(_)).toSeq, + ifExists) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala index 302a810485c9..8cc93c2dd099 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala @@ -143,7 +143,8 @@ case class ReplaceColumns( // REPLACE COLUMNS deletes all the existing columns and adds new columns specified. require(table.resolved) val deleteChanges = table.schema.fieldNames.map { name => - TableChange.deleteColumn(Array(name)) + // REPLACE COLUMN should require column to exist + TableChange.deleteColumn(Array(name), ifExists = false) } val addChanges = columnsToAdd.map { col => assert(col.path.isEmpty) @@ -167,11 +168,12 @@ case class ReplaceColumns( */ case class DropColumns( table: LogicalPlan, - columnsToDrop: Seq[FieldName]) extends AlterTableCommand { + columnsToDrop: Seq[FieldName], + ifExists: Boolean) extends AlterTableCommand { override def changes: Seq[TableChange] = { columnsToDrop.map { col => require(col.resolved, "FieldName should be resolved before it's converted to TableChange.") - TableChange.deleteColumn(col.name.toArray) + TableChange.deleteColumn(col.name.toArray, ifExists) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 4092674046ec..2fc13510c54e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -191,7 +191,7 @@ private[sql] object CatalogV2Util { } case delete: DeleteColumn => - replace(schema, delete.fieldNames, _ => None) + replace(schema, delete.fieldNames, _ => None, delete.ifExists) case _ => // ignore non-schema changes @@ -222,17 +222,28 @@ private[sql] object CatalogV2Util { private def replace( struct: StructType, fieldNames: Seq[String], - update: StructField => Option[StructField]): StructType = { + update: StructField => Option[StructField], + ifExists: Boolean = false): StructType = { + + val posOpt = struct.getFieldIndex(fieldNames.head) + if (posOpt.isEmpty) { + if (ifExists) { + // We couldn't find the column to replace, but with IF EXISTS, we will silence the error + // Currently only DROP COLUMN may pass down the IF EXISTS parameter + return struct + } else { + throw new IllegalArgumentException(s"Cannot find field: ${fieldNames.head}") + } + } - val pos = struct.getFieldIndex(fieldNames.head) - .getOrElse(throw new IllegalArgumentException(s"Cannot find field: ${fieldNames.head}")) + val pos = posOpt.get val field = struct.fields(pos) val replacement: Option[StructField] = (fieldNames.tail, field.dataType) match { case (Seq(), _) => update(field) case (names, struct: StructType) => - val updatedType: StructType = replace(struct, names, update) + val updatedType: StructType = replace(struct, names, update, ifExists) Some(StructField(field.name, updatedType, field.nullable, field.metadata)) case (Seq("key"), map @ MapType(keyType, _, _)) => @@ -241,7 +252,7 @@ private[sql] object CatalogV2Util { Some(field.copy(dataType = map.copy(keyType = updated.dataType))) case (Seq("key", names @ _*), map @ MapType(keyStruct: StructType, _, _)) => - Some(field.copy(dataType = map.copy(keyType = replace(keyStruct, names, update)))) + Some(field.copy(dataType = map.copy(keyType = replace(keyStruct, names, update, ifExists)))) case (Seq("value"), map @ MapType(_, mapValueType, isNullable)) => val updated = update(StructField("value", mapValueType, nullable = isNullable)) @@ -251,7 +262,8 @@ private[sql] object CatalogV2Util { valueContainsNull = updated.nullable))) case (Seq("value", names @ _*), map @ MapType(_, valueStruct: StructType, _)) => - Some(field.copy(dataType = map.copy(valueType = replace(valueStruct, names, update)))) + Some(field.copy(dataType = map.copy(valueType = + replace(valueStruct, names, update, ifExists)))) case (Seq("element"), array @ ArrayType(elementType, isNullable)) => val updated = update(StructField("element", elementType, nullable = isNullable)) @@ -261,11 +273,15 @@ private[sql] object CatalogV2Util { containsNull = updated.nullable))) case (Seq("element", names @ _*), array @ ArrayType(elementStruct: StructType, _)) => - Some(field.copy(dataType = array.copy(elementType = replace(elementStruct, names, update)))) + Some(field.copy(dataType = array.copy(elementType = + replace(elementStruct, names, update, ifExists)))) case (names, dataType) => - throw new IllegalArgumentException( - s"Cannot find field: ${names.head} in ${dataType.simpleString}") + if (!ifExists) { + throw new IllegalArgumentException( + s"Cannot find field: ${names.head} in ${dataType.simpleString}") + } + None } val newFields = struct.fields.zipWithIndex.flatMap { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 8f280727a781..ade8c61f79fe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -1022,7 +1022,15 @@ class DDLParserSuite extends AnalysisTest { parsePlan("ALTER TABLE table_name DROP COLUMN a.b.c"), DropColumns( UnresolvedTable(Seq("table_name"), "ALTER TABLE ... DROP COLUMNS", None), - Seq(UnresolvedFieldName(Seq("a", "b", "c"))))) + Seq(UnresolvedFieldName(Seq("a", "b", "c"))), + ifExists = false)) + + comparePlans( + parsePlan("ALTER TABLE table_name DROP COLUMN IF EXISTS a.b.c"), + DropColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... DROP COLUMNS", None), + Seq(UnresolvedFieldName(Seq("a", "b", "c"))), + ifExists = true)) } test("alter table: drop multiple columns") { @@ -1034,7 +1042,20 @@ class DDLParserSuite extends AnalysisTest { UnresolvedTable(Seq("table_name"), "ALTER TABLE ... DROP COLUMNS", None), Seq(UnresolvedFieldName(Seq("x")), UnresolvedFieldName(Seq("y")), - UnresolvedFieldName(Seq("a", "b", "c"))))) + UnresolvedFieldName(Seq("a", "b", "c"))), + ifExists = false)) + } + + val sqlIfExists = "ALTER TABLE table_name DROP COLUMN IF EXISTS x, y, a.b.c" + Seq(sqlIfExists, sqlIfExists.replace("COLUMN", "COLUMNS")).foreach { drop => + comparePlans( + parsePlan(drop), + DropColumns( + UnresolvedTable(Seq("table_name"), "ALTER TABLE ... DROP COLUMNS", None), + Seq(UnresolvedFieldName(Seq("x")), + UnresolvedFieldName(Seq("y")), + UnresolvedFieldName(Seq("a", "b", "c"))), + ifExists = true)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala index d00bc31e07f1..54aad8b63ad5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala @@ -550,7 +550,7 @@ class CatalogSuite extends SparkFunSuite { assert(table.schema == schema) val updated = catalog.alterTable(testIdent, - TableChange.deleteColumn(Array("id"))) + TableChange.deleteColumn(Array("id"), false)) val expectedSchema = new StructType().add("data", StringType) assert(updated.schema == expectedSchema) @@ -567,7 +567,7 @@ class CatalogSuite extends SparkFunSuite { assert(table.schema == tableSchema) val updated = catalog.alterTable(testIdent, - TableChange.deleteColumn(Array("point", "y"))) + TableChange.deleteColumn(Array("point", "y"), false)) val newPointStruct = new StructType().add("x", DoubleType) val expectedSchema = schema.add("point", newPointStruct) @@ -583,11 +583,15 @@ class CatalogSuite extends SparkFunSuite { assert(table.schema == schema) val exc = intercept[IllegalArgumentException] { - catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"))) + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"), false)) } assert(exc.getMessage.contains("missing_col")) assert(exc.getMessage.contains("Cannot find")) + + // with if exists it should pass + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"), true)) + assert(table.schema == schema) } test("alterTable: delete missing nested column fails") { @@ -601,11 +605,15 @@ class CatalogSuite extends SparkFunSuite { assert(table.schema == tableSchema) val exc = intercept[IllegalArgumentException] { - catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"))) + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"), false)) } assert(exc.getMessage.contains("z")) assert(exc.getMessage.contains("Cannot find")) + + // with if exists it should pass + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"), true)) + assert(table.schema == tableSchema) } test("alterTable: table does not exist") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 5722bedfb53e..252b68825194 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -93,7 +93,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) case RenameColumn(ResolvedV1TableIdentifier(_), _, _) => throw QueryCompilationErrors.operationOnlySupportedWithV2TableError("RENAME COLUMN") - case DropColumns(ResolvedV1TableIdentifier(_), _) => + case DropColumns(ResolvedV1TableIdentifier(_), _, _) => throw QueryCompilationErrors.operationOnlySupportedWithV2TableError("DROP COLUMN") case SetTableProperties(ResolvedV1TableIdentifier(ident), props) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index 1b0898fbc12f..19f3f86c9411 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -1070,7 +1070,7 @@ trait AlterTableTests extends SharedSparkSession { } } - test("AlterTable: drop column must exist") { + test("AlterTable: drop column must exist if required") { val t = s"${catalogAndNamespace}table_name" withTable(t) { sql(s"CREATE TABLE $t (id int) USING $v2Format") @@ -1080,10 +1080,15 @@ trait AlterTableTests extends SharedSparkSession { } assert(exc.getMessage.contains("Missing field data")) + + // with if exists it should pass + sql(s"ALTER TABLE $t DROP COLUMN IF EXISTS data") + val table = getTableMetadata(fullTableName(t)) + assert(table.schema == new StructType().add("id", IntegerType)) } } - test("AlterTable: nested drop column must exist") { + test("AlterTable: nested drop column must exist if required") { val t = s"${catalogAndNamespace}table_name" withTable(t) { sql(s"CREATE TABLE $t (id int) USING $v2Format") @@ -1093,6 +1098,27 @@ trait AlterTableTests extends SharedSparkSession { } assert(exc.getMessage.contains("Missing field point.x")) + + // with if exists it should pass + sql(s"ALTER TABLE $t DROP COLUMN IF EXISTS point.x") + val table = getTableMetadata(fullTableName(t)) + assert(table.schema == new StructType().add("id", IntegerType)) + + } + } + + test("AlterTable: drop mixed existing/non-existing columns using IF EXISTS") { + val t = s"${catalogAndNamespace}table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, name string, points array>) " + + s"USING $v2Format") + + // with if exists it should pass + sql(s"ALTER TABLE $t DROP COLUMNS IF EXISTS " + + s"names, name, points.element.z, id, points.element.x") + val table = getTableMetadata(fullTableName(t)) + assert(table.schema == new StructType() + .add("points", ArrayType(StructType(Seq(StructField("y", DoubleType)))))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala index e1ea31d973ea..4aedac2b4afa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala @@ -286,10 +286,21 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes test("AlterTable: drop column resolution") { Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref => - alterTableTest( - DropColumns(table, Seq(UnresolvedFieldName(ref))), - Seq("Missing field " + ref.quoted) - ) + Seq(true, false).foreach { ifExists => + val expectedErrors = if (ifExists) { + Seq.empty[String] + } else { + Seq("Missing field " + ref.quoted) + } + val alter = DropColumns(table, Seq(UnresolvedFieldName(ref)), ifExists) + if (ifExists) { + // using IF EXISTS will silence all errors for missing columns + assertAnalysisSuccess(alter, caseSensitive = true) + assertAnalysisSuccess(alter, caseSensitive = false) + } else { + alterTableTest(alter, expectedErrors, expectErrorOnCaseSensitive = true) + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala index bae793bb0121..af0eafbc805e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala @@ -610,7 +610,7 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(table.schema == schema) val updated = catalog.alterTable(testIdent, - TableChange.deleteColumn(Array("id"))) + TableChange.deleteColumn(Array("id"), false)) val expectedSchema = new StructType().add("data", StringType) assert(updated.schema == expectedSchema) @@ -627,7 +627,7 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(table.schema == tableSchema) val updated = catalog.alterTable(testIdent, - TableChange.deleteColumn(Array("point", "y"))) + TableChange.deleteColumn(Array("point", "y"), false)) val newPointStruct = new StructType().add("x", DoubleType) val expectedSchema = schema.add("point", newPointStruct) @@ -643,11 +643,15 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(table.schema == schema) val exc = intercept[IllegalArgumentException] { - catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"))) + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"), false)) } assert(exc.getMessage.contains("missing_col")) assert(exc.getMessage.contains("Cannot find")) + + // with if exists it should pass + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"), true)) + assert(table.schema == schema) } test("alterTable: delete missing nested column fails") { @@ -661,11 +665,15 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { assert(table.schema == tableSchema) val exc = intercept[IllegalArgumentException] { - catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"))) + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"), false)) } assert(exc.getMessage.contains("z")) assert(exc.getMessage.contains("Cannot find")) + + // with if exists it should pass + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"), true)) + assert(table.schema == tableSchema) } test("alterTable: table does not exist") {