diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ce82b3b567b54..aec71747fde68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3023,9 +3023,29 @@ class Analyzer( object ResolveAlterTableChanges extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { case a @ AlterTable(_, _, t: NamedRelation, changes) if t.resolved => + // 'colsToAdd' keeps track of new columns being added. It stores a mapping from a + // normalized parent name of fields to field names that belong to the parent. + // For example, if we add columns "a.b.c", "a.b.d", and "a.c", 'colsToAdd' will become + // Map(Seq("a", "b") -> Seq("c", "d"), Seq("a") -> Seq("c")). + val colsToAdd = mutable.Map.empty[Seq[String], Seq[String]] val schema = t.schema val normalizedChanges = changes.flatMap { case add: AddColumn => + def addColumn( + parentSchema: StructType, + parentName: String, + normalizedParentName: Seq[String]): TableChange = { + val fieldsAdded = colsToAdd.getOrElse(normalizedParentName, Nil) + val pos = findColumnPosition(add.position(), parentName, parentSchema, fieldsAdded) + val field = add.fieldNames().last + colsToAdd(normalizedParentName) = fieldsAdded :+ field + TableChange.addColumn( + (normalizedParentName :+ field).toArray, + add.dataType(), + add.isNullable, + add.comment, + pos) + } val parent = add.fieldNames().init if (parent.nonEmpty) { // Adding a nested field, need to normalize the parent column and position @@ -3037,27 +3057,14 @@ class Analyzer( val (normalizedName, sf) = target.get sf.dataType match { case struct: StructType => - val pos = findColumnPosition(add.position(), parent.quoted, struct) - Some(TableChange.addColumn( - (normalizedName ++ Seq(sf.name, add.fieldNames().last)).toArray, - add.dataType(), - add.isNullable, - add.comment, - pos)) - + Some(addColumn(struct, parent.quoted, normalizedName :+ sf.name)) case other => Some(add) } } } else { // Adding to the root. Just need to normalize position - val pos = findColumnPosition(add.position(), "root", schema) - Some(TableChange.addColumn( - add.fieldNames(), - add.dataType(), - add.isNullable, - add.comment, - pos)) + Some(addColumn(schema, "root", Nil)) } case typeChange: UpdateColumnType => @@ -3156,17 +3163,18 @@ class Analyzer( private def findColumnPosition( position: ColumnPosition, - field: String, - struct: StructType): ColumnPosition = { + parentName: String, + struct: StructType, + fieldsAdded: Seq[String]): ColumnPosition = { position match { case null => null case after: After => - struct.fieldNames.find(n => conf.resolver(n, after.column())) match { + (struct.fieldNames ++ fieldsAdded).find(n => conf.resolver(n, after.column())) match { case Some(colName) => ColumnPosition.after(colName) case None => throw new AnalysisException("Couldn't find the reference column for " + - s"$after at $field") + s"$after at $parentName") } case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 67c509ed98245..066dc6db0227d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -440,12 +440,16 @@ trait CheckAnalysis extends PredicateHelper { } field.get._2 } - def positionArgumentExists(position: ColumnPosition, struct: StructType): Unit = { + def positionArgumentExists( + position: ColumnPosition, + struct: StructType, + fieldsAdded: Seq[String]): Unit = { position match { case after: After => - if (!struct.fieldNames.contains(after.column())) { + val allFields = struct.fieldNames ++ fieldsAdded + if (!allFields.contains(after.column())) { alter.failAnalysis(s"Couldn't resolve positional argument $position amongst " + - s"${struct.fieldNames.mkString("[", ", ", "]")}") + s"${allFields.mkString("[", ", ", "]")}") } case _ => } @@ -474,6 +478,11 @@ trait CheckAnalysis extends PredicateHelper { } val colsToDelete = mutable.Set.empty[Seq[String]] + // 'colsToAdd' keeps track of new columns being added. It stores a mapping from a parent + // name of fields to field names that belong to the parent. For example, if we add + // columns "a.b.c", "a.b.d", and "a.c", 'colsToAdd' will become + // Map(Seq("a", "b") -> Seq("c", "d"), Seq("a") -> Seq("c")). + val colsToAdd = mutable.Map.empty[Seq[String], Seq[String]] alter.changes.foreach { case add: AddColumn => @@ -483,8 +492,11 @@ trait CheckAnalysis extends PredicateHelper { checkColumnNotExists("add", add.fieldNames(), table.schema) } val parent = findParentStruct("add", add.fieldNames()) - positionArgumentExists(add.position(), parent) + val parentName = add.fieldNames().init + val fieldsAdded = colsToAdd.getOrElse(parentName, Nil) + positionArgumentExists(add.position(), parent, fieldsAdded) TypeUtils.failWithIntervalType(add.dataType()) + colsToAdd(parentName) = fieldsAdded :+ add.fieldNames().last case update: UpdateColumnType => val field = findField("update", update.fieldNames) val fieldName = update.fieldNames.quoted @@ -523,7 +535,11 @@ trait CheckAnalysis extends PredicateHelper { case updatePos: UpdateColumnPosition => findField("update", updatePos.fieldNames) val parent = findParentStruct("update", updatePos.fieldNames()) - positionArgumentExists(updatePos.position(), parent) + val parentName = updatePos.fieldNames().init + positionArgumentExists( + updatePos.position(), + parent, + colsToAdd.getOrElse(parentName, Nil)) case rename: RenameColumn => findField("rename", rename.fieldNames) checkColumnNotExists( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index 96fe301b512ea..d04a1fca6387c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -173,6 +173,42 @@ trait AlterTableTests extends SharedSparkSession { } } + test("SPARK-30814: add column with position referencing new columns being added") { + val t = s"${catalogAndNamespace}table_name" + withTable(t) { + sql(s"CREATE TABLE $t (a string, b int, point struct) USING $v2Format") + sql(s"ALTER TABLE $t ADD COLUMNS (x int AFTER a, y int AFTER x, z int AFTER y)") + + assert(getTableMetadata(t).schema === new StructType() + .add("a", StringType) + .add("x", IntegerType) + .add("y", IntegerType) + .add("z", IntegerType) + .add("b", IntegerType) + .add("point", new StructType() + .add("x", DoubleType) + .add("y", DoubleType))) + + sql(s"ALTER TABLE $t ADD COLUMNS (point.z double AFTER x, point.zz double AFTER z)") + assert(getTableMetadata(t).schema === new StructType() + .add("a", StringType) + .add("x", IntegerType) + .add("y", IntegerType) + .add("z", IntegerType) + .add("b", IntegerType) + .add("point", new StructType() + .add("x", DoubleType) + .add("z", DoubleType) + .add("zz", DoubleType) + .add("y", DoubleType))) + + // The new column being referenced should come before being referenced. + val e = intercept[AnalysisException]( + sql(s"ALTER TABLE $t ADD COLUMNS (yy int AFTER xx, xx int)")) + assert(e.getMessage().contains("Couldn't find the reference column for AFTER xx at root")) + } + } + test("AlterTable: add multiple columns") { val t = s"${catalogAndNamespace}table_name" withTable(t) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala index 289f9dc427795..dd95ceb59bdc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala @@ -151,6 +151,17 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes } } + test("AlterTable: add column resolution - column position referencing new column") { + alterTableTest( + Seq( + TableChange.addColumn( + Array("x"), LongType, true, null, ColumnPosition.after("id")), + TableChange.addColumn( + Array("y"), LongType, true, null, ColumnPosition.after("X"))), + Seq("Couldn't find the reference column for AFTER X at root") + ) + } + test("AlterTable: add column resolution - nested positional") { Seq("X", "Y").foreach { ref => alterTableTest( @@ -161,6 +172,17 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes } } + test("AlterTable: add column resolution - column position referencing new nested column") { + alterTableTest( + Seq( + TableChange.addColumn( + Array("point", "z"), LongType, true, null), + TableChange.addColumn( + Array("point", "zz"), LongType, true, null, ColumnPosition.after("Z"))), + Seq("Couldn't find the reference column for AFTER Z at point") + ) + } + test("AlterTable: drop column resolution") { Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref => alterTableTest( @@ -207,13 +229,17 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes } private def alterTableTest(change: TableChange, error: Seq[String]): Unit = { + alterTableTest(Seq(change), error) + } + + private def alterTableTest(changes: Seq[TableChange], error: Seq[String]): Unit = { Seq(true, false).foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { val plan = AlterTable( catalog, Identifier.of(Array(), "table_name"), TestRelation2, - Seq(change) + changes ) if (caseSensitive) {