diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index dd654cde1996..17701ca1dbc8 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -96,6 +96,8 @@ license: | - In Spark 3.2, `FloatType` is mapped to `FLOAT` in MySQL. Prior to this, it used to be mapped to `REAL`, which is by default a synonym to `DOUBLE PRECISION` in MySQL. - In Spark 3.2, the query executions triggered by `DataFrameWriter` are always named `command` when being sent to `QueryExecutionListener`. In Spark 3.1 and earlier, the name is one of `save`, `insertInto`, `saveAsTable`. + + - In Spark 3.2, `Dataset.unionByName` with `allowMissingColumns` set to true will add missing nested fields to the end of structs. In Spark 3.1, nested struct fields are sorted alphabetically. ## Upgrading from Spark SQL 3.0 to 3.1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 08cc61f81900..2574bf7ab485 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -20,136 +20,63 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.optimizer.{CombineUnions, OptimizeUpdateFields} +import org.apache.spark.sql.catalyst.optimizer.{CombineUnions} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.UNION import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils -import org.apache.spark.unsafe.types.UTF8String /** * Resolves different children of Union to a common set of columns. */ object ResolveUnion extends Rule[LogicalPlan] { /** - * This method sorts columns recursively in a struct expression based on column names. + * Adds missing fields recursively into given `col` expression, based on the expected struct + * fields from merging the two schemas. This is called by `compareAndAddFields` when we find two + * struct columns with same name but different nested fields. This method will recursively + * return a new struct with all of the expected fields, adding null values when `col` doesn't + * already contain them. Currently we don't support merging structs nested inside of arrays + * or maps. */ - private def sortStructFields(expr: Expression): Expression = { - val existingExprs = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { - case (name, i) => - val fieldExpr = GetStructField(KnownNotNull(expr), i) - if (fieldExpr.dataType.isInstanceOf[StructType]) { - (name, sortStructFields(fieldExpr)) - } else { - (name, fieldExpr) - } - }.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2)) + private def addFields(col: Expression, targetType: StructType): Expression = { + assert(col.dataType.isInstanceOf[StructType], "Only support StructType.") - val newExpr = CreateNamedStruct(existingExprs) - if (expr.nullable) { - If(IsNull(expr), Literal(null, newExpr.dataType), newExpr) - } else { - newExpr - } - } + val resolver = conf.resolver + val colType = col.dataType.asInstanceOf[StructType] - /** - * Assumes input expressions are field expression of `CreateNamedStruct`. This method - * sorts the expressions based on field names. - */ - private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = { - fieldExprs.grouped(2).map { e => - Seq(e.head, e.last) - }.toSeq.sortBy { pair => - assert(pair.head.isInstanceOf[Literal]) - pair.head.eval().asInstanceOf[UTF8String].toString - }.flatten - } + val newStructFields = mutable.ArrayBuffer.empty[Expression] - /** - * This helper method sorts fields in a `UpdateFields` expression by field name. - */ - private def sortStructFieldsInWithFields(expr: Expression): Expression = expr transformUp { - case u: UpdateFields if u.resolved => - u.evalExpr match { - case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) => - val sorted = sortFieldExprs(fieldExprs) - val newStruct = CreateNamedStruct(sorted) - i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = newStruct) - case CreateNamedStruct(fieldExprs) => - val sorted = sortFieldExprs(fieldExprs) - val newStruct = CreateNamedStruct(sorted) - newStruct - case other => - throw new IllegalStateException(s"`UpdateFields` has incorrect expression: $other. " + - "Please file a bug report with this error message, stack trace, and the query.") - } - } + targetType.fields.foreach { expectedField => + val currentField = colType.fields.find(f => resolver(f.name, expectedField.name)) - /** - * Adds missing fields recursively into given `col` expression, based on the target `StructType`. - * This is called by `compareAndAddFields` when we find two struct columns with same name but - * different nested fields. This method will find out the missing nested fields from `col` to - * `target` struct and add these missing nested fields. Currently we don't support finding out - * missing nested fields of struct nested in array or struct nested in map. - */ - private def addFields(col: NamedExpression, target: StructType): Expression = { - assert(col.dataType.isInstanceOf[StructType], "Only support StructType.") + val newExpression = (currentField, expectedField.dataType) match { + case (Some(cf), expectedType: StructType) if cf.dataType.isInstanceOf[StructType] => + val extractedValue = ExtractValue(col, Literal(cf.name), resolver) + addFields(extractedValue, expectedType) + case (Some(cf), _) => + ExtractValue(col, Literal(cf.name), resolver) + case (None, expectedType) => + Literal(null, expectedType) + } + newStructFields ++= Literal(expectedField.name) :: newExpression :: Nil + } - val resolver = conf.resolver - val missingFieldsOpt = - StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver) + colType.fields + .filter(f => targetType.fields.find(tf => resolver(f.name, tf.name)).isEmpty) + .foreach { f => + newStructFields ++= Literal(f.name) :: ExtractValue(col, Literal(f.name), resolver) :: Nil + } - // We need to sort columns in result, because we might add another column in other side. - // E.g., we want to union two structs "a int, b long" and "a int, c string". - // If we don't sort, we will have "a int, b long, c string" and - // "a int, c string, b long", which are not compatible. - if (missingFieldsOpt.isEmpty) { - sortStructFields(col) + val newStruct = CreateNamedStruct(newStructFields.toSeq) + if (col.nullable) { + If(IsNull(col), Literal(null, newStruct.dataType), newStruct) } else { - missingFieldsOpt.map { s => - val struct = addFieldsInto(col, s.fields) - // Combines `WithFields`s to reduce expression tree. - val reducedStruct = struct.transformUp(OptimizeUpdateFields.optimizeUpdateFields) - val sorted = sortStructFieldsInWithFields(reducedStruct) - sorted - }.get + newStruct } } - /** - * Adds missing fields recursively into given `col` expression. The missing fields are given - * in `fields`. For example, given `col` as "z struct, x int", and `fields` is - * "z struct, w string". This method will add a nested `z.w` field and a top-level - * `w` field to `col` and fill null values for them. Note that because we might also add missing - * fields at other side of Union, we must make sure corresponding attributes at two sides have - * same field order in structs, so when we adding missing fields, we will sort the fields based on - * field names. So the data type of returned expression will be - * "w string, x int, z struct". - */ - private def addFieldsInto( - col: Expression, - fields: Seq[StructField]): Expression = { - fields.foldLeft(col) { case (currCol, field) => - field.dataType match { - case st: StructType => - val resolver = conf.resolver - val colField = currCol.dataType.asInstanceOf[StructType] - .find(f => resolver(f.name, field.name)) - if (colField.isEmpty) { - // The whole struct is missing. Add a null. - UpdateFields(currCol, field.name, Literal(null, st)) - } else { - UpdateFields(currCol, field.name, - addFieldsInto(ExtractValue(currCol, Literal(field.name), resolver), st.fields)) - } - case dt => - UpdateFields(currCol, field.name, Literal(null, dt)) - } - } - } /** * This method will compare right to left plan's outputs. If there is one struct attribute @@ -175,11 +102,9 @@ object ResolveUnion extends Rule[LogicalPlan] { (foundDt, lattr.dataType) match { case (source: StructType, target: StructType) if allowMissingCol && !source.sameType(target) => - // Having an output with same name, but different struct type. - // We need to add missing fields. Note that if there are deeply nested structs such as - // nested struct of array in struct, we don't support to add missing deeply nested field - // like that. We will sort columns in the struct expression to make sure two sides of - // union have consistent schema. + // We have two structs with different types, so make sure the two structs have their + // fields in the same order by using `target`'s fields and then inluding any remaining + // in `foundAttr`. aliased += foundAttr Alias(addFields(foundAttr, target), foundAttr.name)() case _ => @@ -208,13 +133,11 @@ object ResolveUnion extends Rule[LogicalPlan] { left: LogicalPlan, right: LogicalPlan, allowMissingCol: Boolean): LogicalPlan = { - val rightOutputAttrs = right.output - // Builds a project list for `right` based on `left` output names val (rightProjectList, aliased) = compareAndAddFields(left, right, allowMissingCol) // Delegates failure checks to `CheckAnalysis` - val notFoundAttrs = rightOutputAttrs.diff(rightProjectList ++ aliased) + val notFoundAttrs = right.output.diff(rightProjectList ++ aliased) val rightChild = Project(rightProjectList ++ notFoundAttrs, right) // Builds a project for `logicalPlan` based on `right` output names, if allowing @@ -230,6 +153,7 @@ object ResolveUnion extends Rule[LogicalPlan] { } else { left } + Union(leftChild, rightChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 53651a20dc49..0b393604cd17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2081,10 +2081,8 @@ class Dataset[T] private[sql]( * }}} * * Note that `allowMissingColumns` supports nested column in struct types. Missing nested columns - * of struct columns with same name will also be filled with null values. This currently does not - * support nested columns in array and map types. Note that if there is any missing nested columns - * to be filled, in order to make consistent schema between two sides of union, the nested fields - * of structs will be sorted after merging schema. + * of struct columns with the same name will also be filled with null values and added to the end + * of struct. This currently does not support nested columns in array and map types. * * @group typedrel * @since 3.1.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index 797673ae15ba..e622528afc69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -677,27 +677,30 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") val df2 = Seq((1, UnionClass1b(1, 2L, UnionClass3(2, 3L)))).toDF("id", "a") - val expectedSchema = "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + - "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>" - var unionDf = df1.unionByName(df2, true) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `c`: STRING, `b`: BIGINT>>") checkAnswer(unionDf, - Row(0, Row(0, 1, Row(1, null, "2"))) :: - Row(1, Row(1, 2, Row(2, 3L, null))) :: Nil) - assert(unionDf.schema.toDDL == expectedSchema) + Row(0, Row(0, 1, Row(1, "2", null))) :: + Row(1, Row(1, 2, Row(2, null, 3L))) :: Nil) unionDf = df2.unionByName(df1, true) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>") checkAnswer(unionDf, Row(1, Row(1, 2, Row(2, 3L, null))) :: Row(0, Row(0, 1, Row(1, null, "2"))) :: Nil) - assert(unionDf.schema.toDDL == expectedSchema) val df3 = Seq((2, UnionClass1b(2, 3L, null))).toDF("id", "a") unionDf = df1.unionByName(df3, true) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `c`: STRING, `b`: BIGINT>>") checkAnswer(unionDf, - Row(0, Row(0, 1, Row(1, null, "2"))) :: + Row(0, Row(0, 1, Row(1, "2", null))) :: Row(2, Row(2, 3, null)) :: Nil) - assert(unionDf.schema.toDDL == expectedSchema) } test("SPARK-32376: Make unionByName null-filling behavior work with struct columns" + @@ -707,29 +710,29 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { val df2 = Seq((1, UnionClass1c(1, 2L, UnionClass4(2, 3L)))).toDF("id", "a") var unionDf = df1.unionByName(df2, true) - checkAnswer(unionDf, - Row(0, Row(0, 1, Row(null, 1, null, "2"))) :: - Row(1, Row(1, 2, Row(2, null, 3L, null))) :: Nil) assert(unionDf.schema.toDDL == "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + - "`nested`: STRUCT<`A`: INT, `a`: INT, `b`: BIGINT, `c`: STRING>>") + "`nested`: STRUCT<`a`: INT, `c`: STRING, `A`: INT, `b`: BIGINT>>") + checkAnswer(unionDf, + Row(0, Row(0, 1, Row(1, "2", null, null))) :: + Row(1, Row(1, 2, Row(null, null, 2, 3L))) :: Nil) unionDf = df2.unionByName(df1, true) - checkAnswer(unionDf, - Row(1, Row(1, 2, Row(2, null, 3L, null))) :: - Row(0, Row(0, 1, Row(null, 1, null, "2"))) :: Nil) assert(unionDf.schema.toDDL == "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + - "`nested`: STRUCT<`A`: INT, `a`: INT, `b`: BIGINT, `c`: STRING>>") + "`nested`: STRUCT<`A`: INT, `b`: BIGINT, `a`: INT, `c`: STRING>>") + checkAnswer(unionDf, + Row(1, Row(1, 2, Row(2, 3L, null, null))) :: + Row(0, Row(0, 1, Row(null, null, 1, "2"))) :: Nil) val df3 = Seq((2, UnionClass1b(2, 3L, UnionClass3(4, 5L)))).toDF("id", "a") unionDf = df2.unionByName(df3, true) - checkAnswer(unionDf, - Row(1, Row(1, 2, Row(2, null, 3L))) :: - Row(2, Row(2, 3, Row(null, 4, 5L))) :: Nil) assert(unionDf.schema.toDDL == "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + - "`nested`: STRUCT<`A`: INT, `a`: INT, `b`: BIGINT>>") + "`nested`: STRUCT<`A`: INT, `b`: BIGINT, `a`: INT>>") + checkAnswer(unionDf, + Row(1, Row(1, 2, Row(2, 3L, null))) :: + Row(2, Row(2, 3, Row(null, 5L, 4))) :: Nil) } } @@ -743,17 +746,59 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { StructField("a", StringType))) val nestedStructValues2 = Row("b", "a") - val df1: DataFrame = spark.createDataFrame( + val df1 = spark.createDataFrame( sparkContext.parallelize(Row(nestedStructValues1) :: Nil), StructType(Seq(StructField("topLevelCol", nestedStructType1)))) - val df2: DataFrame = spark.createDataFrame( + val df2 = spark.createDataFrame( sparkContext.parallelize(Row(nestedStructValues2) :: Nil), StructType(Seq(StructField("topLevelCol", nestedStructType2)))) val union = df1.unionByName(df2, allowMissingColumns = true) - checkAnswer(union, Row(Row(null, "b")) :: Row(Row("a", "b")) :: Nil) - assert(union.schema.toDDL == "`topLevelCol` STRUCT<`a`: STRING, `b`: STRING>") + assert(union.schema.toDDL == "`topLevelCol` STRUCT<`b`: STRING, `a`: STRING>") + checkAnswer(union, Row(Row("b", null)) :: Row(Row("b", "a")) :: Nil) + } + + test("SPARK-35290: Make unionByName null-filling behavior work with struct columns" + + " - sorting edge case") { + val nestedStructType1 = StructType(Seq( + StructField("b", StructType(Seq( + StructField("ba", StringType) + ))) + )) + val nestedStructValues1 = Row(Row("ba")) + + val nestedStructType2 = StructType(Seq( + StructField("a", StructType(Seq( + StructField("aa", StringType) + ))), + StructField("b", StructType(Seq( + StructField("bb", StringType) + ))) + )) + val nestedStructValues2 = Row(Row("aa"), Row("bb")) + + val df1 = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues1) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType1)))) + + val df2 = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues2) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType2)))) + + var unionDf = df1.unionByName(df2, true) + assert(unionDf.schema.toDDL == "`topLevelCol` " + + "STRUCT<`b`: STRUCT<`ba`: STRING, `bb`: STRING>, `a`: STRUCT<`aa`: STRING>>") + checkAnswer(unionDf, + Row(Row(Row("ba", null), null)) :: + Row(Row(Row(null, "bb"), Row("aa"))) :: Nil) + + unionDf = df2.unionByName(df1, true) + assert(unionDf.schema.toDDL == "`topLevelCol` STRUCT<`a`: STRUCT<`aa`: STRING>, " + + "`b`: STRUCT<`bb`: STRING, `ba`: STRING>>") + checkAnswer(unionDf, + Row(Row(null, Row(null, "ba"))) :: + Row(Row(Row("aa"), Row("bb", null))) :: Nil) } test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - deep expr") { @@ -777,7 +822,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { depthCounter -= 1 } - val df: DataFrame = spark.createDataFrame( + val df = spark.createDataFrame( sparkContext.parallelize(Row(struct) :: Nil), StructType(Seq(StructField("nested0Col0", structType)))) @@ -800,16 +845,16 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null)) val row2 = Row(Row(Row(Row(Row(Row(Row(Row(Row(Row( - Row(0, 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9)) + Row(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)) // scalastyle:on checkAnswer(union, row1 :: row2 :: Nil) }