Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,11 @@ statement
RENAME COLUMN
from=multipartIdentifier TO to=errorCapturingIdentifier #renameTableColumn
| ALTER TABLE multipartIdentifier
DROP (COLUMN | COLUMNS)
DROP (COLUMN | COLUMNS) (IF EXISTS)?
LEFT_PAREN columns=multipartIdentifierList RIGHT_PAREN #dropTableColumns
| ALTER TABLE multipartIdentifier
DROP (COLUMN | COLUMNS) columns=multipartIdentifierList #dropTableColumns
DROP (COLUMN | COLUMNS) (IF EXISTS)?
columns=multipartIdentifierList #dropTableColumns
| ALTER (TABLE | VIEW) from=multipartIdentifier
RENAME TO to=multipartIdentifier #renameTable
| ALTER (TABLE | VIEW) multipartIdentifier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,11 @@ static TableChange updateColumnPosition(String[] fieldNames, ColumnPosition newP
* If the field does not exist, the change will result in an {@link IllegalArgumentException}.
*
* @param fieldNames field names of the column to delete
* @param ifExists silence the error if column doesn't exist during drop
* @return a TableChange for the delete
*/
static TableChange deleteColumn(String[] fieldNames) {
return new DeleteColumn(fieldNames);
static TableChange deleteColumn(String[] fieldNames, Boolean ifExists) {
return new DeleteColumn(fieldNames, ifExists);
}

/**
Expand Down Expand Up @@ -651,22 +652,26 @@ public int hashCode() {
*/
final class DeleteColumn implements ColumnChange {
private final String[] fieldNames;
private final Boolean ifExists;

private DeleteColumn(String[] fieldNames) {
private DeleteColumn(String[] fieldNames, Boolean ifExists) {
this.fieldNames = fieldNames;
this.ifExists = ifExists;
}

@Override
public String[] fieldNames() {
return fieldNames;
}

public Boolean ifExists() { return ifExists; }

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
DeleteColumn that = (DeleteColumn) o;
return Arrays.equals(fieldNames, that.fieldNames);
return Arrays.equals(fieldNames, that.fieldNames) && that.ifExists() == this.ifExists();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3668,6 +3668,12 @@ class Analyzer(override val catalogManager: CatalogManager)
case other => other
})

case a: DropColumns if a.table.resolved && hasUnresolvedFieldName(a) && a.ifExists =>
// for DropColumn with IF EXISTS clause, we should resolve and ignore missing column errors
val table = a.table.asInstanceOf[ResolvedTable]
val columnsToDrop = a.columnsToDrop
a.copy(columnsToDrop = columnsToDrop.flatMap(c => resolveFieldNamesOpt(table, c.name, c)))

case a: AlterTableCommand if a.table.resolved && hasUnresolvedFieldName(a) =>
val table = a.table.asInstanceOf[ResolvedTable]
a.transformExpressions {
Expand Down Expand Up @@ -3757,11 +3763,19 @@ class Analyzer(override val catalogManager: CatalogManager)
table: ResolvedTable,
fieldName: Seq[String],
context: Expression): ResolvedFieldName = {
resolveFieldNamesOpt(table, fieldName, context)
.getOrElse(throw QueryCompilationErrors.missingFieldError(fieldName, table, context.origin))
}

private def resolveFieldNamesOpt(
table: ResolvedTable,
fieldName: Seq[String],
context: Expression): Option[ResolvedFieldName] = {
table.schema.findNestedField(
fieldName, includeCollections = true, conf.resolver, context.origin
).map {
case (path, field) => ResolvedFieldName(path, field)
}.getOrElse(throw QueryCompilationErrors.missingFieldError(fieldName, table, context.origin))
}
}

private def hasUnresolvedFieldName(a: AlterTableCommand): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3930,12 +3930,14 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
*/
override def visitDropTableColumns(
ctx: DropTableColumnsContext): LogicalPlan = withOrigin(ctx) {
val ifExists = ctx.EXISTS() != null
val columnsToDrop = ctx.columns.multipartIdentifier.asScala.map(typedVisit[Seq[String]])
DropColumns(
createUnresolvedTable(
ctx.multipartIdentifier,
"ALTER TABLE ... DROP COLUMNS"),
columnsToDrop.map(UnresolvedFieldName(_)).toSeq)
columnsToDrop.map(UnresolvedFieldName(_)).toSeq,
ifExists)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ case class ReplaceColumns(
// REPLACE COLUMNS deletes all the existing columns and adds new columns specified.
require(table.resolved)
val deleteChanges = table.schema.fieldNames.map { name =>
TableChange.deleteColumn(Array(name))
// REPLACE COLUMN should require column to exist
TableChange.deleteColumn(Array(name), ifExists = false)
}
val addChanges = columnsToAdd.map { col =>
assert(col.path.isEmpty)
Expand All @@ -167,11 +168,12 @@ case class ReplaceColumns(
*/
case class DropColumns(
table: LogicalPlan,
columnsToDrop: Seq[FieldName]) extends AlterTableCommand {
columnsToDrop: Seq[FieldName],
ifExists: Boolean) extends AlterTableCommand {
override def changes: Seq[TableChange] = {
columnsToDrop.map { col =>
require(col.resolved, "FieldName should be resolved before it's converted to TableChange.")
TableChange.deleteColumn(col.name.toArray)
TableChange.deleteColumn(col.name.toArray, ifExists)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ private[sql] object CatalogV2Util {
}

case delete: DeleteColumn =>
replace(schema, delete.fieldNames, _ => None)
replace(schema, delete.fieldNames, _ => None, delete.ifExists)

case _ =>
// ignore non-schema changes
Expand Down Expand Up @@ -222,17 +222,28 @@ private[sql] object CatalogV2Util {
private def replace(
struct: StructType,
fieldNames: Seq[String],
update: StructField => Option[StructField]): StructType = {
update: StructField => Option[StructField],
ifExists: Boolean = false): StructType = {

val posOpt = struct.getFieldIndex(fieldNames.head)
if (posOpt.isEmpty) {
if (ifExists) {
// We couldn't find the column to replace, but with IF EXISTS, we will silence the error
// Currently only DROP COLUMN may pass down the IF EXISTS parameter
return struct
} else {
throw new IllegalArgumentException(s"Cannot find field: ${fieldNames.head}")
}
}

val pos = struct.getFieldIndex(fieldNames.head)
.getOrElse(throw new IllegalArgumentException(s"Cannot find field: ${fieldNames.head}"))
val pos = posOpt.get
val field = struct.fields(pos)
val replacement: Option[StructField] = (fieldNames.tail, field.dataType) match {
case (Seq(), _) =>
update(field)

case (names, struct: StructType) =>
val updatedType: StructType = replace(struct, names, update)
val updatedType: StructType = replace(struct, names, update, ifExists)
Some(StructField(field.name, updatedType, field.nullable, field.metadata))

case (Seq("key"), map @ MapType(keyType, _, _)) =>
Expand All @@ -241,7 +252,7 @@ private[sql] object CatalogV2Util {
Some(field.copy(dataType = map.copy(keyType = updated.dataType)))

case (Seq("key", names @ _*), map @ MapType(keyStruct: StructType, _, _)) =>
Some(field.copy(dataType = map.copy(keyType = replace(keyStruct, names, update))))
Some(field.copy(dataType = map.copy(keyType = replace(keyStruct, names, update, ifExists))))

case (Seq("value"), map @ MapType(_, mapValueType, isNullable)) =>
val updated = update(StructField("value", mapValueType, nullable = isNullable))
Expand All @@ -251,7 +262,8 @@ private[sql] object CatalogV2Util {
valueContainsNull = updated.nullable)))

case (Seq("value", names @ _*), map @ MapType(_, valueStruct: StructType, _)) =>
Some(field.copy(dataType = map.copy(valueType = replace(valueStruct, names, update))))
Some(field.copy(dataType = map.copy(valueType =
replace(valueStruct, names, update, ifExists))))

case (Seq("element"), array @ ArrayType(elementType, isNullable)) =>
val updated = update(StructField("element", elementType, nullable = isNullable))
Expand All @@ -261,11 +273,15 @@ private[sql] object CatalogV2Util {
containsNull = updated.nullable)))

case (Seq("element", names @ _*), array @ ArrayType(elementStruct: StructType, _)) =>
Some(field.copy(dataType = array.copy(elementType = replace(elementStruct, names, update))))
Some(field.copy(dataType = array.copy(elementType =
replace(elementStruct, names, update, ifExists))))

case (names, dataType) =>
throw new IllegalArgumentException(
s"Cannot find field: ${names.head} in ${dataType.simpleString}")
if (!ifExists) {
throw new IllegalArgumentException(
s"Cannot find field: ${names.head} in ${dataType.simpleString}")
}
None
}

val newFields = struct.fields.zipWithIndex.flatMap {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,15 @@ class DDLParserSuite extends AnalysisTest {
parsePlan("ALTER TABLE table_name DROP COLUMN a.b.c"),
DropColumns(
UnresolvedTable(Seq("table_name"), "ALTER TABLE ... DROP COLUMNS", None),
Seq(UnresolvedFieldName(Seq("a", "b", "c")))))
Seq(UnresolvedFieldName(Seq("a", "b", "c"))),
ifExists = false))

comparePlans(
parsePlan("ALTER TABLE table_name DROP COLUMN IF EXISTS a.b.c"),
DropColumns(
UnresolvedTable(Seq("table_name"), "ALTER TABLE ... DROP COLUMNS", None),
Seq(UnresolvedFieldName(Seq("a", "b", "c"))),
ifExists = true))
}

test("alter table: drop multiple columns") {
Expand All @@ -1034,7 +1042,20 @@ class DDLParserSuite extends AnalysisTest {
UnresolvedTable(Seq("table_name"), "ALTER TABLE ... DROP COLUMNS", None),
Seq(UnresolvedFieldName(Seq("x")),
UnresolvedFieldName(Seq("y")),
UnresolvedFieldName(Seq("a", "b", "c")))))
UnresolvedFieldName(Seq("a", "b", "c"))),
ifExists = false))
}

val sqlIfExists = "ALTER TABLE table_name DROP COLUMN IF EXISTS x, y, a.b.c"
Seq(sqlIfExists, sqlIfExists.replace("COLUMN", "COLUMNS")).foreach { drop =>
comparePlans(
parsePlan(drop),
DropColumns(
UnresolvedTable(Seq("table_name"), "ALTER TABLE ... DROP COLUMNS", None),
Seq(UnresolvedFieldName(Seq("x")),
UnresolvedFieldName(Seq("y")),
UnresolvedFieldName(Seq("a", "b", "c"))),
ifExists = true))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ class CatalogSuite extends SparkFunSuite {
assert(table.schema == schema)

val updated = catalog.alterTable(testIdent,
TableChange.deleteColumn(Array("id")))
TableChange.deleteColumn(Array("id"), false))

val expectedSchema = new StructType().add("data", StringType)
assert(updated.schema == expectedSchema)
Expand All @@ -567,7 +567,7 @@ class CatalogSuite extends SparkFunSuite {
assert(table.schema == tableSchema)

val updated = catalog.alterTable(testIdent,
TableChange.deleteColumn(Array("point", "y")))
TableChange.deleteColumn(Array("point", "y"), false))

val newPointStruct = new StructType().add("x", DoubleType)
val expectedSchema = schema.add("point", newPointStruct)
Expand All @@ -583,11 +583,15 @@ class CatalogSuite extends SparkFunSuite {
assert(table.schema == schema)

val exc = intercept[IllegalArgumentException] {
catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col")))
catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"), false))
}

assert(exc.getMessage.contains("missing_col"))
assert(exc.getMessage.contains("Cannot find"))

// with if exists it should pass
catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"), true))
assert(table.schema == schema)
}

test("alterTable: delete missing nested column fails") {
Expand All @@ -601,11 +605,15 @@ class CatalogSuite extends SparkFunSuite {
assert(table.schema == tableSchema)

val exc = intercept[IllegalArgumentException] {
catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z")))
catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"), false))
}

assert(exc.getMessage.contains("z"))
assert(exc.getMessage.contains("Cannot find"))

// with if exists it should pass
catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"), true))
assert(table.schema == tableSchema)
}

test("alterTable: table does not exist") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
case RenameColumn(ResolvedV1TableIdentifier(_), _, _) =>
throw QueryCompilationErrors.operationOnlySupportedWithV2TableError("RENAME COLUMN")

case DropColumns(ResolvedV1TableIdentifier(_), _) =>
case DropColumns(ResolvedV1TableIdentifier(_), _, _) =>
throw QueryCompilationErrors.operationOnlySupportedWithV2TableError("DROP COLUMN")

case SetTableProperties(ResolvedV1TableIdentifier(ident), props) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,7 @@ trait AlterTableTests extends SharedSparkSession {
}
}

test("AlterTable: drop column must exist") {
test("AlterTable: drop column must exist if required") {
val t = s"${catalogAndNamespace}table_name"
withTable(t) {
sql(s"CREATE TABLE $t (id int) USING $v2Format")
Expand All @@ -1080,10 +1080,15 @@ trait AlterTableTests extends SharedSparkSession {
}

assert(exc.getMessage.contains("Missing field data"))

// with if exists it should pass
sql(s"ALTER TABLE $t DROP COLUMN IF EXISTS data")
val table = getTableMetadata(fullTableName(t))
assert(table.schema == new StructType().add("id", IntegerType))
}
}

test("AlterTable: nested drop column must exist") {
test("AlterTable: nested drop column must exist if required") {
val t = s"${catalogAndNamespace}table_name"
withTable(t) {
sql(s"CREATE TABLE $t (id int) USING $v2Format")
Expand All @@ -1093,6 +1098,27 @@ trait AlterTableTests extends SharedSparkSession {
}

assert(exc.getMessage.contains("Missing field point.x"))

// with if exists it should pass
sql(s"ALTER TABLE $t DROP COLUMN IF EXISTS point.x")
val table = getTableMetadata(fullTableName(t))
assert(table.schema == new StructType().add("id", IntegerType))

}
}

test("AlterTable: drop mixed existing/non-existing columns using IF EXISTS") {
val t = s"${catalogAndNamespace}table_name"
withTable(t) {
sql(s"CREATE TABLE $t (id int, name string, points array<struct<x: double, y: double>>) " +
s"USING $v2Format")

// with if exists it should pass
sql(s"ALTER TABLE $t DROP COLUMNS IF EXISTS " +
s"names, name, points.element.z, id, points.element.x")
val table = getTableMetadata(fullTableName(t))
assert(table.schema == new StructType()
.add("points", ArrayType(StructType(Seq(StructField("y", DoubleType))))))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,21 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes

test("AlterTable: drop column resolution") {
Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref =>
alterTableTest(
DropColumns(table, Seq(UnresolvedFieldName(ref))),
Seq("Missing field " + ref.quoted)
)
Seq(true, false).foreach { ifExists =>
val expectedErrors = if (ifExists) {
Seq.empty[String]
} else {
Seq("Missing field " + ref.quoted)
}
val alter = DropColumns(table, Seq(UnresolvedFieldName(ref)), ifExists)
if (ifExists) {
// using IF EXISTS will silence all errors for missing columns
assertAnalysisSuccess(alter, caseSensitive = true)
assertAnalysisSuccess(alter, caseSensitive = false)
} else {
alterTableTest(alter, expectedErrors, expectErrorOnCaseSensitive = true)
}
}
}
}

Expand Down
Loading