Skip to content

Commit 74afc68

Browse files
Kimahrimanviirya
authored andcommitted
[SPARK-35213][SQL] Keep the correct ordering of nested structs in chained withField operations
### What changes were proposed in this pull request? Modifies the UpdateFields optimizer to fix correctness issues with certain nested and chained withField operations. Examples for recreating the issue are in the new unit tests as well as the JIRA issue. ### Why are the changes needed? Certain withField patterns can cause Exceptions or even incorrect results. It appears to be a result of the additional UpdateFields optimization added in #29812. It traverses fieldOps in reverse order to take the last one per field, but this can cause nested structs to change order which leads to mismatches between the schema and the actual data. This updates the optimization to maintain the initial ordering of nested structs to match the generated schema. ### Does this PR introduce _any_ user-facing change? It fixes exceptions and incorrect results for valid uses in the latest Spark release. ### How was this patch tested? Added new unit tests for these edge cases. Closes #32338 from Kimahriman/bug/optimize-with-fields. Authored-by: Adam Binford <[email protected]> Signed-off-by: Liang-Chi Hsieh <[email protected]>
1 parent d572a85 commit 74afc68

File tree

3 files changed

+88
-18
lines changed

3 files changed

+88
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,28 +49,22 @@ object OptimizeUpdateFields extends Rule[LogicalPlan] {
4949
val values = withFields.map(_.valExpr)
5050

5151
val newNames = mutable.ArrayBuffer.empty[String]
52-
val newValues = mutable.ArrayBuffer.empty[Expression]
52+
val newValues = mutable.HashMap.empty[String, Expression]
53+
// Used to remember the casing of the last instance
54+
val nameMap = mutable.HashMap.empty[String, String]
5355

54-
if (caseSensitive) {
55-
names.zip(values).reverse.foreach { case (name, value) =>
56-
if (!newNames.contains(name)) {
57-
newNames += name
58-
newValues += value
59-
}
60-
}
61-
} else {
62-
val nameSet = mutable.HashSet.empty[String]
63-
names.zip(values).reverse.foreach { case (name, value) =>
64-
val lowercaseName = name.toLowerCase(Locale.ROOT)
65-
if (!nameSet.contains(lowercaseName)) {
66-
newNames += name
67-
newValues += value
68-
nameSet += lowercaseName
69-
}
56+
names.zip(values).foreach { case (name, value) =>
57+
val normalizedName = if (caseSensitive) name else name.toLowerCase(Locale.ROOT)
58+
if (nameMap.contains(normalizedName)) {
59+
newValues += normalizedName -> value
60+
} else {
61+
newNames += normalizedName
62+
newValues += normalizedName -> value
7063
}
64+
nameMap += normalizedName -> name
7165
}
7266

73-
val newWithFields = newNames.reverse.zip(newValues.reverse).map(p => WithField(p._1, p._2))
67+
val newWithFields = newNames.map(n => WithField(nameMap(n), newValues(n)))
7468
UpdateFields(structExpr, newWithFields.toSeq)
7569

7670
case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,4 +126,25 @@ class OptimizeWithFieldsSuite extends PlanTest {
126126
comparePlans(optimized, correctAnswer)
127127
}
128128
}
129+
130+
test("SPARK-35213: ensure optimize WithFields maintains correct WithField ordering") {
131+
val originalQuery = testRelation
132+
.select(
133+
Alias(UpdateFields('a,
134+
WithField("a1", Literal(3)) ::
135+
WithField("b1", Literal(4)) ::
136+
WithField("a1", Literal(5)) ::
137+
Nil), "out")())
138+
139+
val optimized = Optimize.execute(originalQuery.analyze)
140+
val correctAnswer = testRelation
141+
.select(
142+
Alias(UpdateFields('a,
143+
WithField("a1", Literal(5)) ::
144+
WithField("b1", Literal(4)) ::
145+
Nil), "out")())
146+
.analyze
147+
148+
comparePlans(optimized, correctAnswer)
149+
}
129150
}

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,6 +1686,61 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
16861686
StructType(Seq(StructField("a", IntegerType, nullable = true))))
16871687
}
16881688

1689+
test("SPARK-35213: chained withField operations should have correct schema for new columns") {
1690+
val df = spark.createDataFrame(
1691+
sparkContext.parallelize(Row(null) :: Nil),
1692+
StructType(Seq(StructField("data", NullType))))
1693+
1694+
checkAnswer(
1695+
df.withColumn("data", struct()
1696+
.withField("a", struct())
1697+
.withField("b", struct())
1698+
.withField("a.aa", lit("aa1"))
1699+
.withField("b.ba", lit("ba1"))
1700+
.withField("a.ab", lit("ab1"))),
1701+
Row(Row(Row("aa1", "ab1"), Row("ba1"))) :: Nil,
1702+
StructType(Seq(
1703+
StructField("data", StructType(Seq(
1704+
StructField("a", StructType(Seq(
1705+
StructField("aa", StringType, nullable = false),
1706+
StructField("ab", StringType, nullable = false)
1707+
)), nullable = false),
1708+
StructField("b", StructType(Seq(
1709+
StructField("ba", StringType, nullable = false)
1710+
)), nullable = false)
1711+
)), nullable = false)
1712+
))
1713+
)
1714+
}
1715+
1716+
test("SPARK-35213: optimized withField operations should maintain correct nested struct " +
1717+
"ordering") {
1718+
val df = spark.createDataFrame(
1719+
sparkContext.parallelize(Row(null) :: Nil),
1720+
StructType(Seq(StructField("data", NullType))))
1721+
1722+
checkAnswer(
1723+
df.withColumn("data", struct()
1724+
.withField("a", struct().withField("aa", lit("aa1")))
1725+
.withField("b", struct().withField("ba", lit("ba1")))
1726+
)
1727+
.withColumn("data", col("data").withField("b.bb", lit("bb1")))
1728+
.withColumn("data", col("data").withField("a.ab", lit("ab1"))),
1729+
Row(Row(Row("aa1", "ab1"), Row("ba1", "bb1"))) :: Nil,
1730+
StructType(Seq(
1731+
StructField("data", StructType(Seq(
1732+
StructField("a", StructType(Seq(
1733+
StructField("aa", StringType, nullable = false),
1734+
StructField("ab", StringType, nullable = false)
1735+
)), nullable = false),
1736+
StructField("b", StructType(Seq(
1737+
StructField("ba", StringType, nullable = false),
1738+
StructField("bb", StringType, nullable = false)
1739+
)), nullable = false)
1740+
)), nullable = false)
1741+
))
1742+
)
1743+
}
16891744

16901745
test("dropFields should throw an exception if called on a non-StructType column") {
16911746
intercept[AnalysisException] {

0 commit comments

Comments
 (0)