From 95e8fd44396a0843bd0be4722bb2c1724f084a91 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 30 Aug 2020 19:26:32 -0700 Subject: [PATCH 01/23] Make unionByName null-filling behavior work with struct columns. --- .../sql/catalyst/analysis/ResolveUnion.scala | 103 +++++++++++++++--- .../expressions/complexTypeCreator.scala | 68 +++++++++++- .../sql/catalyst/optimizer/ComplexTypes.scala | 2 +- .../sql/catalyst/optimizer/WithFields.scala | 3 +- .../apache/spark/sql/types/StructType.scala | 26 +++++ .../spark/sql/types/StructTypeSuite.scala | 27 +++++ .../scala/org/apache/spark/sql/Column.scala | 27 +---- .../sql/DataFrameSetOperationsSuite.scala | 67 ++++++++++++ 8 files changed, 277 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 693a5a4e7544..c0eace605259 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -17,29 +17,97 @@ package org.apache.spark.sql.catalyst.analysis +import scala.collection.mutable + import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, Literal, NamedExpression, WithFields} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils /** * Resolves different children of Union to a common set of columns. */ object ResolveUnion extends Rule[LogicalPlan] { - private def unionTwoSides( + /** + * Adds missing fields recursively into given `col` expression, based on the target `StructType`. + * For example, given `col` as "a struct, b int" and `target` as + * "a struct, b int, c string", this method should add `a.c` and `c` to + * `col` expression. + */ + private def addFields(col: NamedExpression, target: StructType): Option[Expression] = { + require(col.dataType.isInstanceOf[StructType], "Only support StructType.") + + val resolver = SQLConf.get.resolver + val missingFields = + StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver) + if (missingFields.length == 0) { + None + } else { + Some(addFieldsInto(col, "", missingFields.fields)) + } + } + + private def addFieldsInto(col: Expression, base: String, fields: Seq[StructField]): Expression = { + var currCol = col + fields.foreach { field => + field.dataType match { + case dt: AtomicType => + // We need to sort columns in result, because we might add another column in other side. + // E.g., we want to union two structs "a int, b long" and "a int, c string". + // If we don't sort, we will have "a int, b long, c string" and "a int, c string, b long", + // which are not compatible. + currCol = WithFields(currCol, s"$base${field.name}", Literal(null, dt), + sortColumns = true) + case st: StructType => + val resolver = SQLConf.get.resolver + val colField = currCol.dataType.asInstanceOf[StructType] + .find(f => resolver(f.name, field.name)) + if (colField.isEmpty) { + // The whole struct is missing. Add a null. + currCol = WithFields(currCol, s"$base${field.name}", Literal(null, st), + sortColumns = true) + } else { + currCol = addFieldsInto(currCol, s"$base${field.name}.", st.fields) + } + } + } + currCol + } + + private def compareAndAddFields( left: LogicalPlan, right: LogicalPlan, - allowMissingCol: Boolean): LogicalPlan = { + allowMissingCol: Boolean): (Seq[NamedExpression], Seq[NamedExpression]) = { val resolver = SQLConf.get.resolver val leftOutputAttrs = left.output val rightOutputAttrs = right.output - // Builds a project list for `right` based on `left` output names + val aliased = mutable.ArrayBuffer.empty[Attribute] + val rightProjectList = leftOutputAttrs.map { lattr => - rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse { + val found = rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) } + if (found.isDefined) { + val foundDt = found.get.dataType + (foundDt, lattr.dataType) match { + case (source: StructType, target: StructType) + if allowMissingCol && !source.sameType(target) => + // Having an output with same name, but different struct type. + // We need to add missing fields. + addFields(found.get, target).map { added => + aliased += found.get + Alias(added, found.get.name)() + }.getOrElse(found.get) // Data type doesn't change. We should add fields at other side. + case _ => + // Same struct type, or + // unsupported: different types, array or map types, or + // `allowMissingCol` is disabled. + found.get + } + } else { if (allowMissingCol) { Alias(Literal(null, lattr.dataType), lattr.name)() } else { @@ -50,21 +118,28 @@ object ResolveUnion extends Rule[LogicalPlan] { } } + (rightProjectList, aliased) + } + + private def unionTwoSides( + left: LogicalPlan, + right: LogicalPlan, + allowMissingCol: Boolean): LogicalPlan = { + val rightOutputAttrs = right.output + + // Builds a project list for `right` based on `left` output names + val (rightProjectList, aliased) = compareAndAddFields(left, right, allowMissingCol) + // Delegates failure checks to `CheckAnalysis` - val notFoundAttrs = rightOutputAttrs.diff(rightProjectList) + val notFoundAttrs = rightOutputAttrs.diff(rightProjectList ++ aliased) val rightChild = Project(rightProjectList ++ notFoundAttrs, right) // Builds a project for `logicalPlan` based on `right` output names, if allowing // missing columns. val leftChild = if (allowMissingCol) { - val missingAttrs = notFoundAttrs.map { attr => - Alias(Literal(null, attr.dataType), attr.name)() - } - if (missingAttrs.nonEmpty) { - Project(leftOutputAttrs ++ missingAttrs, left) - } else { - left - } + // Add missing (nested) fields to left plan. + val (leftProjectList, _) = compareAndAddFields(rightChild, left, allowMissingCol) + Project(leftProjectList, left) } else { left } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 563ce7133a3d..87f229fa83e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -546,7 +547,8 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E case class WithFields( structExpr: Expression, names: Seq[String], - valExprs: Seq[Expression]) extends Unevaluable { + valExprs: Seq[Expression], + sortColumns: Boolean = false) extends Unevaluable { assert(names.length == valExprs.length) @@ -585,9 +587,15 @@ case class WithFields( } else { resultExprs :+ newExpr } - }.flatMap { case (name, expr) => Seq(Literal(name), expr) } + } - val expr = CreateNamedStruct(newExprs) + val finalExprs = if (sortColumns) { + newExprs.sortBy(_._1).flatMap { case (name, expr) => Seq(Literal(name), expr) } + } else { + newExprs.flatMap { case (name, expr) => Seq(Literal(name), expr) } + } + + val expr = CreateNamedStruct(finalExprs) if (structExpr.nullable) { If(IsNull(structExpr), Literal(null, expr.dataType), expr) } else { @@ -595,3 +603,55 @@ case class WithFields( } } } + +object WithFields { + /** + * Adds/replaces field in `StructType` into `col` expression by name. + */ + def apply(col: Expression, fieldName: String, expr: Expression): Expression = { + WithFields(col, fieldName, expr, false) + } + + def apply( + col: Expression, + fieldName: String, + expr: Expression, + sortColumns: Boolean): Expression = { + val nameParts = if (fieldName.isEmpty) { + fieldName :: Nil + } else { + CatalystSqlParser.parseMultipartIdentifier(fieldName) + } + withFieldHelper(col, nameParts, Nil, expr, sortColumns) + } + + private def withFieldHelper( + struct: Expression, + namePartsRemaining: Seq[String], + namePartsDone: Seq[String], + value: Expression, + sortColumns: Boolean) : WithFields = { + val name = namePartsRemaining.head + if (namePartsRemaining.length == 1) { + WithFields(struct, name :: Nil, value :: Nil, sortColumns) + } else { + val newNamesRemaining = namePartsRemaining.tail + val newNamesDone = namePartsDone :+ name + + val newStruct = if (struct.resolved) { + val resolver = SQLConf.get.resolver + ExtractValue(struct, Literal(name), resolver) + } else { + UnresolvedExtractValue(struct, Literal(name)) + } + + val newValue = withFieldHelper( + struct = newStruct, + namePartsRemaining = newNamesRemaining, + namePartsDone = newNamesDone, + value = value, + sortColumns = sortColumns) + WithFields(struct, name :: Nil, newValue :: Nil, sortColumns) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 1c33a2c7c313..87005cbde11d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -39,7 +39,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { // Remove redundant field extraction. case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => createNamedStruct.valExprs(ordinal) - case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) => + case GetStructField(w @ WithFields(struct, names, valExprs, _), ordinal, maybeName) => val name = w.dataType(ordinal).name val matches = names.zip(valExprs).filter(_._1 == name) if (matches.nonEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala index 05c90864e4bb..44572a9c46d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.rules.Rule */ object CombineWithFields extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) => + case WithFields(WithFields(struct, names1, valExprs1, sort1), names2, valExprs2, sort2) + if sort1 == sort2 => WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b14fb04cc453..fa97fe233e9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -641,4 +641,30 @@ object StructType extends AbstractDataType { fields.foreach(s => map.put(s.name, s)) map } + + /** + * Returns a `StructType` that contains missing fields recursively from `source` to `target`. + * Note that this doesn't support looking into array type and map type recursively. + */ + def findMissingFields(source: StructType, target: StructType, resolver: Resolver): StructType = { + def bothStructType(dt1: DataType, dt2: DataType): Boolean = + dt1.isInstanceOf[StructType] && dt2.isInstanceOf[StructType] + + val newFields = mutable.ArrayBuffer.empty[StructField] + + target.fields.foreach { field => + val found = source.fields.find(f => resolver(field.name, f.name)) + if (found.isEmpty) { + // Found a missing field in `source`. + newFields += field + } else if (bothStructType(found.get.dataType, field.dataType) && + !found.get.dataType.sameType(field.dataType)) { + // Found a field with same name, but different data type. + newFields += found.get.copy(dataType = + findMissingFields(found.get.dataType.asInstanceOf[StructType], + field.dataType.asInstanceOf[StructType], resolver)) + } + } + StructType(newFields) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index 6824a64badc1..3f5bf56662f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType.fromDDL class StructTypeSuite extends SparkFunSuite { @@ -103,4 +104,30 @@ class StructTypeSuite extends SparkFunSuite { val interval = "`a` INTERVAL" assert(fromDDL(interval).toDDL === interval) } + + test("find missing (nested) fields") { + val schema = StructType.fromDDL( + "c1 INT, c2 STRUCT>") + val resolver = SQLConf.get.resolver + + val source1 = StructType.fromDDL("c1 INT") + val missing1 = StructType.fromDDL( + "c2 STRUCT>") + assert(StructType.findMissingFields(source1, schema, resolver).sameType(missing1)) + + val source2 = StructType.fromDDL("c1 INT, c3 STRING") + val missing2 = StructType.fromDDL( + "c2 STRUCT>") + assert(StructType.findMissingFields(source2, schema, resolver).sameType(missing2)) + + val source3 = StructType.fromDDL("c1 INT, c2 STRUCT") + val missing3 = StructType.fromDDL( + "c2 STRUCT>") + assert(StructType.findMissingFields(source3, schema, resolver).sameType(missing3)) + + val source4 = StructType.fromDDL("c1 INT, c2 STRUCT>") + val missing4 = StructType.fromDDL( + "c2 STRUCT>") + assert(StructType.findMissingFields(source4, schema, resolver).sameType(missing4)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index da542c67d9c5..dabcd905587f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -909,32 +909,7 @@ class Column(val expr: Expression) extends Logging { require(fieldName != null, "fieldName cannot be null") require(col != null, "col cannot be null") - val nameParts = if (fieldName.isEmpty) { - fieldName :: Nil - } else { - CatalystSqlParser.parseMultipartIdentifier(fieldName) - } - withFieldHelper(expr, nameParts, Nil, col.expr) - } - - private def withFieldHelper( - struct: Expression, - namePartsRemaining: Seq[String], - namePartsDone: Seq[String], - value: Expression) : WithFields = { - val name = namePartsRemaining.head - if (namePartsRemaining.length == 1) { - WithFields(struct, name :: Nil, value :: Nil) - } else { - val newNamesRemaining = namePartsRemaining.tail - val newNamesDone = namePartsDone :+ name - val newValue = withFieldHelper( - struct = UnresolvedExtractValue(struct, Literal(name)), - namePartsRemaining = newNamesRemaining, - namePartsDone = newNamesDone, - value = value) - WithFields(struct, name :: Nil, newValue :: Nil) - } + WithFields(expr, fieldName, col.expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index e72b8ce860b2..72b387b7fdff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -536,4 +536,71 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { assert(union2.schema.fieldNames === Array("a", "B", "C", "c")) } } + + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - 1") { + val df1 = Seq(((1, 2, 3), 0), ((2, 3, 4), 1), ((3, 4, 5), 2)).toDF("a", "idx") + val df2 = Seq(((3, 4), 0), ((1, 2), 1), ((2, 3), 2)).toDF("a", "idx") + val df3 = Seq(((100, 101, 102, 103), 0), ((110, 111, 112, 113), 1), ((120, 121, 122, 123), 2)) + .toDF("a", "idx") + + var unionDf = df1.unionByName(df2, true) + + checkAnswer(unionDf, + Row(Row(1, 2, 3), 0) :: Row(Row(2, 3, 4), 1) :: Row(Row(3, 4, 5), 2) :: // df1 + Row(Row(3, 4, null), 0) :: Row(Row(1, 2, null), 1) :: Row(Row(2, 3, null), 2) :: Nil // df2 + ) + + assert(unionDf.schema.toDDL == "`a` STRUCT<`_1`: INT, `_2`: INT, `_3`: INT>,`idx` INT") + + unionDf = df1.unionByName(df2, true).unionByName(df3, true) + + checkAnswer(unionDf, + Row(Row(1, 2, 3, null), 0) :: + Row(Row(2, 3, 4, null), 1) :: + Row(Row(3, 4, 5, null), 2) :: // df1 + Row(Row(3, 4, null, null), 0) :: + Row(Row(1, 2, null, null), 1) :: + Row(Row(2, 3, null, null), 2) :: // df2 + Row(Row(100, 101, 102, 103), 0) :: + Row(Row(110, 111, 112, 113), 1) :: + Row(Row(120, 121, 122, 123), 2) :: Nil // df3 + ) + assert(unionDf.schema.toDDL == + "`a` STRUCT<`_1`: INT, `_2`: INT, `_3`: INT, `_4`: INT>,`idx` INT") + } + + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - 2") { + val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") + val df2 = Seq((1, UnionClass1b(1, 2L, UnionClass3(2, 3L)))).toDF("id", "a") + + var unionDf = df1.unionByName(df2, true) + checkAnswer(unionDf, + Row(0, Row(0, 1, Row(1, null, "2"))) :: + Row(1, Row(1, 2, Row(2, 3L, null))) :: Nil) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>") + + unionDf = df2.unionByName(df1, true) + checkAnswer(unionDf, + Row(1, Row(1, 2, Row(2, 3L, null))) :: + Row(0, Row(0, 1, Row(1, null, "2"))) :: Nil) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>") + + val df3 = Seq((2, UnionClass1b(2, 3L, null))).toDF("id", "a") + unionDf = df1.unionByName(df3, true) + checkAnswer(unionDf, + Row(0, Row(0, 1, Row(1, null, "2"))) :: + Row(2, Row(2, 3, null)) :: Nil) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>") + } } + +case class UnionClass1a(a: Int, b: Long, nested: UnionClass2) +case class UnionClass1b(a: Int, b: Long, nested: UnionClass3) +case class UnionClass2(a: Int, c: String) +case class UnionClass3(a: Int, b: Long) From 5db1e0f9cb3323888b007cef62483d11f7a84773 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 30 Aug 2020 20:47:36 -0700 Subject: [PATCH 02/23] Remove unnecessary Project. --- .../apache/spark/sql/catalyst/analysis/ResolveUnion.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index c0eace605259..42d82b831ccd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -139,7 +139,11 @@ object ResolveUnion extends Rule[LogicalPlan] { val leftChild = if (allowMissingCol) { // Add missing (nested) fields to left plan. val (leftProjectList, _) = compareAndAddFields(rightChild, left, allowMissingCol) - Project(leftProjectList, left) + if (leftProjectList.map(_.toAttribute) != left.output) { + Project(leftProjectList, left) + } else { + left + } } else { left } From 8bec8a3b875a229b096af3d34222316fc33f763a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 31 Aug 2020 18:04:03 -0700 Subject: [PATCH 03/23] Address comments. --- .../sql/catalyst/analysis/ResolveUnion.scala | 26 ++++++------ .../expressions/complexTypeCreator.scala | 16 ++++---- .../spark/sql/types/StructTypeSuite.scala | 40 +++++++++++++++++++ 3 files changed, 60 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 42d82b831ccd..fdce7f9a3c97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -39,7 +39,7 @@ object ResolveUnion extends Rule[LogicalPlan] { * `col` expression. */ private def addFields(col: NamedExpression, target: StructType): Option[Expression] = { - require(col.dataType.isInstanceOf[StructType], "Only support StructType.") + assert(col.dataType.isInstanceOf[StructType], "Only support StructType.") val resolver = SQLConf.get.resolver val missingFields = @@ -52,30 +52,28 @@ object ResolveUnion extends Rule[LogicalPlan] { } private def addFieldsInto(col: Expression, base: String, fields: Seq[StructField]): Expression = { - var currCol = col - fields.foreach { field => + fields.foldLeft(col) { case (currCol, field) => field.dataType match { - case dt: AtomicType => - // We need to sort columns in result, because we might add another column in other side. - // E.g., we want to union two structs "a int, b long" and "a int, c string". - // If we don't sort, we will have "a int, b long, c string" and "a int, c string, b long", - // which are not compatible. - currCol = WithFields(currCol, s"$base${field.name}", Literal(null, dt), - sortColumns = true) case st: StructType => val resolver = SQLConf.get.resolver val colField = currCol.dataType.asInstanceOf[StructType] .find(f => resolver(f.name, field.name)) if (colField.isEmpty) { // The whole struct is missing. Add a null. - currCol = WithFields(currCol, s"$base${field.name}", Literal(null, st), - sortColumns = true) + WithFields(currCol, s"$base${field.name}", Literal(null, st), + sortOutputColumns = true) } else { - currCol = addFieldsInto(currCol, s"$base${field.name}.", st.fields) + addFieldsInto(currCol, s"$base${field.name}.", st.fields) } + case dt => + // We need to sort columns in result, because we might add another column in other side. + // E.g., we want to union two structs "a int, b long" and "a int, c string". + // If we don't sort, we will have "a int, b long, c string" and "a int, c string, b long", + // which are not compatible. + WithFields(currCol, s"$base${field.name}", Literal(null, dt), + sortOutputColumns = true) } } - currCol } private def compareAndAddFields( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 87f229fa83e2..d11bb195450a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -548,7 +548,7 @@ case class WithFields( structExpr: Expression, names: Seq[String], valExprs: Seq[Expression], - sortColumns: Boolean = false) extends Unevaluable { + sortOutputColumns: Boolean = false) extends Unevaluable { assert(names.length == valExprs.length) @@ -589,7 +589,7 @@ case class WithFields( } } - val finalExprs = if (sortColumns) { + val finalExprs = if (sortOutputColumns) { newExprs.sortBy(_._1).flatMap { case (name, expr) => Seq(Literal(name), expr) } } else { newExprs.flatMap { case (name, expr) => Seq(Literal(name), expr) } @@ -616,13 +616,13 @@ object WithFields { col: Expression, fieldName: String, expr: Expression, - sortColumns: Boolean): Expression = { + sortOutputColumns: Boolean): Expression = { val nameParts = if (fieldName.isEmpty) { fieldName :: Nil } else { CatalystSqlParser.parseMultipartIdentifier(fieldName) } - withFieldHelper(col, nameParts, Nil, expr, sortColumns) + withFieldHelper(col, nameParts, Nil, expr, sortOutputColumns) } private def withFieldHelper( @@ -630,10 +630,10 @@ object WithFields { namePartsRemaining: Seq[String], namePartsDone: Seq[String], value: Expression, - sortColumns: Boolean) : WithFields = { + sortOutputColumns: Boolean) : WithFields = { val name = namePartsRemaining.head if (namePartsRemaining.length == 1) { - WithFields(struct, name :: Nil, value :: Nil, sortColumns) + WithFields(struct, name :: Nil, value :: Nil, sortOutputColumns) } else { val newNamesRemaining = namePartsRemaining.tail val newNamesDone = namePartsDone :+ name @@ -650,8 +650,8 @@ object WithFields { namePartsRemaining = newNamesRemaining, namePartsDone = newNamesDone, value = value, - sortColumns = sortColumns) - WithFields(struct, name :: Nil, newValue :: Nil, sortColumns) + sortOutputColumns = sortOutputColumns) + WithFields(struct, name :: Nil, newValue :: Nil, sortOutputColumns) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index 3f5bf56662f9..0f6171b31028 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -129,5 +129,45 @@ class StructTypeSuite extends SparkFunSuite { val missing4 = StructType.fromDDL( "c2 STRUCT>") assert(StructType.findMissingFields(source4, schema, resolver).sameType(missing4)) + + val schemaWithArray = StructType.fromDDL( + "c1 INT, c2 ARRAY>") + val source5 = StructType.fromDDL( + "c1 INT") + val missing5 = StructType.fromDDL( + "c2 ARRAY>") + assert(StructType.findMissingFields(source5, schemaWithArray, resolver).sameType(missing5)) + + val schemaWithMap1 = StructType.fromDDL( + "c1 INT, c2 MAP, STRING>, c3 LONG") + val source6 = StructType.fromDDL( + "c1 INT, c3 LONG") + val missing6 = StructType.fromDDL( + "c2 MAP, STRING>") + assert(StructType.findMissingFields(source6, schemaWithMap1, resolver).sameType(missing6)) + + val schemaWithMap2 = StructType.fromDDL( + "c1 INT, c2 MAP>, c3 STRING") + val source7 = StructType.fromDDL( + "c1 INT, c3 STRING") + val missing7 = StructType.fromDDL( + "c2 MAP>") + assert(StructType.findMissingFields(source7, schemaWithMap2, resolver).sameType(missing7)) + + // Unsupported: nested struct in array, map + val source8 = StructType.fromDDL( + "c1 INT, c2 ARRAY>") + // `findMissingFields` doesn't support looking into nested struct in array type. + assert(StructType.findMissingFields(source8, schemaWithArray, resolver).length == 0) + + val source9 = StructType.fromDDL( + "c1 INT, c2 MAP, STRING>, c3 LONG") + // `findMissingFields` doesn't support looking into nested struct in map type. + assert(StructType.findMissingFields(source9, schemaWithMap1, resolver).length == 0) + + val source10 = StructType.fromDDL( + "c1 INT, c2 MAP>, c3 STRING") + // `findMissingFields` doesn't support looking into nested struct in map type. + assert(StructType.findMissingFields(source10, schemaWithMap2, resolver).length == 0) } } From 2515d7871a9358ed512896534828dc84fd977750 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 1 Sep 2020 21:51:26 -0700 Subject: [PATCH 04/23] Add some comments. --- .../apache/spark/sql/catalyst/analysis/ResolveUnion.scala | 8 +++++++- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 4 ++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index fdce7f9a3c97..ee9c30299c18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -35,7 +35,7 @@ object ResolveUnion extends Rule[LogicalPlan] { /** * Adds missing fields recursively into given `col` expression, based on the target `StructType`. * For example, given `col` as "a struct, b int" and `target` as - * "a struct, b int, c string", this method should add `a.c` and `c` to + * "a struct, b int, c string", this method should add `a.c` and `c` to * `col` expression. */ private def addFields(col: NamedExpression, target: StructType): Option[Expression] = { @@ -51,6 +51,12 @@ object ResolveUnion extends Rule[LogicalPlan] { } } + /** + * Adds missing fields recursively into given `col` expression. The missing fields are given + * in `fields`. For example, given `col` as "a struct, b int", and `fields` is + * "a struct, c string". This method will add a nested `a.c` field and a top-level + * `c` field to `col` and fill null values for them. + */ private def addFieldsInto(col: Expression, base: String, fields: Seq[StructField]): Expression = { fields.foldLeft(col) { case (currCol, field) => field.dataType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7b0bae6a8205..4b8d64fb6cf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2067,6 +2067,10 @@ class Dataset[T] private[sql]( * // +----+----+----+----+ * }}} * + * Note that `allowMissingColumns` supports nested column in struct types. Missing nested columns + * of struct columns with same name will also be filled with null values. This currently does not + * support nested columns in array and map types. + * * @group typedrel * @since 3.1.0 */ From 4398e77f1b5dfc6a03adbb38526c3b3cc16f6a81 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 2 Sep 2020 23:21:43 -0700 Subject: [PATCH 05/23] Address comments. --- .../sql/catalyst/analysis/ResolveUnion.scala | 13 ++++---- .../apache/spark/sql/types/StructType.scala | 19 ++++++++---- .../spark/sql/types/StructTypeSuite.scala | 30 ++++++++++++------- 3 files changed, 41 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index ee9c30299c18..cd735243bc7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -44,18 +44,18 @@ object ResolveUnion extends Rule[LogicalPlan] { val resolver = SQLConf.get.resolver val missingFields = StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver) - if (missingFields.length == 0) { + if (missingFields.isEmpty) { None } else { - Some(addFieldsInto(col, "", missingFields.fields)) + missingFields.map(s => addFieldsInto(col, "", s.fields)) } } /** - * Adds missing fields recursively into given `col` expression. The missing fields are given - * in `fields`. For example, given `col` as "a struct, b int", and `fields` is - * "a struct, c string". This method will add a nested `a.c` field and a top-level - * `c` field to `col` and fill null values for them. + * Adds missing fields recursively into given `col` expression. The missing fields are given + * in `fields`. For example, given `col` as "a struct, b int", and `fields` is + * "a struct, c string". This method will add a nested `a.c` field and a top-level + * `c` field to `col` and fill null values for them. */ private def addFieldsInto(col: Expression, base: String, fields: Seq[StructField]): Expression = { fields.foldLeft(col) { case (currCol, field) => @@ -134,6 +134,7 @@ object ResolveUnion extends Rule[LogicalPlan] { // Builds a project list for `right` based on `left` output names val (rightProjectList, aliased) = compareAndAddFields(left, right, allowMissingCol) + // Delegates failure checks to `CheckAnalysis` val notFoundAttrs = rightOutputAttrs.diff(rightProjectList ++ aliased) val rightChild = Project(rightProjectList ++ notFoundAttrs, right) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index fa97fe233e9e..4055e2493a39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -646,7 +646,10 @@ object StructType extends AbstractDataType { * Returns a `StructType` that contains missing fields recursively from `source` to `target`. * Note that this doesn't support looking into array type and map type recursively. */ - def findMissingFields(source: StructType, target: StructType, resolver: Resolver): StructType = { + def findMissingFields( + source: StructType, + target: StructType, + resolver: Resolver): Option[StructType] = { def bothStructType(dt1: DataType, dt2: DataType): Boolean = dt1.isInstanceOf[StructType] && dt2.isInstanceOf[StructType] @@ -660,11 +663,17 @@ object StructType extends AbstractDataType { } else if (bothStructType(found.get.dataType, field.dataType) && !found.get.dataType.sameType(field.dataType)) { // Found a field with same name, but different data type. - newFields += found.get.copy(dataType = - findMissingFields(found.get.dataType.asInstanceOf[StructType], - field.dataType.asInstanceOf[StructType], resolver)) + findMissingFields(found.get.dataType.asInstanceOf[StructType], + field.dataType.asInstanceOf[StructType], resolver).map { missingType => + newFields += found.get.copy(dataType = missingType) + } } } - StructType(newFields) + + if (newFields.isEmpty) { + None + } else { + Some(StructType(newFields)) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index 0f6171b31028..c602afb6494a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -113,22 +113,26 @@ class StructTypeSuite extends SparkFunSuite { val source1 = StructType.fromDDL("c1 INT") val missing1 = StructType.fromDDL( "c2 STRUCT>") - assert(StructType.findMissingFields(source1, schema, resolver).sameType(missing1)) + assert(StructType.findMissingFields(source1, schema, resolver) + .map(_.sameType(missing1)).getOrElse(false)) val source2 = StructType.fromDDL("c1 INT, c3 STRING") val missing2 = StructType.fromDDL( "c2 STRUCT>") - assert(StructType.findMissingFields(source2, schema, resolver).sameType(missing2)) + assert(StructType.findMissingFields(source2, schema, resolver) + .map(_.sameType(missing2)).getOrElse(false)) val source3 = StructType.fromDDL("c1 INT, c2 STRUCT") val missing3 = StructType.fromDDL( "c2 STRUCT>") - assert(StructType.findMissingFields(source3, schema, resolver).sameType(missing3)) + assert(StructType.findMissingFields(source3, schema, resolver) + .map(_.sameType(missing3)).getOrElse(false)) val source4 = StructType.fromDDL("c1 INT, c2 STRUCT>") val missing4 = StructType.fromDDL( "c2 STRUCT>") - assert(StructType.findMissingFields(source4, schema, resolver).sameType(missing4)) + assert(StructType.findMissingFields(source4, schema, resolver) + .map(_.sameType(missing4)).getOrElse(false)) val schemaWithArray = StructType.fromDDL( "c1 INT, c2 ARRAY>") @@ -136,7 +140,9 @@ class StructTypeSuite extends SparkFunSuite { "c1 INT") val missing5 = StructType.fromDDL( "c2 ARRAY>") - assert(StructType.findMissingFields(source5, schemaWithArray, resolver).sameType(missing5)) + assert( + StructType.findMissingFields(source5, schemaWithArray, resolver) + .map(_.sameType(missing5)).getOrElse(false)) val schemaWithMap1 = StructType.fromDDL( "c1 INT, c2 MAP, STRING>, c3 LONG") @@ -144,7 +150,9 @@ class StructTypeSuite extends SparkFunSuite { "c1 INT, c3 LONG") val missing6 = StructType.fromDDL( "c2 MAP, STRING>") - assert(StructType.findMissingFields(source6, schemaWithMap1, resolver).sameType(missing6)) + assert( + StructType.findMissingFields(source6, schemaWithMap1, resolver) + .map(_.sameType(missing6)).getOrElse(false)) val schemaWithMap2 = StructType.fromDDL( "c1 INT, c2 MAP>, c3 STRING") @@ -152,22 +160,24 @@ class StructTypeSuite extends SparkFunSuite { "c1 INT, c3 STRING") val missing7 = StructType.fromDDL( "c2 MAP>") - assert(StructType.findMissingFields(source7, schemaWithMap2, resolver).sameType(missing7)) + assert( + StructType.findMissingFields(source7, schemaWithMap2, resolver) + .map(_.sameType(missing7)).getOrElse(false)) // Unsupported: nested struct in array, map val source8 = StructType.fromDDL( "c1 INT, c2 ARRAY>") // `findMissingFields` doesn't support looking into nested struct in array type. - assert(StructType.findMissingFields(source8, schemaWithArray, resolver).length == 0) + assert(StructType.findMissingFields(source8, schemaWithArray, resolver).isEmpty) val source9 = StructType.fromDDL( "c1 INT, c2 MAP, STRING>, c3 LONG") // `findMissingFields` doesn't support looking into nested struct in map type. - assert(StructType.findMissingFields(source9, schemaWithMap1, resolver).length == 0) + assert(StructType.findMissingFields(source9, schemaWithMap1, resolver).isEmpty) val source10 = StructType.fromDDL( "c1 INT, c2 MAP>, c3 STRING") // `findMissingFields` doesn't support looking into nested struct in map type. - assert(StructType.findMissingFields(source10, schemaWithMap2, resolver).length == 0) + assert(StructType.findMissingFields(source10, schemaWithMap2, resolver).isEmpty) } } From c787f664cb0248370d07ab829070ce889fb02a0c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 12 Sep 2020 16:24:49 -0700 Subject: [PATCH 06/23] Revise comments and add case-sensitive test. --- .../sql/catalyst/analysis/ResolveUnion.scala | 16 +++--- .../spark/sql/types/StructTypeSuite.scala | 49 ++++++++++++++++--- .../sql/DataFrameSetOperationsSuite.scala | 4 +- 3 files changed, 53 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index cd735243bc7a..37d624d9ab48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -100,15 +100,20 @@ object ResolveUnion extends Rule[LogicalPlan] { case (source: StructType, target: StructType) if allowMissingCol && !source.sameType(target) => // Having an output with same name, but different struct type. - // We need to add missing fields. + // We need to add missing fields. Note that if there are deeply nested structs such as + // nested struct of array in struct, we don't support to add missing deeply nested field + // like that. For such case, simply use original attribute. addFields(found.get, target).map { added => aliased += found.get Alias(added, found.get.name)() - }.getOrElse(found.get) // Data type doesn't change. We should add fields at other side. + }.getOrElse(found.get) case _ => - // Same struct type, or - // unsupported: different types, array or map types, or - // `allowMissingCol` is disabled. + // We don't need/try to add missing fields if: + // 1. The attributes of left and right side are the same struct type + // 2. The attributes are not struct types. They might be primitive types, or array, map + // types. We don't support adding missing fields of nested structs in array or map + // types now. + // 3. `allowMissingCol` is disabled. found.get } } else { @@ -134,7 +139,6 @@ object ResolveUnion extends Rule[LogicalPlan] { // Builds a project list for `right` based on `left` output names val (rightProjectList, aliased) = compareAndAddFields(left, right, allowMissingCol) - // Delegates failure checks to `CheckAnalysis` val notFoundAttrs = rightOutputAttrs.diff(rightProjectList ++ aliased) val rightChild = Project(rightProjectList ++ notFoundAttrs, right) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index c602afb6494a..ab388a320cde 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.types import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType.fromDDL -class StructTypeSuite extends SparkFunSuite { +class StructTypeSuite extends SparkFunSuite with SQLHelper { private val s = StructType.fromDDL("a INT, b STRING") @@ -114,25 +115,25 @@ class StructTypeSuite extends SparkFunSuite { val missing1 = StructType.fromDDL( "c2 STRUCT>") assert(StructType.findMissingFields(source1, schema, resolver) - .map(_.sameType(missing1)).getOrElse(false)) + .exists(_.sameType(missing1))) val source2 = StructType.fromDDL("c1 INT, c3 STRING") val missing2 = StructType.fromDDL( "c2 STRUCT>") assert(StructType.findMissingFields(source2, schema, resolver) - .map(_.sameType(missing2)).getOrElse(false)) + .exists(_.sameType(missing2))) val source3 = StructType.fromDDL("c1 INT, c2 STRUCT") val missing3 = StructType.fromDDL( "c2 STRUCT>") assert(StructType.findMissingFields(source3, schema, resolver) - .map(_.sameType(missing3)).getOrElse(false)) + .exists(_.sameType(missing3))) val source4 = StructType.fromDDL("c1 INT, c2 STRUCT>") val missing4 = StructType.fromDDL( "c2 STRUCT>") assert(StructType.findMissingFields(source4, schema, resolver) - .map(_.sameType(missing4)).getOrElse(false)) + .exists(_.sameType(missing4))) val schemaWithArray = StructType.fromDDL( "c1 INT, c2 ARRAY>") @@ -142,7 +143,7 @@ class StructTypeSuite extends SparkFunSuite { "c2 ARRAY>") assert( StructType.findMissingFields(source5, schemaWithArray, resolver) - .map(_.sameType(missing5)).getOrElse(false)) + .exists(_.sameType(missing5))) val schemaWithMap1 = StructType.fromDDL( "c1 INT, c2 MAP, STRING>, c3 LONG") @@ -152,7 +153,7 @@ class StructTypeSuite extends SparkFunSuite { "c2 MAP, STRING>") assert( StructType.findMissingFields(source6, schemaWithMap1, resolver) - .map(_.sameType(missing6)).getOrElse(false)) + .exists(_.sameType(missing6))) val schemaWithMap2 = StructType.fromDDL( "c1 INT, c2 MAP>, c3 STRING") @@ -162,7 +163,7 @@ class StructTypeSuite extends SparkFunSuite { "c2 MAP>") assert( StructType.findMissingFields(source7, schemaWithMap2, resolver) - .map(_.sameType(missing7)).getOrElse(false)) + .exists(_.sameType(missing7))) // Unsupported: nested struct in array, map val source8 = StructType.fromDDL( @@ -180,4 +181,36 @@ class StructTypeSuite extends SparkFunSuite { // `findMissingFields` doesn't support looking into nested struct in map type. assert(StructType.findMissingFields(source10, schemaWithMap2, resolver).isEmpty) } + + test("find missing (nested) fields: case sensitive cases") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val schema = StructType.fromDDL( + "c1 INT, c2 STRUCT>") + val resolver = SQLConf.get.resolver + + val source1 = StructType.fromDDL("c1 INT, C2 LONG") + val missing1 = StructType.fromDDL( + "c2 STRUCT>") + assert(StructType.findMissingFields(source1, schema, resolver) + .exists(_.sameType(missing1))) + + val source2 = StructType.fromDDL("c2 LONG") + val missing2 = StructType.fromDDL( + "c1 INT") + assert(StructType.findMissingFields(source2, schema, resolver) + .exists(_.sameType(missing2))) + + val source3 = StructType.fromDDL("c1 INT, c2 STRUCT>") + val missing3 = StructType.fromDDL( + "c2 STRUCT>") + assert(StructType.findMissingFields(source3, schema, resolver) + .exists(_.sameType(missing3))) + + val source4 = StructType.fromDDL("c1 INT, c2 STRUCT>") + val missing4 = StructType.fromDDL( + "c2 STRUCT>") + assert(StructType.findMissingFields(source4, schema, resolver) + .exists(_.sameType(missing4))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index 72b387b7fdff..9a495533de62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -537,7 +537,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { } } - test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - 1") { + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - simple") { val df1 = Seq(((1, 2, 3), 0), ((2, 3, 4), 1), ((3, 4, 5), 2)).toDF("a", "idx") val df2 = Seq(((3, 4), 0), ((1, 2), 1), ((2, 3), 2)).toDF("a", "idx") val df3 = Seq(((100, 101, 102, 103), 0), ((110, 111, 112, 113), 1), ((120, 121, 122, 123), 2)) @@ -569,7 +569,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { "`a` STRUCT<`_1`: INT, `_2`: INT, `_3`: INT, `_4`: INT>,`idx` INT") } - test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - 2") { + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - nested") { val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") val df2 = Seq((1, UnionClass1b(1, 2L, UnionClass3(2, 3L)))).toDF("id", "a") From 3ea24af9d8cff1afc8f86d2998ab68ad1412341d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 12 Sep 2020 17:35:40 -0700 Subject: [PATCH 07/23] Add another case-sensitive test. --- .../sql/catalyst/optimizer/WithFields.scala | 2 +- .../sql/DataFrameSetOperationsSuite.scala | 36 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala index 44572a9c46d9..23d23afc11ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala @@ -29,7 +29,7 @@ object CombineWithFields extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case WithFields(WithFields(struct, names1, valExprs1, sort1), names2, valExprs2, sort2) if sort1 == sort2 => - WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2) + WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2, sort1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index 9a495533de62..a7dce9cc00af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -598,9 +598,45 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>") } + + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns" + + "- case-sensitive cases") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") + val df2 = Seq((1, UnionClass1c(1, 2L, UnionClass4(2, 3L)))).toDF("id", "a") + + var unionDf = df1.unionByName(df2, true) + checkAnswer(unionDf, + Row(0, Row(0, 1, Row(null, 1, null, "2"))) :: + Row(1, Row(1, 2, Row(2, null, 3L, null))) :: Nil) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`A`: INT, `a`: INT, `b`: BIGINT, `c`: STRING>>") + + unionDf = df2.unionByName(df1, true) + checkAnswer(unionDf, + Row(1, Row(1, 2, Row(2, null, 3L, null))) :: + Row(0, Row(0, 1, Row(null, 1, null, "2"))) :: Nil) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`A`: INT, `a`: INT, `b`: BIGINT, `c`: STRING>>") + + val df3 = Seq((2, UnionClass1b(2, 3L, UnionClass3(4, 5L)))).toDF("id", "a") + unionDf = df2.unionByName(df3, true) + checkAnswer(unionDf, + Row(1, Row(1, 2, Row(2, null, 3L))) :: + Row(2, Row(2, 3, Row(null, 4, 5L))) :: Nil) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`A`: INT, `a`: INT, `b`: BIGINT>>") + } + } } case class UnionClass1a(a: Int, b: Long, nested: UnionClass2) case class UnionClass1b(a: Int, b: Long, nested: UnionClass3) +case class UnionClass1c(a: Int, b: Long, nested: UnionClass4) + case class UnionClass2(a: Int, c: String) case class UnionClass3(a: Int, b: Long) +case class UnionClass4(A: Int, b: Long) From ae14447e7f9bf1c99e6a6e471636d169a4d8340b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 12 Sep 2020 22:55:54 -0700 Subject: [PATCH 08/23] Address comments. --- .../sql/catalyst/analysis/ResolveUnion.scala | 11 ++-- .../spark/sql/types/StructTypeSuite.scala | 64 +++++++------------ .../sql/DataFrameSetOperationsSuite.scala | 17 ++--- 3 files changed, 37 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 37d624d9ab48..95d2e8d09b12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -95,7 +95,8 @@ object ResolveUnion extends Rule[LogicalPlan] { val rightProjectList = leftOutputAttrs.map { lattr => val found = rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) } if (found.isDefined) { - val foundDt = found.get.dataType + val foundAttr = found.get + val foundDt = foundAttr.dataType (foundDt, lattr.dataType) match { case (source: StructType, target: StructType) if allowMissingCol && !source.sameType(target) => @@ -103,10 +104,10 @@ object ResolveUnion extends Rule[LogicalPlan] { // We need to add missing fields. Note that if there are deeply nested structs such as // nested struct of array in struct, we don't support to add missing deeply nested field // like that. For such case, simply use original attribute. - addFields(found.get, target).map { added => - aliased += found.get - Alias(added, found.get.name)() - }.getOrElse(found.get) + addFields(foundAttr, target).map { added => + aliased += foundAttr + Alias(added, foundAttr.name)() + }.getOrElse(foundAttr) case _ => // We don't need/try to add missing fields if: // 1. The attributes of left and right side are the same struct type diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index ab388a320cde..645e65f06508 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -107,108 +107,92 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { } test("find missing (nested) fields") { - val schema = StructType.fromDDL( - "c1 INT, c2 STRUCT>") + val schema = StructType.fromDDL("c1 INT, c2 STRUCT>") val resolver = SQLConf.get.resolver val source1 = StructType.fromDDL("c1 INT") - val missing1 = StructType.fromDDL( - "c2 STRUCT>") + val missing1 = StructType.fromDDL("c2 STRUCT>") assert(StructType.findMissingFields(source1, schema, resolver) .exists(_.sameType(missing1))) val source2 = StructType.fromDDL("c1 INT, c3 STRING") - val missing2 = StructType.fromDDL( - "c2 STRUCT>") + val missing2 = StructType.fromDDL("c2 STRUCT>") assert(StructType.findMissingFields(source2, schema, resolver) .exists(_.sameType(missing2))) val source3 = StructType.fromDDL("c1 INT, c2 STRUCT") - val missing3 = StructType.fromDDL( - "c2 STRUCT>") + val missing3 = StructType.fromDDL("c2 STRUCT>") assert(StructType.findMissingFields(source3, schema, resolver) .exists(_.sameType(missing3))) val source4 = StructType.fromDDL("c1 INT, c2 STRUCT>") - val missing4 = StructType.fromDDL( - "c2 STRUCT>") + val missing4 = StructType.fromDDL("c2 STRUCT>") assert(StructType.findMissingFields(source4, schema, resolver) .exists(_.sameType(missing4))) + } + + test("find missing (nested) fields: array and map") { + val resolver = SQLConf.get.resolver - val schemaWithArray = StructType.fromDDL( - "c1 INT, c2 ARRAY>") - val source5 = StructType.fromDDL( - "c1 INT") - val missing5 = StructType.fromDDL( - "c2 ARRAY>") + val schemaWithArray = StructType.fromDDL("c1 INT, c2 ARRAY>") + val source5 = StructType.fromDDL("c1 INT") + val missing5 = StructType.fromDDL("c2 ARRAY>") assert( StructType.findMissingFields(source5, schemaWithArray, resolver) .exists(_.sameType(missing5))) val schemaWithMap1 = StructType.fromDDL( "c1 INT, c2 MAP, STRING>, c3 LONG") - val source6 = StructType.fromDDL( - "c1 INT, c3 LONG") - val missing6 = StructType.fromDDL( - "c2 MAP, STRING>") + val source6 = StructType.fromDDL("c1 INT, c3 LONG") + val missing6 = StructType.fromDDL("c2 MAP, STRING>") assert( StructType.findMissingFields(source6, schemaWithMap1, resolver) .exists(_.sameType(missing6))) val schemaWithMap2 = StructType.fromDDL( "c1 INT, c2 MAP>, c3 STRING") - val source7 = StructType.fromDDL( - "c1 INT, c3 STRING") - val missing7 = StructType.fromDDL( - "c2 MAP>") + val source7 = StructType.fromDDL("c1 INT, c3 STRING") + val missing7 = StructType.fromDDL("c2 MAP>") assert( StructType.findMissingFields(source7, schemaWithMap2, resolver) .exists(_.sameType(missing7))) // Unsupported: nested struct in array, map - val source8 = StructType.fromDDL( - "c1 INT, c2 ARRAY>") + val source8 = StructType.fromDDL("c1 INT, c2 ARRAY>") // `findMissingFields` doesn't support looking into nested struct in array type. assert(StructType.findMissingFields(source8, schemaWithArray, resolver).isEmpty) - val source9 = StructType.fromDDL( - "c1 INT, c2 MAP, STRING>, c3 LONG") + val source9 = StructType.fromDDL("c1 INT, c2 MAP, STRING>, c3 LONG") // `findMissingFields` doesn't support looking into nested struct in map type. assert(StructType.findMissingFields(source9, schemaWithMap1, resolver).isEmpty) - val source10 = StructType.fromDDL( - "c1 INT, c2 MAP>, c3 STRING") + val source10 = StructType.fromDDL("c1 INT, c2 MAP>, c3 STRING") // `findMissingFields` doesn't support looking into nested struct in map type. assert(StructType.findMissingFields(source10, schemaWithMap2, resolver).isEmpty) } test("find missing (nested) fields: case sensitive cases") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - val schema = StructType.fromDDL( - "c1 INT, c2 STRUCT>") + val schema = StructType.fromDDL("c1 INT, c2 STRUCT>") val resolver = SQLConf.get.resolver val source1 = StructType.fromDDL("c1 INT, C2 LONG") - val missing1 = StructType.fromDDL( - "c2 STRUCT>") + val missing1 = StructType.fromDDL("c2 STRUCT>") assert(StructType.findMissingFields(source1, schema, resolver) .exists(_.sameType(missing1))) val source2 = StructType.fromDDL("c2 LONG") - val missing2 = StructType.fromDDL( - "c1 INT") + val missing2 = StructType.fromDDL("c1 INT") assert(StructType.findMissingFields(source2, schema, resolver) .exists(_.sameType(missing2))) val source3 = StructType.fromDDL("c1 INT, c2 STRUCT>") - val missing3 = StructType.fromDDL( - "c2 STRUCT>") + val missing3 = StructType.fromDDL("c2 STRUCT>") assert(StructType.findMissingFields(source3, schema, resolver) .exists(_.sameType(missing3))) val source4 = StructType.fromDDL("c1 INT, c2 STRUCT>") - val missing4 = StructType.fromDDL( - "c2 STRUCT>") + val missing4 = StructType.fromDDL("c2 STRUCT>") assert(StructType.findMissingFields(source4, schema, resolver) .exists(_.sameType(missing4))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index a7dce9cc00af..72e04c75bb3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -573,34 +573,31 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") val df2 = Seq((1, UnionClass1b(1, 2L, UnionClass3(2, 3L)))).toDF("id", "a") + val expectedSchema = "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>" + var unionDf = df1.unionByName(df2, true) checkAnswer(unionDf, Row(0, Row(0, 1, Row(1, null, "2"))) :: Row(1, Row(1, 2, Row(2, 3L, null))) :: Nil) - assert(unionDf.schema.toDDL == - "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + - "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>") + assert(unionDf.schema.toDDL == expectedSchema) unionDf = df2.unionByName(df1, true) checkAnswer(unionDf, Row(1, Row(1, 2, Row(2, 3L, null))) :: Row(0, Row(0, 1, Row(1, null, "2"))) :: Nil) - assert(unionDf.schema.toDDL == - "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + - "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>") + assert(unionDf.schema.toDDL == expectedSchema) val df3 = Seq((2, UnionClass1b(2, 3L, null))).toDF("id", "a") unionDf = df1.unionByName(df3, true) checkAnswer(unionDf, Row(0, Row(0, 1, Row(1, null, "2"))) :: Row(2, Row(2, 3, null)) :: Nil) - assert(unionDf.schema.toDDL == - "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + - "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>") + assert(unionDf.schema.toDDL == expectedSchema) } test("SPARK-32376: Make unionByName null-filling behavior work with struct columns" + - "- case-sensitive cases") { + " - case-sensitive cases") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") val df2 = Seq((1, UnionClass1c(1, 2L, UnionClass4(2, 3L)))).toDF("id", "a") From 72800e6f755c2ae25bbe4273e7c06bed484795c1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 13 Sep 2020 22:00:36 -0700 Subject: [PATCH 09/23] Enhance comments. --- .../sql/catalyst/analysis/ResolveUnion.scala | 25 +++++++++++++------ .../expressions/complexTypeCreator.scala | 2 +- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 95d2e8d09b12..d8f209389e76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -34,9 +34,10 @@ import org.apache.spark.sql.util.SchemaUtils object ResolveUnion extends Rule[LogicalPlan] { /** * Adds missing fields recursively into given `col` expression, based on the target `StructType`. - * For example, given `col` as "a struct, b int" and `target` as - * "a struct, b int, c string", this method should add `a.c` and `c` to - * `col` expression. + * This is called by `compareAndAddFields` when we find two struct columns with same name but + * different nested fields. This method will find out the missing nested fields from `col` to + * `target` struct and add these missing nested fields. Currently we don't support finding out + * missing nested fields of struct nested in array or struct nested in map. */ private def addFields(col: NamedExpression, target: StructType): Option[Expression] = { assert(col.dataType.isInstanceOf[StructType], "Only support StructType.") @@ -53,9 +54,13 @@ object ResolveUnion extends Rule[LogicalPlan] { /** * Adds missing fields recursively into given `col` expression. The missing fields are given - * in `fields`. For example, given `col` as "a struct, b int", and `fields` is - * "a struct, c string". This method will add a nested `a.c` field and a top-level - * `c` field to `col` and fill null values for them. + * in `fields`. For example, given `col` as "z struct, x int", and `fields` is + * "z struct, w string". This method will add a nested `z.w` field and a top-level + * `w` field to `col` and fill null values for them. Note that because we might also add missing + * fields at other side of Union, we must make sure corresponding attributes at two sides have + * same field order in structs, so when we adding missing fields, we will sort the fields based on + * field names. So the data type of returned expression will be + * "w string, x int, z struct". */ private def addFieldsInto(col: Expression, base: String, fields: Seq[StructField]): Expression = { fields.foldLeft(col) { case (currCol, field) => @@ -82,6 +87,12 @@ object ResolveUnion extends Rule[LogicalPlan] { } } + /** + * This method will compare right to left plan's outputs. If there is one struct attribute + * at right side has same name with left side struct attribute, but two structs are not the + * same data type, i.e., some missing (nested) fields at right struct attribute, then this + * method will try to add missing (nested) fields into the right attribute with null values. + */ private def compareAndAddFields( left: LogicalPlan, right: LogicalPlan, @@ -115,7 +126,7 @@ object ResolveUnion extends Rule[LogicalPlan] { // types. We don't support adding missing fields of nested structs in array or map // types now. // 3. `allowMissingCol` is disabled. - found.get + foundAttr } } else { if (allowMissingCol) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index d11bb195450a..57aa9a2786e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -609,7 +609,7 @@ object WithFields { * Adds/replaces field in `StructType` into `col` expression by name. */ def apply(col: Expression, fieldName: String, expr: Expression): Expression = { - WithFields(col, fieldName, expr, false) + WithFields(col, fieldName, expr, sortOutputColumns = false) } def apply( From 337cea7ebccbce3b78ee59818311bd06b2e4a072 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 13 Sep 2020 22:22:54 -0700 Subject: [PATCH 10/23] Add SQL config. --- .../sql/catalyst/analysis/ResolveUnion.scala | 4 +- .../apache/spark/sql/internal/SQLConf.scala | 13 ++ .../scala/org/apache/spark/sql/Dataset.scala | 3 +- .../sql/DataFrameSetOperationsSuite.scala | 167 ++++++++++-------- 4 files changed, 113 insertions(+), 74 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index d8f209389e76..d63768fb2b47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -103,6 +103,8 @@ object ResolveUnion extends Rule[LogicalPlan] { val aliased = mutable.ArrayBuffer.empty[Attribute] + val supportStruct = SQLConf.get.unionByNameStructSupportEnabled + val rightProjectList = leftOutputAttrs.map { lattr => val found = rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) } if (found.isDefined) { @@ -110,7 +112,7 @@ object ResolveUnion extends Rule[LogicalPlan] { val foundDt = foundAttr.dataType (foundDt, lattr.dataType) match { case (source: StructType, target: StructType) - if allowMissingCol && !source.sameType(target) => + if supportStruct && allowMissingCol && !source.sameType(target) => // Having an output with same name, but different struct type. // We need to add missing fields. Note that if there are deeply nested structs such as // nested struct of array in struct, we don't support to add missing deeply nested field diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 57c5c39bdeb7..b71cc24491c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2712,6 +2712,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val UNION_BYNAME_STRUCT_SUPPORT_ENABLED = + buildConf("spark.sql.unionByName.structSupport.enabled") + .doc("When true, the `allowMissingColumns` feature of `Dataset.unionByName` supports " + + "nested column in struct types. Missing nested columns of struct columns with same " + + "name will also be filled with null values. This currently does not support nested " + + "columns in array and map types.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + /** * Holds information about keys that have been deprecated. * @@ -3020,6 +3030,9 @@ class SQLConf extends Serializable with Logging { LegacyBehaviorPolicy.withName(getConf(SQLConf.LEGACY_TIME_PARSER_POLICY)) } + def unionByNameStructSupportEnabled: Boolean = + getConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED) + def broadcastHashJoinOutputPartitioningExpandLimit: Int = getConf(BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 4b8d64fb6cf7..79995a75f416 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2067,7 +2067,8 @@ class Dataset[T] private[sql]( * // +----+----+----+----+ * }}} * - * Note that `allowMissingColumns` supports nested column in struct types. Missing nested columns + * Note that `allowMissingColumns` supports nested column in struct types, if the config + * `spark.sql.unionByName.structSupport.enabled` is enabled. Missing nested columns * of struct columns with same name will also be filled with null values. This currently does not * support nested columns in array and map types. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index 72e04c75bb3b..2ead94a3f827 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -507,98 +507,107 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { } test("SPARK-29358: Make unionByName optionally fill missing columns with nulls") { - var df1 = Seq(1, 2, 3).toDF("a") - var df2 = Seq(3, 1, 2).toDF("b") - val df3 = Seq(2, 3, 1).toDF("c") - val unionDf = df1.unionByName(df2.unionByName(df3, true), true) - checkAnswer(unionDf, - Row(1, null, null) :: Row(2, null, null) :: Row(3, null, null) :: // df1 - Row(null, 3, null) :: Row(null, 1, null) :: Row(null, 2, null) :: // df2 - Row(null, null, 2) :: Row(null, null, 3) :: Row(null, null, 1) :: Nil // df3 - ) - - df1 = Seq((1, 2)).toDF("a", "c") - df2 = Seq((3, 4, 5)).toDF("a", "b", "c") - checkAnswer(df1.unionByName(df2, true), - Row(1, 2, null) :: Row(3, 5, 4) :: Nil) - checkAnswer(df2.unionByName(df1, true), - Row(3, 4, 5) :: Row(1, null, 2) :: Nil) - - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - df2 = Seq((3, 4, 5)).toDF("a", "B", "C") - val union1 = df1.unionByName(df2, true) - val union2 = df2.unionByName(df1, true) - - checkAnswer(union1, Row(1, 2, null, null) :: Row(3, null, 4, 5) :: Nil) - checkAnswer(union2, Row(3, 4, 5, null) :: Row(1, null, null, 2) :: Nil) - - assert(union1.schema.fieldNames === Array("a", "c", "B", "C")) - assert(union2.schema.fieldNames === Array("a", "B", "C", "c")) + Seq("true", "false").foreach { config => + withSQLConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED.key -> config) { + var df1 = Seq(1, 2, 3).toDF("a") + var df2 = Seq(3, 1, 2).toDF("b") + val df3 = Seq(2, 3, 1).toDF("c") + val unionDf = df1.unionByName(df2.unionByName(df3, true), true) + checkAnswer(unionDf, + Row(1, null, null) :: Row(2, null, null) :: Row(3, null, null) :: // df1 + Row(null, 3, null) :: Row(null, 1, null) :: Row(null, 2, null) :: // df2 + Row(null, null, 2) :: Row(null, null, 3) :: Row(null, null, 1) :: Nil // df3 + ) + + df1 = Seq((1, 2)).toDF("a", "c") + df2 = Seq((3, 4, 5)).toDF("a", "b", "c") + checkAnswer(df1.unionByName(df2, true), + Row(1, 2, null) :: Row(3, 5, 4) :: Nil) + checkAnswer(df2.unionByName(df1, true), + Row(3, 4, 5) :: Row(1, null, 2) :: Nil) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + df2 = Seq((3, 4, 5)).toDF("a", "B", "C") + val union1 = df1.unionByName(df2, true) + val union2 = df2.unionByName(df1, true) + + checkAnswer(union1, Row(1, 2, null, null) :: Row(3, null, 4, 5) :: Nil) + checkAnswer(union2, Row(3, 4, 5, null) :: Row(1, null, null, 2) :: Nil) + + assert(union1.schema.fieldNames === Array("a", "c", "B", "C")) + assert(union2.schema.fieldNames === Array("a", "B", "C", "c")) + } + } } } test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - simple") { - val df1 = Seq(((1, 2, 3), 0), ((2, 3, 4), 1), ((3, 4, 5), 2)).toDF("a", "idx") - val df2 = Seq(((3, 4), 0), ((1, 2), 1), ((2, 3), 2)).toDF("a", "idx") - val df3 = Seq(((100, 101, 102, 103), 0), ((110, 111, 112, 113), 1), ((120, 121, 122, 123), 2)) - .toDF("a", "idx") + withSQLConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED.key -> "true") { + val df1 = Seq(((1, 2, 3), 0), ((2, 3, 4), 1), ((3, 4, 5), 2)).toDF("a", "idx") + val df2 = Seq(((3, 4), 0), ((1, 2), 1), ((2, 3), 2)).toDF("a", "idx") + val df3 = Seq(((100, 101, 102, 103), 0), ((110, 111, 112, 113), 1), ((120, 121, 122, 123), 2)) + .toDF("a", "idx") - var unionDf = df1.unionByName(df2, true) + var unionDf = df1.unionByName(df2, true) - checkAnswer(unionDf, - Row(Row(1, 2, 3), 0) :: Row(Row(2, 3, 4), 1) :: Row(Row(3, 4, 5), 2) :: // df1 - Row(Row(3, 4, null), 0) :: Row(Row(1, 2, null), 1) :: Row(Row(2, 3, null), 2) :: Nil // df2 - ) + checkAnswer(unionDf, + Row(Row(1, 2, 3), 0) :: Row(Row(2, 3, 4), 1) :: Row(Row(3, 4, 5), 2) :: + Row(Row(3, 4, null), 0) :: Row(Row(1, 2, null), 1) :: Row(Row(2, 3, null), 2) :: Nil + ) - assert(unionDf.schema.toDDL == "`a` STRUCT<`_1`: INT, `_2`: INT, `_3`: INT>,`idx` INT") + assert(unionDf.schema.toDDL == "`a` STRUCT<`_1`: INT, `_2`: INT, `_3`: INT>,`idx` INT") - unionDf = df1.unionByName(df2, true).unionByName(df3, true) + unionDf = df1.unionByName(df2, true).unionByName(df3, true) - checkAnswer(unionDf, - Row(Row(1, 2, 3, null), 0) :: - Row(Row(2, 3, 4, null), 1) :: - Row(Row(3, 4, 5, null), 2) :: // df1 - Row(Row(3, 4, null, null), 0) :: - Row(Row(1, 2, null, null), 1) :: - Row(Row(2, 3, null, null), 2) :: // df2 - Row(Row(100, 101, 102, 103), 0) :: - Row(Row(110, 111, 112, 113), 1) :: - Row(Row(120, 121, 122, 123), 2) :: Nil // df3 - ) - assert(unionDf.schema.toDDL == - "`a` STRUCT<`_1`: INT, `_2`: INT, `_3`: INT, `_4`: INT>,`idx` INT") + checkAnswer(unionDf, + Row(Row(1, 2, 3, null), 0) :: + Row(Row(2, 3, 4, null), 1) :: + Row(Row(3, 4, 5, null), 2) :: // df1 + Row(Row(3, 4, null, null), 0) :: + Row(Row(1, 2, null, null), 1) :: + Row(Row(2, 3, null, null), 2) :: // df2 + Row(Row(100, 101, 102, 103), 0) :: + Row(Row(110, 111, 112, 113), 1) :: + Row(Row(120, 121, 122, 123), 2) :: Nil // df3 + ) + assert(unionDf.schema.toDDL == + "`a` STRUCT<`_1`: INT, `_2`: INT, `_3`: INT, `_4`: INT>,`idx` INT") + } } test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - nested") { - val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") - val df2 = Seq((1, UnionClass1b(1, 2L, UnionClass3(2, 3L)))).toDF("id", "a") + withSQLConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED.key -> "true") { + val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") + val df2 = Seq((1, UnionClass1b(1, 2L, UnionClass3(2, 3L)))).toDF("id", "a") - val expectedSchema = "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + - "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>" + val expectedSchema = "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>" - var unionDf = df1.unionByName(df2, true) - checkAnswer(unionDf, - Row(0, Row(0, 1, Row(1, null, "2"))) :: - Row(1, Row(1, 2, Row(2, 3L, null))) :: Nil) - assert(unionDf.schema.toDDL == expectedSchema) + var unionDf = df1.unionByName(df2, true) + checkAnswer(unionDf, + Row(0, Row(0, 1, Row(1, null, "2"))) :: + Row(1, Row(1, 2, Row(2, 3L, null))) :: Nil) + assert(unionDf.schema.toDDL == expectedSchema) - unionDf = df2.unionByName(df1, true) - checkAnswer(unionDf, - Row(1, Row(1, 2, Row(2, 3L, null))) :: - Row(0, Row(0, 1, Row(1, null, "2"))) :: Nil) - assert(unionDf.schema.toDDL == expectedSchema) + unionDf = df2.unionByName(df1, true) + checkAnswer(unionDf, + Row(1, Row(1, 2, Row(2, 3L, null))) :: + Row(0, Row(0, 1, Row(1, null, "2"))) :: Nil) + assert(unionDf.schema.toDDL == expectedSchema) - val df3 = Seq((2, UnionClass1b(2, 3L, null))).toDF("id", "a") - unionDf = df1.unionByName(df3, true) - checkAnswer(unionDf, - Row(0, Row(0, 1, Row(1, null, "2"))) :: - Row(2, Row(2, 3, null)) :: Nil) - assert(unionDf.schema.toDDL == expectedSchema) + val df3 = Seq((2, UnionClass1b(2, 3L, null))).toDF("id", "a") + unionDf = df1.unionByName(df3, true) + checkAnswer(unionDf, + Row(0, Row(0, 1, Row(1, null, "2"))) :: + Row(2, Row(2, 3, null)) :: Nil) + assert(unionDf.schema.toDDL == expectedSchema) + } } test("SPARK-32376: Make unionByName null-filling behavior work with struct columns" + " - case-sensitive cases") { - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true", + SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED.key -> "true") { val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") val df2 = Seq((1, UnionClass1c(1, 2L, UnionClass4(2, 3L)))).toDF("id", "a") @@ -628,6 +637,20 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { "`nested`: STRUCT<`A`: INT, `a`: INT, `b`: BIGINT>>") } } + + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - disable") { + withSQLConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED.key -> "false") { + val df1 = Seq(((1, 2, 3), 0), ((2, 3, 4), 1), ((3, 4, 5), 2)).toDF("a", "idx") + val df2 = Seq(((3, 4), 0), ((1, 2), 1), ((2, 3), 2)).toDF("a", "idx") + + val err = intercept[AnalysisException] { + df1.unionByName(df2, true).collect() + } + assert(err.getMessage.contains("Union can only be performed on tables with the compatible " + + "column types. struct<_1:int,_2:int> <> struct<_1:int,_2:int,_3:int> at the first column " + + "of the second table")) + } + } } case class UnionClass1a(a: Int, b: Long, nested: UnionClass2) From 7b0d65d74e305492af5a5b2551cd3628d9be068c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 14 Sep 2020 17:40:06 -0700 Subject: [PATCH 11/23] Move sorting column out of WithFields. --- .../sql/catalyst/analysis/ResolveUnion.scala | 50 +++++++++++++++---- .../expressions/complexTypeCreator.scala | 33 +++--------- .../sql/catalyst/optimizer/ComplexTypes.scala | 2 +- .../sql/catalyst/optimizer/WithFields.scala | 5 +- 4 files changed, 51 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index d63768fb2b47..51641101fc1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -20,18 +20,47 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, Literal, NamedExpression, WithFields} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateNamedStruct, Expression, GetStructField, If, IsNull, KnownNotNull, Literal, NamedExpression, WithFields} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils +import org.apache.spark.unsafe.types.UTF8String /** * Resolves different children of Union to a common set of columns. */ object ResolveUnion extends Rule[LogicalPlan] { + private def sortStructFields(fieldExprs: Seq[Expression]): Seq[Expression] = { + fieldExprs.grouped(2).map { e => + Seq(e.head, e.last) + }.toSeq.sortBy { pair => + assert(pair(0).isInstanceOf[Literal]) + pair(0).eval().asInstanceOf[UTF8String].toString + }.flatten + } + + /** + * This helper method sorts fields in a `WithFields` expression by field name. + */ + private def sortStructFieldsInWithFields(expr: Expression): Expression = expr transformUp { + case w: WithFields if w.resolved => + w.evalExpr match { + case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) => + val sorted = sortStructFields(fieldExprs) + val newStruct = CreateNamedStruct(sorted) + i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = newStruct) + case CreateNamedStruct(fieldExprs) => + val sorted = sortStructFields(fieldExprs) + val newStruct = CreateNamedStruct(sorted) + newStruct + case other => + throw new AnalysisException(s"`WithFields` has incorrect eval expression: $other") + } + } + /** * Adds missing fields recursively into given `col` expression, based on the target `StructType`. * This is called by `compareAndAddFields` when we find two struct columns with same name but @@ -48,7 +77,14 @@ object ResolveUnion extends Rule[LogicalPlan] { if (missingFields.isEmpty) { None } else { - missingFields.map(s => addFieldsInto(col, "", s.fields)) + missingFields.map { s => + val struct = addFieldsInto(col, "", s.fields) + // We need to sort columns in result, because we might add another column in other side. + // E.g., we want to union two structs "a int, b long" and "a int, c string". + // If we don't sort, we will have "a int, b long, c string" and + // "a int, c string, b long", which are not compatible. + sortStructFieldsInWithFields(struct) + } } } @@ -71,18 +107,12 @@ object ResolveUnion extends Rule[LogicalPlan] { .find(f => resolver(f.name, field.name)) if (colField.isEmpty) { // The whole struct is missing. Add a null. - WithFields(currCol, s"$base${field.name}", Literal(null, st), - sortOutputColumns = true) + WithFields(currCol, s"$base${field.name}", Literal(null, st)) } else { addFieldsInto(currCol, s"$base${field.name}.", st.fields) } case dt => - // We need to sort columns in result, because we might add another column in other side. - // E.g., we want to union two structs "a int, b long" and "a int, c string". - // If we don't sort, we will have "a int, b long, c string" and "a int, c string, b long", - // which are not compatible. - WithFields(currCol, s"$base${field.name}", Literal(null, dt), - sortOutputColumns = true) + WithFields(currCol, s"$base${field.name}", Literal(null, dt)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 57aa9a2786e3..e9a0b9c6c142 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -547,8 +547,7 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E case class WithFields( structExpr: Expression, names: Seq[String], - valExprs: Seq[Expression], - sortOutputColumns: Boolean = false) extends Unevaluable { + valExprs: Seq[Expression]) extends Unevaluable { assert(names.length == valExprs.length) @@ -587,15 +586,9 @@ case class WithFields( } else { resultExprs :+ newExpr } - } - - val finalExprs = if (sortOutputColumns) { - newExprs.sortBy(_._1).flatMap { case (name, expr) => Seq(Literal(name), expr) } - } else { - newExprs.flatMap { case (name, expr) => Seq(Literal(name), expr) } - } + }.flatMap { case (name, expr) => Seq(Literal(name), expr) } - val expr = CreateNamedStruct(finalExprs) + val expr = CreateNamedStruct(newExprs) if (structExpr.nullable) { If(IsNull(structExpr), Literal(null, expr.dataType), expr) } else { @@ -609,31 +602,22 @@ object WithFields { * Adds/replaces field in `StructType` into `col` expression by name. */ def apply(col: Expression, fieldName: String, expr: Expression): Expression = { - WithFields(col, fieldName, expr, sortOutputColumns = false) - } - - def apply( - col: Expression, - fieldName: String, - expr: Expression, - sortOutputColumns: Boolean): Expression = { val nameParts = if (fieldName.isEmpty) { fieldName :: Nil } else { CatalystSqlParser.parseMultipartIdentifier(fieldName) } - withFieldHelper(col, nameParts, Nil, expr, sortOutputColumns) + withFieldHelper(col, nameParts, Nil, expr) } private def withFieldHelper( struct: Expression, namePartsRemaining: Seq[String], namePartsDone: Seq[String], - value: Expression, - sortOutputColumns: Boolean) : WithFields = { + value: Expression) : WithFields = { val name = namePartsRemaining.head if (namePartsRemaining.length == 1) { - WithFields(struct, name :: Nil, value :: Nil, sortOutputColumns) + WithFields(struct, name :: Nil, value :: Nil) } else { val newNamesRemaining = namePartsRemaining.tail val newNamesDone = namePartsDone :+ name @@ -649,9 +633,8 @@ object WithFields { struct = newStruct, namePartsRemaining = newNamesRemaining, namePartsDone = newNamesDone, - value = value, - sortOutputColumns = sortOutputColumns) - WithFields(struct, name :: Nil, newValue :: Nil, sortOutputColumns) + value = value) + WithFields(struct, name :: Nil, newValue :: Nil) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 9cf4465385eb..2aba4bae397c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -39,7 +39,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { // Remove redundant field extraction. case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => createNamedStruct.valExprs(ordinal) - case GetStructField(w @ WithFields(struct, names, valExprs, _), ordinal, maybeName) => + case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) => val name = w.dataType(ordinal).name val matches = names.zip(valExprs).filter(_._1 == name) if (matches.nonEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala index 23d23afc11ea..05c90864e4bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala @@ -27,9 +27,8 @@ import org.apache.spark.sql.catalyst.rules.Rule */ object CombineWithFields extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case WithFields(WithFields(struct, names1, valExprs1, sort1), names2, valExprs2, sort2) - if sort1 == sort2 => - WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2, sort1) + case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) => + WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2) } } From a77481e426a565cfc768cef1aa43489fe1958980 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 14 Sep 2020 18:14:24 -0700 Subject: [PATCH 12/23] Fix edge case. --- .../sql/catalyst/analysis/ResolveUnion.scala | 33 ++++++++++++++----- .../sql/DataFrameSetOperationsSuite.scala | 25 ++++++++++++++ 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 51641101fc1c..ca80fae5e199 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -33,6 +33,24 @@ import org.apache.spark.unsafe.types.UTF8String * Resolves different children of Union to a common set of columns. */ object ResolveUnion extends Rule[LogicalPlan] { + /** + * This method sorts columns in a struct expression based on column names. + */ + private def sortStructFields(expr: Expression): Expression = { + assert(expr.dataType.isInstanceOf[StructType]) + + val existingExprs = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { + case (name, i) => (name, GetStructField(KnownNotNull(expr), i).asInstanceOf[Expression]) + }.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2)) + + val newExpr = CreateNamedStruct(existingExprs) + if (expr.nullable) { + If(IsNull(expr), Literal(null, newExpr.dataType), newExpr) + } else { + newExpr + } + } + private def sortStructFields(fieldExprs: Seq[Expression]): Seq[Expression] = { fieldExprs.grouped(2).map { e => Seq(e.head, e.last) @@ -68,14 +86,14 @@ object ResolveUnion extends Rule[LogicalPlan] { * `target` struct and add these missing nested fields. Currently we don't support finding out * missing nested fields of struct nested in array or struct nested in map. */ - private def addFields(col: NamedExpression, target: StructType): Option[Expression] = { + private def addFields(col: NamedExpression, target: StructType): Expression = { assert(col.dataType.isInstanceOf[StructType], "Only support StructType.") val resolver = SQLConf.get.resolver val missingFields = StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver) if (missingFields.isEmpty) { - None + sortStructFields(col) } else { missingFields.map { s => val struct = addFieldsInto(col, "", s.fields) @@ -84,7 +102,7 @@ object ResolveUnion extends Rule[LogicalPlan] { // If we don't sort, we will have "a int, b long, c string" and // "a int, c string, b long", which are not compatible. sortStructFieldsInWithFields(struct) - } + }.get } } @@ -146,11 +164,10 @@ object ResolveUnion extends Rule[LogicalPlan] { // Having an output with same name, but different struct type. // We need to add missing fields. Note that if there are deeply nested structs such as // nested struct of array in struct, we don't support to add missing deeply nested field - // like that. For such case, simply use original attribute. - addFields(foundAttr, target).map { added => - aliased += foundAttr - Alias(added, foundAttr.name)() - }.getOrElse(foundAttr) + // like that. We will sort columns in the struct expression to make sure two sides of + // union have consistent schema. + aliased += foundAttr + Alias(addFields(foundAttr, target), foundAttr.name)() case _ => // We don't need/try to add missing fields if: // 1. The attributes of left and right side are the same struct type diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index 2ead94a3f827..68f1c44488f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -638,6 +638,31 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { } } + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - edge case") { + withSQLConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED.key -> "true") { + val nestedStructType1 = StructType(Seq( + StructField("b", StringType))) + val nestedStructValues1 = Row("b") + + val nestedStructType2 = StructType(Seq( + StructField("b", StringType), + StructField("a", StringType))) + val nestedStructValues2 = Row("b", "a") + + val df1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues1) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType1)))) + + val df2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues2) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType2)))) + + val union = df1.unionByName(df2, allowMissingColumns = true) + checkAnswer(union, Row(Row(null, "b")) :: Row(Row("a", "b")) :: Nil) + assert(union.schema.toDDL == "`topLevelCol` STRUCT<`a`: STRING, `b`: STRING>") + } + } + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - disable") { withSQLConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED.key -> "false") { val df1 = Seq(((1, 2, 3), 0), ((2, 3, 4), 1), ((3, 4, 5), 2)).toDF("a", "idx") From b4270f4f8879f3225399c93456beb30e2a2c78e9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 14 Sep 2020 18:16:56 -0700 Subject: [PATCH 13/23] Move comment around. --- .../spark/sql/catalyst/analysis/ResolveUnion.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index ca80fae5e199..d67866b7aac9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -92,15 +92,16 @@ object ResolveUnion extends Rule[LogicalPlan] { val resolver = SQLConf.get.resolver val missingFields = StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver) + + // We need to sort columns in result, because we might add another column in other side. + // E.g., we want to union two structs "a int, b long" and "a int, c string". + // If we don't sort, we will have "a int, b long, c string" and + // "a int, c string, b long", which are not compatible. if (missingFields.isEmpty) { sortStructFields(col) } else { missingFields.map { s => val struct = addFieldsInto(col, "", s.fields) - // We need to sort columns in result, because we might add another column in other side. - // E.g., we want to union two structs "a int, b long" and "a int, c string". - // If we don't sort, we will have "a int, b long, c string" and - // "a int, c string, b long", which are not compatible. sortStructFieldsInWithFields(struct) }.get } From 61ff46f675c1e49a7f562d2529f6982f49476579 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 16 Sep 2020 15:59:56 -0700 Subject: [PATCH 14/23] Address comments. --- .../sql/catalyst/analysis/ResolveUnion.scala | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index d67866b7aac9..3765ad277f5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateNamedStruct, Expression, GetStructField, If, IsNull, KnownNotNull, Literal, NamedExpression, WithFields} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule @@ -37,10 +37,8 @@ object ResolveUnion extends Rule[LogicalPlan] { * This method sorts columns in a struct expression based on column names. */ private def sortStructFields(expr: Expression): Expression = { - assert(expr.dataType.isInstanceOf[StructType]) - val existingExprs = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { - case (name, i) => (name, GetStructField(KnownNotNull(expr), i).asInstanceOf[Expression]) + case (name, i) => (name, GetStructField(KnownNotNull(expr), i)) }.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2)) val newExpr = CreateNamedStruct(existingExprs) @@ -51,12 +49,16 @@ object ResolveUnion extends Rule[LogicalPlan] { } } - private def sortStructFields(fieldExprs: Seq[Expression]): Seq[Expression] = { + /** + * Assumes input expressions are field expression of `CreateNamedStruct`. This method + * sorts the expressions based on field names. + */ + private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = { fieldExprs.grouped(2).map { e => Seq(e.head, e.last) }.toSeq.sortBy { pair => - assert(pair(0).isInstanceOf[Literal]) - pair(0).eval().asInstanceOf[UTF8String].toString + assert(pair.head.isInstanceOf[Literal]) + pair.head.eval().asInstanceOf[UTF8String].toString }.flatten } @@ -67,11 +69,11 @@ object ResolveUnion extends Rule[LogicalPlan] { case w: WithFields if w.resolved => w.evalExpr match { case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) => - val sorted = sortStructFields(fieldExprs) + val sorted = sortFieldExprs(fieldExprs) val newStruct = CreateNamedStruct(sorted) i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = newStruct) case CreateNamedStruct(fieldExprs) => - val sorted = sortStructFields(fieldExprs) + val sorted = sortFieldExprs(fieldExprs) val newStruct = CreateNamedStruct(sorted) newStruct case other => From 9040c56b86a932d7740186637f1e4c9accbee216 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 18 Sep 2020 23:08:08 -0700 Subject: [PATCH 15/23] Combine WithFields. --- .../apache/spark/sql/catalyst/analysis/ResolveUnion.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 3765ad277f5e..38b17c2d4121 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -104,7 +104,12 @@ object ResolveUnion extends Rule[LogicalPlan] { } else { missingFields.map { s => val struct = addFieldsInto(col, "", s.fields) - sortStructFieldsInWithFields(struct) + // Combines `WithFields`s to reduce expression tree. + val reducedStruct = struct.transformUp { + case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) => + WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2) + } + sortStructFieldsInWithFields(reducedStruct) }.get } } From 9b21d9172f26fc07b3173be9ad56ae7cd1487305 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 19 Sep 2020 12:45:00 -0700 Subject: [PATCH 16/23] Optimize WithFields expression chain. --- .../sql/catalyst/analysis/ResolveUnion.scala | 51 ++++++++++++---- .../expressions/complexTypeExtractors.scala | 4 +- .../sql/DataFrameSetOperationsSuite.scala | 59 +++++++++++++++++++ 3 files changed, 101 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 38b17c2d4121..6e523e4b571d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -34,11 +34,17 @@ import org.apache.spark.unsafe.types.UTF8String */ object ResolveUnion extends Rule[LogicalPlan] { /** - * This method sorts columns in a struct expression based on column names. + * This method sorts recursively columns in a struct expression based on column names. */ private def sortStructFields(expr: Expression): Expression = { val existingExprs = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { - case (name, i) => (name, GetStructField(KnownNotNull(expr), i)) + case (name, i) => + val fieldExpr = GetStructField(KnownNotNull(expr), i) + if (fieldExpr.dataType.isInstanceOf[StructType]) { + (name, sortStructFields(fieldExpr)) + } else { + (name, fieldExpr) + } }.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2)) val newExpr = CreateNamedStruct(existingExprs) @@ -81,6 +87,26 @@ object ResolveUnion extends Rule[LogicalPlan] { } } + def simplifyWithFields(expr: Expression): Expression = { + expr.transformUp { + case WithFields(structExpr, names, values) if names.distinct.length != names.length => + val newNames = mutable.ArrayBuffer.empty[String] + val newValues = mutable.ArrayBuffer.empty[Expression] + names.zip(values).reverse.foreach { case (name, value) => + if (!newNames.contains(name)) { + newNames += name + newValues += value + } + } + WithFields(structExpr, names = newNames.reverse, valExprs = newValues.reverse) + case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) => + WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2) + case g @ GetStructField(WithFields(_, names, values), _, _) + if names.contains(g.extractFieldName) => + names.zip(values).reverse.filter(p => p._1 == g.extractFieldName).head._2 + } + } + /** * Adds missing fields recursively into given `col` expression, based on the target `StructType`. * This is called by `compareAndAddFields` when we find two struct columns with same name but @@ -103,13 +129,11 @@ object ResolveUnion extends Rule[LogicalPlan] { sortStructFields(col) } else { missingFields.map { s => - val struct = addFieldsInto(col, "", s.fields) + val struct = addFieldsInto(col, s.fields) // Combines `WithFields`s to reduce expression tree. - val reducedStruct = struct.transformUp { - case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) => - WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2) - } - sortStructFieldsInWithFields(reducedStruct) + val reducedStruct = simplifyWithFields(struct) + val sorted = sortStructFieldsInWithFields(reducedStruct) + sorted }.get } } @@ -124,7 +148,9 @@ object ResolveUnion extends Rule[LogicalPlan] { * field names. So the data type of returned expression will be * "w string, x int, z struct". */ - private def addFieldsInto(col: Expression, base: String, fields: Seq[StructField]): Expression = { + private def addFieldsInto( + col: Expression, + fields: Seq[StructField]): Expression = { fields.foldLeft(col) { case (currCol, field) => field.dataType match { case st: StructType => @@ -133,12 +159,13 @@ object ResolveUnion extends Rule[LogicalPlan] { .find(f => resolver(f.name, field.name)) if (colField.isEmpty) { // The whole struct is missing. Add a null. - WithFields(currCol, s"$base${field.name}", Literal(null, st)) + WithFields(currCol, field.name, Literal(null, st)) } else { - addFieldsInto(currCol, s"$base${field.name}.", st.fields) + WithFields(currCol, field.name, + addFieldsInto(ExtractValue(currCol, Literal(field.name), resolver), st.fields)) } case dt => - WithFields(currCol, s"$base${field.name}", Literal(null, dt)) + WithFields(currCol, field.name, Literal(null, dt)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 89ff4facd25a..60afe140960c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -116,8 +116,10 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] s"$child.${name.getOrElse(fieldName)}" } + def extractFieldName: String = name.getOrElse(childSchema(ordinal).name) + override def sql: String = - child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}" + child.sql + s".${quoteIdentifier(extractFieldName)}" protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index 68f1c44488f5..41440a4e711d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -676,6 +676,65 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { "of the second table")) } } + + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - deep expr") { + def nestedDf(depth: Int, numColsAtEachDepth: Int): DataFrame = { + val initialNestedStructType = StructType( + (0 to numColsAtEachDepth).map(i => + StructField(s"nested${depth}Col$i", IntegerType, nullable = false)) + ) + val initialNestedValues = Row(0 to numColsAtEachDepth: _*) + + var depthCounter = depth - 1 + var structType = initialNestedStructType + var struct = initialNestedValues + while (depthCounter != 0) { + struct = Row((struct +: (1 to numColsAtEachDepth)): _*) + structType = StructType( + StructField(s"nested${depthCounter}Col0", structType, nullable = false) +: + (1 to numColsAtEachDepth).map(i => + StructField(s"nested${depthCounter}Col$i", IntegerType, nullable = false)) + ) + depthCounter -= 1 + } + + val df: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(struct) :: Nil), + StructType(Seq(StructField("nested0Col0", structType)))) + + df + } + withSQLConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED.key -> "true") { + val df1 = nestedDf(depth = 10, numColsAtEachDepth = 1) + val df2 = nestedDf(depth = 10, numColsAtEachDepth = 20) + val union = df1.unionByName(df2, allowMissingColumns = true) + // scalastyle:off + val row1 = Row(Row(Row(Row(Row(Row(Row(Row(Row(Row( + Row(0, 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null)) + val row2 = Row(Row(Row(Row(Row(Row(Row(Row(Row(Row( + Row(0, 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9)) + // scalastyle:on + checkAnswer(union, row1 :: row2 :: Nil) + } + } } case class UnionClass1a(a: Int, b: Long, nested: UnionClass2) From 18298899ee94911531a3f441bff636517ae2e68f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 5 Oct 2020 23:28:39 -0700 Subject: [PATCH 17/23] Address comments. --- .../apache/spark/sql/catalyst/analysis/ResolveUnion.scala | 3 ++- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 5 ++++- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 6e523e4b571d..273b00f9d654 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -83,7 +83,8 @@ object ResolveUnion extends Rule[LogicalPlan] { val newStruct = CreateNamedStruct(sorted) newStruct case other => - throw new AnalysisException(s"`WithFields` has incorrect eval expression: $other") + throw new AnalysisException(s"`WithFields` has incorrect eval expression: $other. " + + "Please file a bug report with this error message, stack trace, and the query.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f6c9a72e0703..599035b8380c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2726,7 +2726,10 @@ object SQLConf { .doc("When true, the `allowMissingColumns` feature of `Dataset.unionByName` supports " + "nested column in struct types. Missing nested columns of struct columns with same " + "name will also be filled with null values. This currently does not support nested " + - "columns in array and map types.") + "columns in array and map types. Note that if there is any missing nested columns " + + "to be filled, in order to make consistent schema between two sides of union, the " + + "nested fields of structs will be sorted after merging schema." + ) .version("3.1.0") .booleanConf .createWithDefault(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b91a66761765..8d296ce54984 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2070,7 +2070,9 @@ class Dataset[T] private[sql]( * Note that `allowMissingColumns` supports nested column in struct types, if the config * `spark.sql.unionByName.structSupport.enabled` is enabled. Missing nested columns * of struct columns with same name will also be filled with null values. This currently does not - * support nested columns in array and map types. + * support nested columns in array and map types. Note that if there is any missing nested columns + * to be filled, in order to make consistent schema between two sides of union, the nested fields + * of structs will be sorted after merging schema. * * @group typedrel * @since 3.1.0 From 8a9522e5c0c30634c8031b37daaff2bd84c83ef7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 6 Oct 2020 13:59:07 -0700 Subject: [PATCH 18/23] Synced up with the change from WithFields to UpdateFields. --- .../sql/catalyst/analysis/ResolveUnion.scala | 35 ++++++++----------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 273b00f9d654..579b70e5cab4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -69,11 +69,11 @@ object ResolveUnion extends Rule[LogicalPlan] { } /** - * This helper method sorts fields in a `WithFields` expression by field name. + * This helper method sorts fields in a `UpdateFields` expression by field name. */ private def sortStructFieldsInWithFields(expr: Expression): Expression = expr transformUp { - case w: WithFields if w.resolved => - w.evalExpr match { + case u: UpdateFields if u.resolved => + u.evalExpr match { case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) => val sorted = sortFieldExprs(fieldExprs) val newStruct = CreateNamedStruct(sorted) @@ -83,27 +83,20 @@ object ResolveUnion extends Rule[LogicalPlan] { val newStruct = CreateNamedStruct(sorted) newStruct case other => - throw new AnalysisException(s"`WithFields` has incorrect eval expression: $other. " + + throw new AnalysisException(s"`UpdateFields` has incorrect eval expression: $other. " + "Please file a bug report with this error message, stack trace, and the query.") } } def simplifyWithFields(expr: Expression): Expression = { expr.transformUp { - case WithFields(structExpr, names, values) if names.distinct.length != names.length => - val newNames = mutable.ArrayBuffer.empty[String] - val newValues = mutable.ArrayBuffer.empty[Expression] - names.zip(values).reverse.foreach { case (name, value) => - if (!newNames.contains(name)) { - newNames += name - newValues += value - } - } - WithFields(structExpr, names = newNames.reverse, valExprs = newValues.reverse) - case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) => - WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2) - case g @ GetStructField(WithFields(_, names, values), _, _) - if names.contains(g.extractFieldName) => + case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) => + UpdateFields(struct, fieldOps1 ++ fieldOps2) + case g @ GetStructField(u: UpdateFields, _, _) + if u.fieldOps.forall(_.isInstanceOf[WithField]) && + u.fieldOps.map(_.asInstanceOf[WithField].name).contains(g.extractFieldName) => + val names = u.fieldOps.map(_.asInstanceOf[WithField].name) + val values = u.fieldOps.map(_.asInstanceOf[WithField].valExpr) names.zip(values).reverse.filter(p => p._1 == g.extractFieldName).head._2 } } @@ -160,13 +153,13 @@ object ResolveUnion extends Rule[LogicalPlan] { .find(f => resolver(f.name, field.name)) if (colField.isEmpty) { // The whole struct is missing. Add a null. - WithFields(currCol, field.name, Literal(null, st)) + UpdateFields(currCol, field.name, Literal(null, st)) } else { - WithFields(currCol, field.name, + UpdateFields(currCol, field.name, addFieldsInto(ExtractValue(currCol, Literal(field.name), resolver), st.fields)) } case dt => - WithFields(currCol, field.name, Literal(null, dt)) + UpdateFields(currCol, field.name, Literal(null, dt)) } } } From bb8938f6ee9dce04dd96f5cb66c7e6afa6025b25 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 6 Oct 2020 14:12:23 -0700 Subject: [PATCH 19/23] Remove unnecessary code. --- .../apache/spark/sql/catalyst/analysis/ResolveUnion.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 579b70e5cab4..30fbfe4445f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -92,12 +92,6 @@ object ResolveUnion extends Rule[LogicalPlan] { expr.transformUp { case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) => UpdateFields(struct, fieldOps1 ++ fieldOps2) - case g @ GetStructField(u: UpdateFields, _, _) - if u.fieldOps.forall(_.isInstanceOf[WithField]) && - u.fieldOps.map(_.asInstanceOf[WithField].name).contains(g.extractFieldName) => - val names = u.fieldOps.map(_.asInstanceOf[WithField].name) - val values = u.fieldOps.map(_.asInstanceOf[WithField].valExpr) - names.zip(values).reverse.filter(p => p._1 == g.extractFieldName).head._2 } } From 9e73928443fbba00dd6f47240299f7cc18c2630b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 14 Oct 2020 12:51:09 -0700 Subject: [PATCH 20/23] For review comments. --- .../sql/catalyst/analysis/ResolveUnion.scala | 12 +- .../apache/spark/sql/internal/SQLConf.scala | 16 -- .../sql/DataFrameSetOperationsSuite.scala | 260 ++++++++---------- 3 files changed, 122 insertions(+), 166 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 30fbfe4445f9..8dce9895da7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -34,7 +34,7 @@ import org.apache.spark.unsafe.types.UTF8String */ object ResolveUnion extends Rule[LogicalPlan] { /** - * This method sorts recursively columns in a struct expression based on column names. + * This method sorts columns recursively in a struct expression based on column names. */ private def sortStructFields(expr: Expression): Expression = { val existingExprs = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { @@ -106,17 +106,17 @@ object ResolveUnion extends Rule[LogicalPlan] { assert(col.dataType.isInstanceOf[StructType], "Only support StructType.") val resolver = SQLConf.get.resolver - val missingFields = + val missingFieldsOpt = StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver) // We need to sort columns in result, because we might add another column in other side. // E.g., we want to union two structs "a int, b long" and "a int, c string". // If we don't sort, we will have "a int, b long, c string" and // "a int, c string, b long", which are not compatible. - if (missingFields.isEmpty) { + if (missingFieldsOpt.isEmpty) { sortStructFields(col) } else { - missingFields.map { s => + missingFieldsOpt.map { s => val struct = addFieldsInto(col, s.fields) // Combines `WithFields`s to reduce expression tree. val reducedStruct = simplifyWithFields(struct) @@ -174,8 +174,6 @@ object ResolveUnion extends Rule[LogicalPlan] { val aliased = mutable.ArrayBuffer.empty[Attribute] - val supportStruct = SQLConf.get.unionByNameStructSupportEnabled - val rightProjectList = leftOutputAttrs.map { lattr => val found = rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) } if (found.isDefined) { @@ -183,7 +181,7 @@ object ResolveUnion extends Rule[LogicalPlan] { val foundDt = foundAttr.dataType (foundDt, lattr.dataType) match { case (source: StructType, target: StructType) - if supportStruct && allowMissingCol && !source.sameType(target) => + if allowMissingCol && !source.sameType(target) => // Having an output with same name, but different struct type. // We need to add missing fields. Note that if there are deeply nested structs such as // nested struct of array in struct, we don't support to add missing deeply nested field diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 44c6ac33b1ea..18ffc655b217 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2740,19 +2740,6 @@ object SQLConf { .booleanConf .createWithDefault(false) - val UNION_BYNAME_STRUCT_SUPPORT_ENABLED = - buildConf("spark.sql.unionByName.structSupport.enabled") - .doc("When true, the `allowMissingColumns` feature of `Dataset.unionByName` supports " + - "nested column in struct types. Missing nested columns of struct columns with same " + - "name will also be filled with null values. This currently does not support nested " + - "columns in array and map types. Note that if there is any missing nested columns " + - "to be filled, in order to make consistent schema between two sides of union, the " + - "nested fields of structs will be sorted after merging schema." - ) - .version("3.1.0") - .booleanConf - .createWithDefault(true) - val LEGACY_PATH_OPTION_BEHAVIOR = buildConf("spark.sql.legacy.pathOptionBehavior.enabled") .internal() @@ -3102,9 +3089,6 @@ class SQLConf extends Serializable with Logging { LegacyBehaviorPolicy.withName(getConf(SQLConf.LEGACY_TIME_PARSER_POLICY)) } - def unionByNameStructSupportEnabled: Boolean = - getConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED) - def broadcastHashJoinOutputPartitioningExpandLimit: Int = getConf(BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index 41440a4e711d..5f28dc60962b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -507,107 +507,98 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { } test("SPARK-29358: Make unionByName optionally fill missing columns with nulls") { - Seq("true", "false").foreach { config => - withSQLConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED.key -> config) { - var df1 = Seq(1, 2, 3).toDF("a") - var df2 = Seq(3, 1, 2).toDF("b") - val df3 = Seq(2, 3, 1).toDF("c") - val unionDf = df1.unionByName(df2.unionByName(df3, true), true) - checkAnswer(unionDf, - Row(1, null, null) :: Row(2, null, null) :: Row(3, null, null) :: // df1 - Row(null, 3, null) :: Row(null, 1, null) :: Row(null, 2, null) :: // df2 - Row(null, null, 2) :: Row(null, null, 3) :: Row(null, null, 1) :: Nil // df3 - ) + var df1 = Seq(1, 2, 3).toDF("a") + var df2 = Seq(3, 1, 2).toDF("b") + val df3 = Seq(2, 3, 1).toDF("c") + val unionDf = df1.unionByName(df2.unionByName(df3, true), true) + checkAnswer(unionDf, + Row(1, null, null) :: Row(2, null, null) :: Row(3, null, null) :: // df1 + Row(null, 3, null) :: Row(null, 1, null) :: Row(null, 2, null) :: // df2 + Row(null, null, 2) :: Row(null, null, 3) :: Row(null, null, 1) :: Nil // df3 + ) - df1 = Seq((1, 2)).toDF("a", "c") - df2 = Seq((3, 4, 5)).toDF("a", "b", "c") - checkAnswer(df1.unionByName(df2, true), - Row(1, 2, null) :: Row(3, 5, 4) :: Nil) - checkAnswer(df2.unionByName(df1, true), - Row(3, 4, 5) :: Row(1, null, 2) :: Nil) + df1 = Seq((1, 2)).toDF("a", "c") + df2 = Seq((3, 4, 5)).toDF("a", "b", "c") + checkAnswer(df1.unionByName(df2, true), + Row(1, 2, null) :: Row(3, 5, 4) :: Nil) + checkAnswer(df2.unionByName(df1, true), + Row(3, 4, 5) :: Row(1, null, 2) :: Nil) - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - df2 = Seq((3, 4, 5)).toDF("a", "B", "C") - val union1 = df1.unionByName(df2, true) - val union2 = df2.unionByName(df1, true) + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + df2 = Seq((3, 4, 5)).toDF("a", "B", "C") + val union1 = df1.unionByName(df2, true) + val union2 = df2.unionByName(df1, true) - checkAnswer(union1, Row(1, 2, null, null) :: Row(3, null, 4, 5) :: Nil) - checkAnswer(union2, Row(3, 4, 5, null) :: Row(1, null, null, 2) :: Nil) + checkAnswer(union1, Row(1, 2, null, null) :: Row(3, null, 4, 5) :: Nil) + checkAnswer(union2, Row(3, 4, 5, null) :: Row(1, null, null, 2) :: Nil) - assert(union1.schema.fieldNames === Array("a", "c", "B", "C")) - assert(union2.schema.fieldNames === Array("a", "B", "C", "c")) - } - } + assert(union1.schema.fieldNames === Array("a", "c", "B", "C")) + assert(union2.schema.fieldNames === Array("a", "B", "C", "c")) } } test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - simple") { - withSQLConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED.key -> "true") { - val df1 = Seq(((1, 2, 3), 0), ((2, 3, 4), 1), ((3, 4, 5), 2)).toDF("a", "idx") - val df2 = Seq(((3, 4), 0), ((1, 2), 1), ((2, 3), 2)).toDF("a", "idx") - val df3 = Seq(((100, 101, 102, 103), 0), ((110, 111, 112, 113), 1), ((120, 121, 122, 123), 2)) - .toDF("a", "idx") + val df1 = Seq(((1, 2, 3), 0), ((2, 3, 4), 1), ((3, 4, 5), 2)).toDF("a", "idx") + val df2 = Seq(((3, 4), 0), ((1, 2), 1), ((2, 3), 2)).toDF("a", "idx") + val df3 = Seq(((100, 101, 102, 103), 0), ((110, 111, 112, 113), 1), ((120, 121, 122, 123), 2)) + .toDF("a", "idx") - var unionDf = df1.unionByName(df2, true) + var unionDf = df1.unionByName(df2, true) - checkAnswer(unionDf, - Row(Row(1, 2, 3), 0) :: Row(Row(2, 3, 4), 1) :: Row(Row(3, 4, 5), 2) :: - Row(Row(3, 4, null), 0) :: Row(Row(1, 2, null), 1) :: Row(Row(2, 3, null), 2) :: Nil - ) + checkAnswer(unionDf, + Row(Row(1, 2, 3), 0) :: Row(Row(2, 3, 4), 1) :: Row(Row(3, 4, 5), 2) :: + Row(Row(3, 4, null), 0) :: Row(Row(1, 2, null), 1) :: Row(Row(2, 3, null), 2) :: Nil + ) - assert(unionDf.schema.toDDL == "`a` STRUCT<`_1`: INT, `_2`: INT, `_3`: INT>,`idx` INT") + assert(unionDf.schema.toDDL == "`a` STRUCT<`_1`: INT, `_2`: INT, `_3`: INT>,`idx` INT") - unionDf = df1.unionByName(df2, true).unionByName(df3, true) + unionDf = df1.unionByName(df2, true).unionByName(df3, true) - checkAnswer(unionDf, - Row(Row(1, 2, 3, null), 0) :: - Row(Row(2, 3, 4, null), 1) :: - Row(Row(3, 4, 5, null), 2) :: // df1 - Row(Row(3, 4, null, null), 0) :: - Row(Row(1, 2, null, null), 1) :: - Row(Row(2, 3, null, null), 2) :: // df2 - Row(Row(100, 101, 102, 103), 0) :: - Row(Row(110, 111, 112, 113), 1) :: - Row(Row(120, 121, 122, 123), 2) :: Nil // df3 - ) - assert(unionDf.schema.toDDL == - "`a` STRUCT<`_1`: INT, `_2`: INT, `_3`: INT, `_4`: INT>,`idx` INT") - } + checkAnswer(unionDf, + Row(Row(1, 2, 3, null), 0) :: + Row(Row(2, 3, 4, null), 1) :: + Row(Row(3, 4, 5, null), 2) :: // df1 + Row(Row(3, 4, null, null), 0) :: + Row(Row(1, 2, null, null), 1) :: + Row(Row(2, 3, null, null), 2) :: // df2 + Row(Row(100, 101, 102, 103), 0) :: + Row(Row(110, 111, 112, 113), 1) :: + Row(Row(120, 121, 122, 123), 2) :: Nil // df3 + ) + assert(unionDf.schema.toDDL == + "`a` STRUCT<`_1`: INT, `_2`: INT, `_3`: INT, `_4`: INT>,`idx` INT") } test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - nested") { - withSQLConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED.key -> "true") { - val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") - val df2 = Seq((1, UnionClass1b(1, 2L, UnionClass3(2, 3L)))).toDF("id", "a") + val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") + val df2 = Seq((1, UnionClass1b(1, 2L, UnionClass3(2, 3L)))).toDF("id", "a") - val expectedSchema = "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + - "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>" + val expectedSchema = "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>" - var unionDf = df1.unionByName(df2, true) - checkAnswer(unionDf, - Row(0, Row(0, 1, Row(1, null, "2"))) :: - Row(1, Row(1, 2, Row(2, 3L, null))) :: Nil) - assert(unionDf.schema.toDDL == expectedSchema) + var unionDf = df1.unionByName(df2, true) + checkAnswer(unionDf, + Row(0, Row(0, 1, Row(1, null, "2"))) :: + Row(1, Row(1, 2, Row(2, 3L, null))) :: Nil) + assert(unionDf.schema.toDDL == expectedSchema) - unionDf = df2.unionByName(df1, true) - checkAnswer(unionDf, - Row(1, Row(1, 2, Row(2, 3L, null))) :: - Row(0, Row(0, 1, Row(1, null, "2"))) :: Nil) - assert(unionDf.schema.toDDL == expectedSchema) + unionDf = df2.unionByName(df1, true) + checkAnswer(unionDf, + Row(1, Row(1, 2, Row(2, 3L, null))) :: + Row(0, Row(0, 1, Row(1, null, "2"))) :: Nil) + assert(unionDf.schema.toDDL == expectedSchema) - val df3 = Seq((2, UnionClass1b(2, 3L, null))).toDF("id", "a") - unionDf = df1.unionByName(df3, true) - checkAnswer(unionDf, - Row(0, Row(0, 1, Row(1, null, "2"))) :: - Row(2, Row(2, 3, null)) :: Nil) - assert(unionDf.schema.toDDL == expectedSchema) - } + val df3 = Seq((2, UnionClass1b(2, 3L, null))).toDF("id", "a") + unionDf = df1.unionByName(df3, true) + checkAnswer(unionDf, + Row(0, Row(0, 1, Row(1, null, "2"))) :: + Row(2, Row(2, 3, null)) :: Nil) + assert(unionDf.schema.toDDL == expectedSchema) } test("SPARK-32376: Make unionByName null-filling behavior work with struct columns" + " - case-sensitive cases") { - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true", - SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED.key -> "true") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") val df2 = Seq((1, UnionClass1c(1, 2L, UnionClass4(2, 3L)))).toDF("id", "a") @@ -639,42 +630,26 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { } test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - edge case") { - withSQLConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED.key -> "true") { - val nestedStructType1 = StructType(Seq( - StructField("b", StringType))) - val nestedStructValues1 = Row("b") - - val nestedStructType2 = StructType(Seq( - StructField("b", StringType), - StructField("a", StringType))) - val nestedStructValues2 = Row("b", "a") - - val df1: DataFrame = spark.createDataFrame( - sparkContext.parallelize(Row(nestedStructValues1) :: Nil), - StructType(Seq(StructField("topLevelCol", nestedStructType1)))) - - val df2: DataFrame = spark.createDataFrame( - sparkContext.parallelize(Row(nestedStructValues2) :: Nil), - StructType(Seq(StructField("topLevelCol", nestedStructType2)))) - - val union = df1.unionByName(df2, allowMissingColumns = true) - checkAnswer(union, Row(Row(null, "b")) :: Row(Row("a", "b")) :: Nil) - assert(union.schema.toDDL == "`topLevelCol` STRUCT<`a`: STRING, `b`: STRING>") - } - } - - test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - disable") { - withSQLConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED.key -> "false") { - val df1 = Seq(((1, 2, 3), 0), ((2, 3, 4), 1), ((3, 4, 5), 2)).toDF("a", "idx") - val df2 = Seq(((3, 4), 0), ((1, 2), 1), ((2, 3), 2)).toDF("a", "idx") - - val err = intercept[AnalysisException] { - df1.unionByName(df2, true).collect() - } - assert(err.getMessage.contains("Union can only be performed on tables with the compatible " + - "column types. struct<_1:int,_2:int> <> struct<_1:int,_2:int,_3:int> at the first column " + - "of the second table")) - } + val nestedStructType1 = StructType(Seq( + StructField("b", StringType))) + val nestedStructValues1 = Row("b") + + val nestedStructType2 = StructType(Seq( + StructField("b", StringType), + StructField("a", StringType))) + val nestedStructValues2 = Row("b", "a") + + val df1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues1) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType1)))) + + val df2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(nestedStructValues2) :: Nil), + StructType(Seq(StructField("topLevelCol", nestedStructType2)))) + + val union = df1.unionByName(df2, allowMissingColumns = true) + checkAnswer(union, Row(Row(null, "b")) :: Row(Row("a", "b")) :: Nil) + assert(union.schema.toDDL == "`topLevelCol` STRUCT<`a`: STRING, `b`: STRING>") } test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - deep expr") { @@ -704,36 +679,35 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { df } - withSQLConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED.key -> "true") { - val df1 = nestedDf(depth = 10, numColsAtEachDepth = 1) - val df2 = nestedDf(depth = 10, numColsAtEachDepth = 20) - val union = df1.unionByName(df2, allowMissingColumns = true) - // scalastyle:off - val row1 = Row(Row(Row(Row(Row(Row(Row(Row(Row(Row( - Row(0, 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), - 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), - 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), - 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), - 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), - 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), - 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), - 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), - 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), - 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null)) - val row2 = Row(Row(Row(Row(Row(Row(Row(Row(Row(Row( - Row(0, 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), - 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9)) - // scalastyle:on - checkAnswer(union, row1 :: row2 :: Nil) - } + + val df1 = nestedDf(depth = 10, numColsAtEachDepth = 1) + val df2 = nestedDf(depth = 10, numColsAtEachDepth = 20) + val union = df1.unionByName(df2, allowMissingColumns = true) + // scalastyle:off + val row1 = Row(Row(Row(Row(Row(Row(Row(Row(Row(Row( + Row(0, 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null), + 1, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null)) + val row2 = Row(Row(Row(Row(Row(Row(Row(Row(Row(Row( + Row(0, 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9), + 1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2, 20, 3, 4, 5, 6, 7, 8, 9)) + // scalastyle:on + checkAnswer(union, row1 :: row2 :: Nil) } } From c07e30f18a9474a1a3bcbe913380d963984a50d2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 16 Oct 2020 09:11:13 -0700 Subject: [PATCH 21/23] For comments. --- .../org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 8dce9895da7d..9a7df14b81d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -83,7 +83,7 @@ object ResolveUnion extends Rule[LogicalPlan] { val newStruct = CreateNamedStruct(sorted) newStruct case other => - throw new AnalysisException(s"`UpdateFields` has incorrect eval expression: $other. " + + throw new IllegalStateException(s"`UpdateFields` has incorrect expression: $other. " + "Please file a bug report with this error message, stack trace, and the query.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2b519b6dbb3d..3d431d6ff13a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2067,8 +2067,7 @@ class Dataset[T] private[sql]( * // +----+----+----+----+ * }}} * - * Note that `allowMissingColumns` supports nested column in struct types, if the config - * `spark.sql.unionByName.structSupport.enabled` is enabled. Missing nested columns + * Note that `allowMissingColumns` supports nested column in struct types. Missing nested columns * of struct columns with same name will also be filled with null values. This currently does not * support nested columns in array and map types. Note that if there is any missing nested columns * to be filled, in order to make consistent schema between two sides of union, the nested fields From 2ca137968859bbd018b3c91372949da0a81e5eaa Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 16 Oct 2020 11:19:46 -0700 Subject: [PATCH 22/23] Make Scala 2.13 build happy. --- .../src/main/scala/org/apache/spark/sql/types/StructType.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 4055e2493a39..c5e76c160ff4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -673,7 +673,7 @@ object StructType extends AbstractDataType { if (newFields.isEmpty) { None } else { - Some(StructType(newFields)) + Some(StructType(newFields.toSeq)) } } } From 3d907d07233851d6169533d0475c5a53b02cb4a7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 16 Oct 2020 12:47:14 -0700 Subject: [PATCH 23/23] Make Scala 2.13 build happy, again. --- .../org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 9a7df14b81d5..c1a9c9d3d9ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -209,7 +209,7 @@ object ResolveUnion extends Rule[LogicalPlan] { } } - (rightProjectList, aliased) + (rightProjectList, aliased.toSeq) } private def unionTwoSides(