Skip to content

Commit c01f565

Browse files
committed
address comments
1 parent ea58952 commit c01f565

File tree

3 files changed

+92
-30
lines changed

3 files changed

+92
-30
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ public int hashCode() {
301301
interface ColumnPosition {
302302

303303
static ColumnPosition first() {
304-
return First.singleton;
304+
return First.SINGLETON;
305305
}
306306

307307
static ColumnPosition after(String column) {
@@ -315,7 +315,7 @@ static ColumnPosition after(String column) {
315315
* be the first one within the struct.
316316
*/
317317
final class First implements ColumnPosition {
318-
private static First singleton = new First();
318+
private static final First SINGLETON = new First();
319319

320320
private First() {}
321321

sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ private[sql] object CatalogV2Util {
140140
case update: UpdateColumnPosition =>
141141
def updateFieldPos(struct: StructType, name: String): StructType = {
142142
val oldField = struct.fields.find(_.name == name).getOrElse {
143-
throw new IllegalArgumentException("field not found: " + name)
143+
throw new IllegalArgumentException("Field not found: " + name)
144144
}
145145
val withFieldRemoved = StructType(struct.fields.filter(_ != oldField))
146146
addField(withFieldRemoved, oldField, update.position())
@@ -153,6 +153,8 @@ private[sql] object CatalogV2Util {
153153
replace(schema, names.init, parent => parent.dataType match {
154154
case parentType: StructType =>
155155
Some(parent.copy(dataType = updateFieldPos(parentType, names.last)))
156+
case _ =>
157+
throw new IllegalArgumentException(s"Not a struct: ${names.init.last}")
156158
})
157159
}
158160

@@ -176,7 +178,11 @@ private[sql] object CatalogV2Util {
176178
StructType(field +: schema.fields)
177179
} else {
178180
val afterCol = position.asInstanceOf[After].column()
179-
val (before, after) = schema.fields.span(_.name == afterCol)
181+
val fieldIndex = schema.fields.indexWhere(_.name == afterCol)
182+
if (fieldIndex == -1) {
183+
throw new IllegalArgumentException("AFTER column not found: " + afterCol)
184+
}
185+
val (before, after) = schema.fields.splitAt(fieldIndex + 1)
180186
StructType(before ++ (field +: after))
181187
}
182188
}

sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala

Lines changed: 82 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -104,21 +104,43 @@ trait AlterTableTests extends SharedSparkSession {
104104
test("AlterTable: add column with position") {
105105
val t = s"${catalogAndNamespace}table_name"
106106
withTable(t) {
107-
sql(s"CREATE TABLE $t (id struct<x: int>) USING $v2Format")
107+
sql(s"CREATE TABLE $t (point struct<x: int>) USING $v2Format")
108108

109109
sql(s"ALTER TABLE $t ADD COLUMN a string FIRST")
110-
assert(getTableMetadata(t).schema.names.toSeq == Seq("a", "id"))
111-
112-
sql(s"ALTER TABLE $t ADD COLUMN b string AFTER a")
113-
assert(getTableMetadata(t).schema.names.toSeq == Seq("a", "b", "id"))
114-
115-
sql(s"ALTER TABLE $t ADD COLUMN id.y string FIRST")
116-
assert(getTableMetadata(t).schema.last.dataType.asInstanceOf[StructType].names.toSeq ==
117-
Seq("y", "x"))
118-
119-
sql(s"ALTER TABLE $t ADD COLUMN id.z string AFTER y")
120-
assert(getTableMetadata(t).schema.last.dataType.asInstanceOf[StructType].names.toSeq ==
121-
Seq("y", "z", "x"))
110+
assert(getTableMetadata(t).schema == new StructType()
111+
.add("a", StringType)
112+
.add("point", new StructType().add("x", IntegerType)))
113+
114+
sql(s"ALTER TABLE $t ADD COLUMN b string AFTER point")
115+
assert(getTableMetadata(t).schema == new StructType()
116+
.add("a", StringType)
117+
.add("point", new StructType().add("x", IntegerType))
118+
.add("b", StringType))
119+
120+
val e1 = intercept[SparkException](
121+
sql(s"ALTER TABLE $t ADD COLUMN c string AFTER non_exist"))
122+
assert(e1.getMessage().contains("AFTER column not found"))
123+
124+
sql(s"ALTER TABLE $t ADD COLUMN point.y int FIRST")
125+
assert(getTableMetadata(t).schema == new StructType()
126+
.add("a", StringType)
127+
.add("point", new StructType()
128+
.add("y", IntegerType)
129+
.add("x", IntegerType))
130+
.add("b", StringType))
131+
132+
sql(s"ALTER TABLE $t ADD COLUMN point.z int AFTER x")
133+
assert(getTableMetadata(t).schema == new StructType()
134+
.add("a", StringType)
135+
.add("point", new StructType()
136+
.add("y", IntegerType)
137+
.add("x", IntegerType)
138+
.add("z", IntegerType))
139+
.add("b", StringType))
140+
141+
val e2 = intercept[SparkException](
142+
sql(s"ALTER TABLE $t ADD COLUMN point.x2 int AFTER non_exist"))
143+
assert(e2.getMessage().contains("AFTER column not found"))
122144
}
123145
}
124146

@@ -495,21 +517,55 @@ trait AlterTableTests extends SharedSparkSession {
495517
test("AlterTable: update column position") {
496518
val t = s"${catalogAndNamespace}table_name"
497519
withTable(t) {
498-
sql(s"CREATE TABLE $t (a int, b struct<x: int, y: int>) USING $v2Format")
520+
sql(s"CREATE TABLE $t (a int, b int, point struct<x: int, y: int, z: int>) USING $v2Format")
499521

500522
sql(s"ALTER TABLE $t ALTER COLUMN b FIRST")
501-
assert(getTableMetadata(t).schema().names.toSeq == Seq("b", "a"))
502-
503-
sql(s"ALTER TABLE $t ALTER COLUMN b AFTER a")
504-
assert(getTableMetadata(t).schema().names.toSeq == Seq("a", "b"))
505-
506-
sql(s"ALTER TABLE $t ALTER COLUMN b.y FIRST")
507-
assert(getTableMetadata(t).schema.apply("b").dataType.asInstanceOf[StructType].names.toSeq ==
508-
Seq("y", "x"))
509-
510-
sql(s"ALTER TABLE $t ALTER COLUMN b.y AFTER x")
511-
assert(getTableMetadata(t).schema.apply("b").dataType.asInstanceOf[StructType].names.toSeq ==
512-
Seq("x", "y"))
523+
assert(getTableMetadata(t).schema == new StructType()
524+
.add("b", IntegerType)
525+
.add("a", IntegerType)
526+
.add("point", new StructType()
527+
.add("x", IntegerType)
528+
.add("y", IntegerType)
529+
.add("z", IntegerType)))
530+
531+
sql(s"ALTER TABLE $t ALTER COLUMN b AFTER point")
532+
assert(getTableMetadata(t).schema == new StructType()
533+
.add("a", IntegerType)
534+
.add("point", new StructType()
535+
.add("x", IntegerType)
536+
.add("y", IntegerType)
537+
.add("z", IntegerType))
538+
.add("b", IntegerType))
539+
540+
val e1 = intercept[SparkException](
541+
sql(s"ALTER TABLE $t ALTER COLUMN b AFTER non_exist"))
542+
assert(e1.getMessage.contains("AFTER column not found"))
543+
544+
sql(s"ALTER TABLE $t ALTER COLUMN point.y FIRST")
545+
assert(getTableMetadata(t).schema == new StructType()
546+
.add("a", IntegerType)
547+
.add("point", new StructType()
548+
.add("y", IntegerType)
549+
.add("x", IntegerType)
550+
.add("z", IntegerType))
551+
.add("b", IntegerType))
552+
553+
sql(s"ALTER TABLE $t ALTER COLUMN point.y AFTER z")
554+
assert(getTableMetadata(t).schema == new StructType()
555+
.add("a", IntegerType)
556+
.add("point", new StructType()
557+
.add("x", IntegerType)
558+
.add("z", IntegerType)
559+
.add("y", IntegerType))
560+
.add("b", IntegerType))
561+
562+
val e2 = intercept[SparkException](
563+
sql(s"ALTER TABLE $t ALTER COLUMN point.y AFTER non_exist"))
564+
assert(e2.getMessage.contains("AFTER column not found"))
565+
566+
// `AlterTable.resolved` checks column existence.
567+
intercept[AnalysisException](
568+
sql(s"ALTER TABLE $t ALTER COLUMN a.y AFTER x"))
513569
}
514570
}
515571

0 commit comments

Comments
 (0)