diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 693a5a4e7544..c1a9c9d3d9ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -17,29 +17,188 @@ package org.apache.spark.sql.catalyst.analysis +import scala.collection.mutable + import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils +import org.apache.spark.unsafe.types.UTF8String /** * Resolves different children of Union to a common set of columns. */ object ResolveUnion extends Rule[LogicalPlan] { - private def unionTwoSides( + /** + * This method sorts columns recursively in a struct expression based on column names. + */ + private def sortStructFields(expr: Expression): Expression = { + val existingExprs = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { + case (name, i) => + val fieldExpr = GetStructField(KnownNotNull(expr), i) + if (fieldExpr.dataType.isInstanceOf[StructType]) { + (name, sortStructFields(fieldExpr)) + } else { + (name, fieldExpr) + } + }.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2)) + + val newExpr = CreateNamedStruct(existingExprs) + if (expr.nullable) { + If(IsNull(expr), Literal(null, newExpr.dataType), newExpr) + } else { + newExpr + } + } + + /** + * Assumes input expressions are field expression of `CreateNamedStruct`. This method + * sorts the expressions based on field names. + */ + private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = { + fieldExprs.grouped(2).map { e => + Seq(e.head, e.last) + }.toSeq.sortBy { pair => + assert(pair.head.isInstanceOf[Literal]) + pair.head.eval().asInstanceOf[UTF8String].toString + }.flatten + } + + /** + * This helper method sorts fields in a `UpdateFields` expression by field name. + */ + private def sortStructFieldsInWithFields(expr: Expression): Expression = expr transformUp { + case u: UpdateFields if u.resolved => + u.evalExpr match { + case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) => + val sorted = sortFieldExprs(fieldExprs) + val newStruct = CreateNamedStruct(sorted) + i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = newStruct) + case CreateNamedStruct(fieldExprs) => + val sorted = sortFieldExprs(fieldExprs) + val newStruct = CreateNamedStruct(sorted) + newStruct + case other => + throw new IllegalStateException(s"`UpdateFields` has incorrect expression: $other. " + + "Please file a bug report with this error message, stack trace, and the query.") + } + } + + def simplifyWithFields(expr: Expression): Expression = { + expr.transformUp { + case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) => + UpdateFields(struct, fieldOps1 ++ fieldOps2) + } + } + + /** + * Adds missing fields recursively into given `col` expression, based on the target `StructType`. + * This is called by `compareAndAddFields` when we find two struct columns with same name but + * different nested fields. This method will find out the missing nested fields from `col` to + * `target` struct and add these missing nested fields. Currently we don't support finding out + * missing nested fields of struct nested in array or struct nested in map. + */ + private def addFields(col: NamedExpression, target: StructType): Expression = { + assert(col.dataType.isInstanceOf[StructType], "Only support StructType.") + + val resolver = SQLConf.get.resolver + val missingFieldsOpt = + StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver) + + // We need to sort columns in result, because we might add another column in other side. + // E.g., we want to union two structs "a int, b long" and "a int, c string". + // If we don't sort, we will have "a int, b long, c string" and + // "a int, c string, b long", which are not compatible. + if (missingFieldsOpt.isEmpty) { + sortStructFields(col) + } else { + missingFieldsOpt.map { s => + val struct = addFieldsInto(col, s.fields) + // Combines `WithFields`s to reduce expression tree. + val reducedStruct = simplifyWithFields(struct) + val sorted = sortStructFieldsInWithFields(reducedStruct) + sorted + }.get + } + } + + /** + * Adds missing fields recursively into given `col` expression. The missing fields are given + * in `fields`. For example, given `col` as "z struct, x int", and `fields` is + * "z struct, w string". This method will add a nested `z.w` field and a top-level + * `w` field to `col` and fill null values for them. Note that because we might also add missing + * fields at other side of Union, we must make sure corresponding attributes at two sides have + * same field order in structs, so when we adding missing fields, we will sort the fields based on + * field names. So the data type of returned expression will be + * "w string, x int, z struct". + */ + private def addFieldsInto( + col: Expression, + fields: Seq[StructField]): Expression = { + fields.foldLeft(col) { case (currCol, field) => + field.dataType match { + case st: StructType => + val resolver = SQLConf.get.resolver + val colField = currCol.dataType.asInstanceOf[StructType] + .find(f => resolver(f.name, field.name)) + if (colField.isEmpty) { + // The whole struct is missing. Add a null. + UpdateFields(currCol, field.name, Literal(null, st)) + } else { + UpdateFields(currCol, field.name, + addFieldsInto(ExtractValue(currCol, Literal(field.name), resolver), st.fields)) + } + case dt => + UpdateFields(currCol, field.name, Literal(null, dt)) + } + } + } + + /** + * This method will compare right to left plan's outputs. If there is one struct attribute + * at right side has same name with left side struct attribute, but two structs are not the + * same data type, i.e., some missing (nested) fields at right struct attribute, then this + * method will try to add missing (nested) fields into the right attribute with null values. + */ + private def compareAndAddFields( left: LogicalPlan, right: LogicalPlan, - allowMissingCol: Boolean): LogicalPlan = { + allowMissingCol: Boolean): (Seq[NamedExpression], Seq[NamedExpression]) = { val resolver = SQLConf.get.resolver val leftOutputAttrs = left.output val rightOutputAttrs = right.output - // Builds a project list for `right` based on `left` output names + val aliased = mutable.ArrayBuffer.empty[Attribute] + val rightProjectList = leftOutputAttrs.map { lattr => - rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse { + val found = rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) } + if (found.isDefined) { + val foundAttr = found.get + val foundDt = foundAttr.dataType + (foundDt, lattr.dataType) match { + case (source: StructType, target: StructType) + if allowMissingCol && !source.sameType(target) => + // Having an output with same name, but different struct type. + // We need to add missing fields. Note that if there are deeply nested structs such as + // nested struct of array in struct, we don't support to add missing deeply nested field + // like that. We will sort columns in the struct expression to make sure two sides of + // union have consistent schema. + aliased += foundAttr + Alias(addFields(foundAttr, target), foundAttr.name)() + case _ => + // We don't need/try to add missing fields if: + // 1. The attributes of left and right side are the same struct type + // 2. The attributes are not struct types. They might be primitive types, or array, map + // types. We don't support adding missing fields of nested structs in array or map + // types now. + // 3. `allowMissingCol` is disabled. + foundAttr + } + } else { if (allowMissingCol) { Alias(Literal(null, lattr.dataType), lattr.name)() } else { @@ -50,18 +209,29 @@ object ResolveUnion extends Rule[LogicalPlan] { } } + (rightProjectList, aliased.toSeq) + } + + private def unionTwoSides( + left: LogicalPlan, + right: LogicalPlan, + allowMissingCol: Boolean): LogicalPlan = { + val rightOutputAttrs = right.output + + // Builds a project list for `right` based on `left` output names + val (rightProjectList, aliased) = compareAndAddFields(left, right, allowMissingCol) + // Delegates failure checks to `CheckAnalysis` - val notFoundAttrs = rightOutputAttrs.diff(rightProjectList) + val notFoundAttrs = rightOutputAttrs.diff(rightProjectList ++ aliased) val rightChild = Project(rightProjectList ++ notFoundAttrs, right) // Builds a project for `logicalPlan` based on `right` output names, if allowing // missing columns. val leftChild = if (allowMissingCol) { - val missingAttrs = notFoundAttrs.map { attr => - Alias(Literal(null, attr.dataType), attr.name)() - } - if (missingAttrs.nonEmpty) { - Project(leftOutputAttrs ++ missingAttrs, left) + // Add missing (nested) fields to left plan. + val (leftProjectList, _) = compareAndAddFields(rightChild, left, allowMissingCol) + if (leftProjectList.map(_.toAttribute) != left.output) { + Project(leftProjectList, left) } else { left } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index f6485a51f8fa..3958cfd0af2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -661,3 +662,52 @@ case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperat } } } + +object UpdateFields { + private def nameParts(fieldName: String): Seq[String] = { + require(fieldName != null, "fieldName cannot be null") + + if (fieldName.isEmpty) { + fieldName :: Nil + } else { + CatalystSqlParser.parseMultipartIdentifier(fieldName) + } + } + + /** + * Adds/replaces field of `StructType` into `col` expression by name. + */ + def apply(col: Expression, fieldName: String, expr: Expression): UpdateFields = { + updateFieldsHelper(col, nameParts(fieldName), name => WithField(name, expr)) + } + + /** + * Drops fields of `StructType` in `col` expression by name. + */ + def apply(col: Expression, fieldName: String): UpdateFields = { + updateFieldsHelper(col, nameParts(fieldName), name => DropField(name)) + } + + private def updateFieldsHelper( + structExpr: Expression, + namePartsRemaining: Seq[String], + valueFunc: String => StructFieldsOperation) : UpdateFields = { + val fieldName = namePartsRemaining.head + if (namePartsRemaining.length == 1) { + UpdateFields(structExpr, valueFunc(fieldName) :: Nil) + } else { + val newStruct = if (structExpr.resolved) { + val resolver = SQLConf.get.resolver + ExtractValue(structExpr, Literal(fieldName), resolver) + } else { + UnresolvedExtractValue(structExpr, Literal(fieldName)) + } + + val newValue = updateFieldsHelper( + structExpr = newStruct, + namePartsRemaining = namePartsRemaining.tail, + valueFunc = valueFunc) + UpdateFields(structExpr, WithField(fieldName, newValue) :: Nil) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 89ff4facd25a..60afe140960c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -116,8 +116,10 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] s"$child.${name.getOrElse(fieldName)}" } + def extractFieldName: String = name.getOrElse(childSchema(ordinal).name) + override def sql: String = - child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}" + child.sql + s".${quoteIdentifier(extractFieldName)}" protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b14fb04cc453..c5e76c160ff4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -641,4 +641,39 @@ object StructType extends AbstractDataType { fields.foreach(s => map.put(s.name, s)) map } + + /** + * Returns a `StructType` that contains missing fields recursively from `source` to `target`. + * Note that this doesn't support looking into array type and map type recursively. + */ + def findMissingFields( + source: StructType, + target: StructType, + resolver: Resolver): Option[StructType] = { + def bothStructType(dt1: DataType, dt2: DataType): Boolean = + dt1.isInstanceOf[StructType] && dt2.isInstanceOf[StructType] + + val newFields = mutable.ArrayBuffer.empty[StructField] + + target.fields.foreach { field => + val found = source.fields.find(f => resolver(field.name, f.name)) + if (found.isEmpty) { + // Found a missing field in `source`. + newFields += field + } else if (bothStructType(found.get.dataType, field.dataType) && + !found.get.dataType.sameType(field.dataType)) { + // Found a field with same name, but different data type. + findMissingFields(found.get.dataType.asInstanceOf[StructType], + field.dataType.asInstanceOf[StructType], resolver).map { missingType => + newFields += found.get.copy(dataType = missingType) + } + } + } + + if (newFields.isEmpty) { + None + } else { + Some(StructType(newFields.toSeq)) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index 6824a64badc1..645e65f06508 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.types import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType.fromDDL -class StructTypeSuite extends SparkFunSuite { +class StructTypeSuite extends SparkFunSuite with SQLHelper { private val s = StructType.fromDDL("a INT, b STRING") @@ -103,4 +105,96 @@ class StructTypeSuite extends SparkFunSuite { val interval = "`a` INTERVAL" assert(fromDDL(interval).toDDL === interval) } + + test("find missing (nested) fields") { + val schema = StructType.fromDDL("c1 INT, c2 STRUCT>") + val resolver = SQLConf.get.resolver + + val source1 = StructType.fromDDL("c1 INT") + val missing1 = StructType.fromDDL("c2 STRUCT>") + assert(StructType.findMissingFields(source1, schema, resolver) + .exists(_.sameType(missing1))) + + val source2 = StructType.fromDDL("c1 INT, c3 STRING") + val missing2 = StructType.fromDDL("c2 STRUCT>") + assert(StructType.findMissingFields(source2, schema, resolver) + .exists(_.sameType(missing2))) + + val source3 = StructType.fromDDL("c1 INT, c2 STRUCT") + val missing3 = StructType.fromDDL("c2 STRUCT>") + assert(StructType.findMissingFields(source3, schema, resolver) + .exists(_.sameType(missing3))) + + val source4 = StructType.fromDDL("c1 INT, c2 STRUCT>") + val missing4 = StructType.fromDDL("c2 STRUCT>") + assert(StructType.findMissingFields(source4, schema, resolver) + .exists(_.sameType(missing4))) + } + + test("find missing (nested) fields: array and map") { + val resolver = SQLConf.get.resolver + + val schemaWithArray = StructType.fromDDL("c1 INT, c2 ARRAY>") + val source5 = StructType.fromDDL("c1 INT") + val missing5 = StructType.fromDDL("c2 ARRAY>") + assert( + StructType.findMissingFields(source5, schemaWithArray, resolver) + .exists(_.sameType(missing5))) + + val schemaWithMap1 = StructType.fromDDL( + "c1 INT, c2 MAP, STRING>, c3 LONG") + val source6 = StructType.fromDDL("c1 INT, c3 LONG") + val missing6 = StructType.fromDDL("c2 MAP, STRING>") + assert( + StructType.findMissingFields(source6, schemaWithMap1, resolver) + .exists(_.sameType(missing6))) + + val schemaWithMap2 = StructType.fromDDL( + "c1 INT, c2 MAP>, c3 STRING") + val source7 = StructType.fromDDL("c1 INT, c3 STRING") + val missing7 = StructType.fromDDL("c2 MAP>") + assert( + StructType.findMissingFields(source7, schemaWithMap2, resolver) + .exists(_.sameType(missing7))) + + // Unsupported: nested struct in array, map + val source8 = StructType.fromDDL("c1 INT, c2 ARRAY>") + // `findMissingFields` doesn't support looking into nested struct in array type. + assert(StructType.findMissingFields(source8, schemaWithArray, resolver).isEmpty) + + val source9 = StructType.fromDDL("c1 INT, c2 MAP, STRING>, c3 LONG") + // `findMissingFields` doesn't support looking into nested struct in map type. + assert(StructType.findMissingFields(source9, schemaWithMap1, resolver).isEmpty) + + val source10 = StructType.fromDDL("c1 INT, c2 MAP>, c3 STRING") + // `findMissingFields` doesn't support looking into nested struct in map type. + assert(StructType.findMissingFields(source10, schemaWithMap2, resolver).isEmpty) + } + + test("find missing (nested) fields: case sensitive cases") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val schema = StructType.fromDDL("c1 INT, c2 STRUCT>") + val resolver = SQLConf.get.resolver + + val source1 = StructType.fromDDL("c1 INT, C2 LONG") + val missing1 = StructType.fromDDL("c2 STRUCT>") + assert(StructType.findMissingFields(source1, schema, resolver) + .exists(_.sameType(missing1))) + + val source2 = StructType.fromDDL("c2 LONG") + val missing2 = StructType.fromDDL("c1 INT") + assert(StructType.findMissingFields(source2, schema, resolver) + .exists(_.sameType(missing2))) + + val source3 = StructType.fromDDL("c1 INT, c2 STRUCT>") + val missing3 = StructType.fromDDL("c2 STRUCT>") + assert(StructType.findMissingFields(source3, schema, resolver) + .exists(_.sameType(missing3))) + + val source4 = StructType.fromDDL("c1 INT, c2 STRUCT>") + val missing4 = StructType.fromDDL("c2 STRUCT>") + assert(StructType.findMissingFields(source4, schema, resolver) + .exists(_.sameType(missing4))) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index a46d6c0bb228..30792c9bacd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -925,7 +925,7 @@ class Column(val expr: Expression) extends Logging { def withField(fieldName: String, col: Column): Column = withExpr { require(fieldName != null, "fieldName cannot be null") require(col != null, "col cannot be null") - updateFieldsHelper(expr, nameParts(fieldName), name => WithField(name, col.expr)) + UpdateFields(expr, fieldName, col.expr) } // scalastyle:off line.size.limit @@ -989,38 +989,8 @@ class Column(val expr: Expression) extends Logging { */ // scalastyle:on line.size.limit def dropFields(fieldNames: String*): Column = withExpr { - def dropField(structExpr: Expression, fieldName: String): UpdateFields = - updateFieldsHelper(structExpr, nameParts(fieldName), name => DropField(name)) - - fieldNames.tail.foldLeft(dropField(expr, fieldNames.head)) { - (resExpr, fieldName) => dropField(resExpr, fieldName) - } - } - - private def nameParts(fieldName: String): Seq[String] = { - require(fieldName != null, "fieldName cannot be null") - - if (fieldName.isEmpty) { - fieldName :: Nil - } else { - CatalystSqlParser.parseMultipartIdentifier(fieldName) - } - } - - private def updateFieldsHelper( - structExpr: Expression, - namePartsRemaining: Seq[String], - valueFunc: String => StructFieldsOperation): UpdateFields = { - - val fieldName = namePartsRemaining.head - if (namePartsRemaining.length == 1) { - UpdateFields(structExpr, valueFunc(fieldName) :: Nil) - } else { - val newValue = updateFieldsHelper( - structExpr = UnresolvedExtractValue(structExpr, Literal(fieldName)), - namePartsRemaining = namePartsRemaining.tail, - valueFunc = valueFunc) - UpdateFields(structExpr, WithField(fieldName, newValue) :: Nil) + fieldNames.tail.foldLeft(UpdateFields(expr, fieldNames.head)) { + (resExpr, fieldName) => UpdateFields(resExpr, fieldName) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 87b9aea80c82..3d431d6ff13a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2067,6 +2067,12 @@ class Dataset[T] private[sql]( * // +----+----+----+----+ * }}} * + * Note that `allowMissingColumns` supports nested column in struct types. Missing nested columns + * of struct columns with same name will also be filled with null values. This currently does not + * support nested columns in array and map types. Note that if there is any missing nested columns + * to be filled, in order to make consistent schema between two sides of union, the nested fields + * of structs will be sorted after merging schema. + * * @group typedrel * @since 3.1.0 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index e72b8ce860b2..5f28dc60962b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -536,4 +536,185 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { assert(union2.schema.fieldNames === Array("a", "B", "C", "c")) } } + + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - simple") { + val df1 = Seq(((1, 2, 3), 0), ((2, 3, 4), 1), ((3, 4, 5), 2)).toDF("a", "idx") + val df2 = Seq(((3, 4), 0), ((1, 2), 1), ((2, 3), 2)).toDF("a", "idx") + val df3 = Seq(((100, 101, 102, 103), 0), ((110, 111, 112, 113), 1), ((120, 121, 122, 123), 2)) + .toDF("a", "idx") + + var unionDf = df1.unionByName(df2, true) + + checkAnswer(unionDf, + Row(Row(1, 2, 3), 0) :: Row(Row(2, 3, 4), 1) :: Row(Row(3, 4, 5), 2) :: + Row(Row(3, 4, null), 0) :: Row(Row(1, 2, null), 1) :: Row(Row(2, 3, null), 2) :: Nil + ) + + assert(unionDf.schema.toDDL == "`a` STRUCT<`_1`: INT, `_2`: INT, `_3`: INT>,`idx` INT") + + unionDf = df1.unionByName(df2, true).unionByName(df3, true) + + checkAnswer(unionDf, + Row(Row(1, 2, 3, null), 0) :: + Row(Row(2, 3, 4, null), 1) :: + Row(Row(3, 4, 5, null), 2) :: // df1 + Row(Row(3, 4, null, null), 0) :: + Row(Row(1, 2, null, null), 1) :: + Row(Row(2, 3, null, null), 2) :: // df2 + Row(Row(100, 101, 102, 103), 0) :: + Row(Row(110, 111, 112, 113), 1) :: + Row(Row(120, 121, 122, 123), 2) :: Nil // df3 + ) + assert(unionDf.schema.toDDL == + "`a` STRUCT<`_1`: INT, `_2`: INT, `_3`: INT, `_4`: INT>,`idx` INT") + } + + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - nested") { + val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") + val df2 = Seq((1, UnionClass1b(1, 2L, UnionClass3(2, 3L)))).toDF("id", "a") + + val expectedSchema = "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>" + + var unionDf = df1.unionByName(df2, true) + checkAnswer(unionDf, + Row(0, Row(0, 1, Row(1, null, "2"))) :: + Row(1, Row(1, 2, Row(2, 3L, null))) :: Nil) + assert(unionDf.schema.toDDL == expectedSchema) + + unionDf = df2.unionByName(df1, true) + checkAnswer(unionDf, + Row(1, Row(1, 2, Row(2, 3L, null))) :: + Row(0, Row(0, 1, Row(1, null, "2"))) :: Nil) + assert(unionDf.schema.toDDL == expectedSchema) + + val df3 = Seq((2, UnionClass1b(2, 3L, null))).toDF("id", "a") + unionDf = df1.unionByName(df3, true) + checkAnswer(unionDf, + Row(0, Row(0, 1, Row(1, null, "2"))) :: + Row(2, Row(2, 3, null)) :: Nil) + assert(unionDf.schema.toDDL == expectedSchema) + } + + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns" + + " - case-sensitive cases") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") + val df2 = Seq((1, UnionClass1c(1, 2L, UnionClass4(2, 3L)))).toDF("id", "a") + + var unionDf = df1.unionByName(df2, true) + checkAnswer(unionDf, + Row(0, Row(0, 1, Row(null, 1, null, "2"))) :: + Row(1, Row(1, 2, Row(2, null, 3L, null))) :: Nil) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`A`: INT, `a`: INT, `b`: BIGINT, `c`: STRING>>") + + unionDf = df2.unionByName(df1, true) + checkAnswer(unionDf, + Row(1, Row(1, 2, Row(2, null, 3L, null))) :: + Row(0, Row(0, 1, Row(null, 1, null, "2"))) :: Nil) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`A`: INT, `a`: INT, `b`: BIGINT, `c`: STRING>>") + + val df3 = Seq((2, UnionClass1b(2, 3L, UnionClass3(4, 5L)))).toDF("id", "a") + unionDf = df2.unionByName(df3, true) + checkAnswer(unionDf, + Row(1, Row(1, 2, Row(2, null, 3L))) :: + Row(2, Row(2, 3, Row(null, 4, 5L))) :: Nil) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`A`: INT, `a`: INT, `b`: BIGINT>>") + } + } + + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - edge case") { + val nestedStructType1 = StructType(Seq( + StructField("b", StringType))) + val nestedStructValues1 = Row("b") + + val nestedStructType2 = StructType(Seq( + StructField("b", StringType), + StructField("a", StringType))) + val nestedStructValues2 = Row("b", "a") + + val df1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues1) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType1)))) + + val df2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues2) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType2)))) + + val union = df1.unionByName(df2, allowMissingColumns = true) + checkAnswer(union, Row(Row(null, "b")) :: Row(Row("a", "b")) :: Nil) + assert(union.schema.toDDL == "`topLevelCol` STRUCT<`a`: STRING, `b`: STRING>") + } + + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - deep expr") { + def nestedDf(depth: Int, numColsAtEachDepth: Int): DataFrame = { + val initialNestedStructType = StructType( + (0 to numColsAtEachDepth).map(i => + StructField(s"nested${depth}Col$i", IntegerType, nullable = false)) + ) + val initialNestedValues = Row(0 to numColsAtEachDepth: _*) + + var depthCounter = depth - 1 + var structType = initialNestedStructType + var struct = initialNestedValues + while (depthCounter != 0) { + struct = Row((struct +: (1 to numColsAtEachDepth)): _*) + structType = StructType( + StructField(s"nested${depthCounter}Col0", structType, nullable = false) +: + (1 to numColsAtEachDepth).map(i => + StructField(s"nested${depthCounter}Col$i", IntegerType, nullable = false)) + ) + depthCounter -= 1 + } + + val df: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(struct) :: Nil), + StructType(Seq(StructField("nested0Col0", structType)))) + + df + } + + val df1 = nestedDf(depth = 10, numColsAtEachDepth = 1) + val df2 = nestedDf(depth = 10, numColsAtEachDepth = 20) + val union = df1.unionByName(df2, allowMissingColumns = true) + // scalastyle:off + val row1 = Row(Row(Row(Row(Row(Row(Row(Row(Row(Row( + Row(0, 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null)) + val row2 = Row(Row(Row(Row(Row(Row(Row(Row(Row(Row( + Row(0, 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9)) + // scalastyle:on + checkAnswer(union, row1 :: row2 :: Nil) + } } + +case class UnionClass1a(a: Int, b: Long, nested: UnionClass2) +case class UnionClass1b(a: Int, b: Long, nested: UnionClass3) +case class UnionClass1c(a: Int, b: Long, nested: UnionClass4) + +case class UnionClass2(a: Int, c: String) +case class UnionClass3(a: Int, b: Long) +case class UnionClass4(A: Int, b: Long)