Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3023,9 +3023,29 @@ class Analyzer(
object ResolveAlterTableChanges extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case a @ AlterTable(_, _, t: NamedRelation, changes) if t.resolved =>
// 'colsToAdd' keeps track of new columns being added. It stores a mapping from a
// normalized parent name of fields to field names that belong to the parent.
// For example, if we add columns "a.b.c", "a.b.d", and "a.c", 'colsToAdd' will become
// Map(Seq("a", "b") -> Seq("c", "d"), Seq("a") -> Seq("c")).
val colsToAdd = mutable.Map.empty[Seq[String], Seq[String]]
val schema = t.schema
val normalizedChanges = changes.flatMap {
case add: AddColumn =>
def addColumn(
parentSchema: StructType,
parentName: String,
normalizedParentName: Seq[String]): TableChange = {
val fieldsAdded = colsToAdd.getOrElse(normalizedParentName, Nil)
val pos = findColumnPosition(add.position(), parentName, parentSchema, fieldsAdded)
val field = add.fieldNames().last
colsToAdd(normalizedParentName) = fieldsAdded :+ field
TableChange.addColumn(
(normalizedParentName :+ field).toArray,
add.dataType(),
add.isNullable,
add.comment,
pos)
}
val parent = add.fieldNames().init
if (parent.nonEmpty) {
// Adding a nested field, need to normalize the parent column and position
Expand All @@ -3037,27 +3057,14 @@ class Analyzer(
val (normalizedName, sf) = target.get
sf.dataType match {
case struct: StructType =>
val pos = findColumnPosition(add.position(), parent.quoted, struct)
Some(TableChange.addColumn(
(normalizedName ++ Seq(sf.name, add.fieldNames().last)).toArray,
add.dataType(),
add.isNullable,
add.comment,
pos))

Some(addColumn(struct, parent.quoted, normalizedName :+ sf.name))
case other =>
Some(add)
}
}
} else {
// Adding to the root. Just need to normalize position
val pos = findColumnPosition(add.position(), "root", schema)
Some(TableChange.addColumn(
add.fieldNames(),
add.dataType(),
add.isNullable,
add.comment,
pos))
Some(addColumn(schema, "root", Nil))
}

case typeChange: UpdateColumnType =>
Expand Down Expand Up @@ -3156,17 +3163,18 @@ class Analyzer(

private def findColumnPosition(
position: ColumnPosition,
field: String,
struct: StructType): ColumnPosition = {
parentName: String,
struct: StructType,
fieldsAdded: Seq[String]): ColumnPosition = {
position match {
case null => null
case after: After =>
struct.fieldNames.find(n => conf.resolver(n, after.column())) match {
(struct.fieldNames ++ fieldsAdded).find(n => conf.resolver(n, after.column())) match {
case Some(colName) =>
ColumnPosition.after(colName)
case None =>
throw new AnalysisException("Couldn't find the reference column for " +
s"$after at $field")
s"$after at $parentName")
}
case other => other
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,16 @@ trait CheckAnalysis extends PredicateHelper {
}
field.get._2
}
def positionArgumentExists(position: ColumnPosition, struct: StructType): Unit = {
def positionArgumentExists(
position: ColumnPosition,
struct: StructType,
fieldsAdded: Seq[String]): Unit = {
position match {
case after: After =>
if (!struct.fieldNames.contains(after.column())) {
val allFields = struct.fieldNames ++ fieldsAdded
if (!allFields.contains(after.column())) {
alter.failAnalysis(s"Couldn't resolve positional argument $position amongst " +
s"${struct.fieldNames.mkString("[", ", ", "]")}")
s"${allFields.mkString("[", ", ", "]")}")
}
case _ =>
}
Expand Down Expand Up @@ -474,6 +478,11 @@ trait CheckAnalysis extends PredicateHelper {
}

val colsToDelete = mutable.Set.empty[Seq[String]]
// 'colsToAdd' keeps track of new columns being added. It stores a mapping from a parent
// name of fields to field names that belong to the parent. For example, if we add
// columns "a.b.c", "a.b.d", and "a.c", 'colsToAdd' will become
// Map(Seq("a", "b") -> Seq("c", "d"), Seq("a") -> Seq("c")).
val colsToAdd = mutable.Map.empty[Seq[String], Seq[String]]

alter.changes.foreach {
case add: AddColumn =>
Expand All @@ -483,8 +492,11 @@ trait CheckAnalysis extends PredicateHelper {
checkColumnNotExists("add", add.fieldNames(), table.schema)
}
val parent = findParentStruct("add", add.fieldNames())
positionArgumentExists(add.position(), parent)
val parentName = add.fieldNames().init
val fieldsAdded = colsToAdd.getOrElse(parentName, Nil)
positionArgumentExists(add.position(), parent, fieldsAdded)
TypeUtils.failWithIntervalType(add.dataType())
colsToAdd(parentName) = fieldsAdded :+ add.fieldNames().last
case update: UpdateColumnType =>
val field = findField("update", update.fieldNames)
val fieldName = update.fieldNames.quoted
Expand Down Expand Up @@ -523,7 +535,11 @@ trait CheckAnalysis extends PredicateHelper {
case updatePos: UpdateColumnPosition =>
findField("update", updatePos.fieldNames)
val parent = findParentStruct("update", updatePos.fieldNames())
positionArgumentExists(updatePos.position(), parent)
val parentName = updatePos.fieldNames().init
positionArgumentExists(
updatePos.position(),
parent,
colsToAdd.getOrElse(parentName, Nil))
case rename: RenameColumn =>
findField("rename", rename.fieldNames)
checkColumnNotExists(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,42 @@ trait AlterTableTests extends SharedSparkSession {
}
}

test("SPARK-30814: add column with position referencing new columns being added") {
val t = s"${catalogAndNamespace}table_name"
withTable(t) {
sql(s"CREATE TABLE $t (a string, b int, point struct<x: double, y: double>) USING $v2Format")
sql(s"ALTER TABLE $t ADD COLUMNS (x int AFTER a, y int AFTER x, z int AFTER y)")

assert(getTableMetadata(t).schema === new StructType()
.add("a", StringType)
.add("x", IntegerType)
.add("y", IntegerType)
.add("z", IntegerType)
.add("b", IntegerType)
.add("point", new StructType()
.add("x", DoubleType)
.add("y", DoubleType)))

sql(s"ALTER TABLE $t ADD COLUMNS (point.z double AFTER x, point.zz double AFTER z)")
assert(getTableMetadata(t).schema === new StructType()
.add("a", StringType)
.add("x", IntegerType)
.add("y", IntegerType)
.add("z", IntegerType)
.add("b", IntegerType)
.add("point", new StructType()
.add("x", DoubleType)
.add("z", DoubleType)
.add("zz", DoubleType)
.add("y", DoubleType)))

// The new column being referenced should come before being referenced.
val e = intercept[AnalysisException](
sql(s"ALTER TABLE $t ADD COLUMNS (yy int AFTER xx, xx int)"))
assert(e.getMessage().contains("Couldn't find the reference column for AFTER xx at root"))
}
}

test("AlterTable: add multiple columns") {
val t = s"${catalogAndNamespace}table_name"
withTable(t) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
}
}

test("AlterTable: add column resolution - column position referencing new column") {
alterTableTest(
Seq(
TableChange.addColumn(
Array("x"), LongType, true, null, ColumnPosition.after("id")),
TableChange.addColumn(
Array("y"), LongType, true, null, ColumnPosition.after("X"))),
Seq("Couldn't find the reference column for AFTER X at root")
)
}

test("AlterTable: add column resolution - nested positional") {
Seq("X", "Y").foreach { ref =>
alterTableTest(
Expand All @@ -161,6 +172,17 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
}
}

test("AlterTable: add column resolution - column position referencing new nested column") {
alterTableTest(
Seq(
TableChange.addColumn(
Array("point", "z"), LongType, true, null),
TableChange.addColumn(
Array("point", "zz"), LongType, true, null, ColumnPosition.after("Z"))),
Seq("Couldn't find the reference column for AFTER Z at point")
)
}

test("AlterTable: drop column resolution") {
Seq(Array("ID"), Array("point", "X"), Array("POINT", "X"), Array("POINT", "x")).foreach { ref =>
alterTableTest(
Expand Down Expand Up @@ -207,13 +229,17 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes
}

private def alterTableTest(change: TableChange, error: Seq[String]): Unit = {
alterTableTest(Seq(change), error)
}

private def alterTableTest(changes: Seq[TableChange], error: Seq[String]): Unit = {
Seq(true, false).foreach { caseSensitive =>
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
val plan = AlterTable(
catalog,
Identifier.of(Array(), "table_name"),
TestRelation2,
Seq(change)
changes
)

if (caseSensitive) {
Expand Down