diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala index 8964a2776b097..be39c3f10e612 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala @@ -49,28 +49,22 @@ object OptimizeUpdateFields extends Rule[LogicalPlan] { val values = withFields.map(_.valExpr) val newNames = mutable.ArrayBuffer.empty[String] - val newValues = mutable.ArrayBuffer.empty[Expression] + val newValues = mutable.HashMap.empty[String, Expression] + // Used to remember the casing of the last instance + val nameMap = mutable.HashMap.empty[String, String] - if (caseSensitive) { - names.zip(values).reverse.foreach { case (name, value) => - if (!newNames.contains(name)) { - newNames += name - newValues += value - } - } - } else { - val nameSet = mutable.HashSet.empty[String] - names.zip(values).reverse.foreach { case (name, value) => - val lowercaseName = name.toLowerCase(Locale.ROOT) - if (!nameSet.contains(lowercaseName)) { - newNames += name - newValues += value - nameSet += lowercaseName - } + names.zip(values).foreach { case (name, value) => + val normalizedName = if (caseSensitive) name else name.toLowerCase(Locale.ROOT) + if (nameMap.contains(normalizedName)) { + newValues += normalizedName -> value + } else { + newNames += normalizedName + newValues += normalizedName -> value } + nameMap += normalizedName -> name } - val newWithFields = newNames.reverse.zip(newValues.reverse).map(p => WithField(p._1, p._2)) + val newWithFields = newNames.map(n => WithField(nameMap(n), newValues(n))) UpdateFields(structExpr, newWithFields.toSeq) case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala index b093b39cc4b88..e63742ac0de56 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala @@ -126,4 +126,25 @@ class OptimizeWithFieldsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } } + + test("SPARK-35213: ensure optimize WithFields maintains correct WithField ordering") { + val originalQuery = testRelation + .select( + Alias(UpdateFields('a, + WithField("a1", Literal(3)) :: + WithField("b1", Literal(4)) :: + WithField("a1", Literal(5)) :: + Nil), "out")()) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select( + Alias(UpdateFields('a, + WithField("a1", Literal(5)) :: + WithField("b1", Literal(4)) :: + Nil), "out")()) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 5108c0169b68b..ad5d73c774274 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -1686,6 +1686,61 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructType(Seq(StructField("a", IntegerType, nullable = true)))) } + test("SPARK-35213: chained withField operations should have correct schema for new columns") { + val df = spark.createDataFrame( + sparkContext.parallelize(Row(null) :: Nil), + StructType(Seq(StructField("data", NullType)))) + + checkAnswer( + df.withColumn("data", struct() + .withField("a", struct()) + .withField("b", struct()) + .withField("a.aa", lit("aa1")) + .withField("b.ba", lit("ba1")) + .withField("a.ab", lit("ab1"))), + Row(Row(Row("aa1", "ab1"), Row("ba1"))) :: Nil, + StructType(Seq( + StructField("data", StructType(Seq( + StructField("a", StructType(Seq( + StructField("aa", StringType, nullable = false), + StructField("ab", StringType, nullable = false) + )), nullable = false), + StructField("b", StructType(Seq( + StructField("ba", StringType, nullable = false) + )), nullable = false) + )), nullable = false) + )) + ) + } + + test("SPARK-35213: optimized withField operations should maintain correct nested struct " + + "ordering") { + val df = spark.createDataFrame( + sparkContext.parallelize(Row(null) :: Nil), + StructType(Seq(StructField("data", NullType)))) + + checkAnswer( + df.withColumn("data", struct() + .withField("a", struct().withField("aa", lit("aa1"))) + .withField("b", struct().withField("ba", lit("ba1"))) + ) + .withColumn("data", col("data").withField("b.bb", lit("bb1"))) + .withColumn("data", col("data").withField("a.ab", lit("ab1"))), + Row(Row(Row("aa1", "ab1"), Row("ba1", "bb1"))) :: Nil, + StructType(Seq( + StructField("data", StructType(Seq( + StructField("a", StructType(Seq( + StructField("aa", StringType, nullable = false), + StructField("ab", StringType, nullable = false) + )), nullable = false), + StructField("b", StructType(Seq( + StructField("ba", StringType, nullable = false), + StructField("bb", StringType, nullable = false) + )), nullable = false) + )), nullable = false) + )) + ) + } test("dropFields should throw an exception if called on a non-StructType column") { intercept[AnalysisException] {