Skip to content

Commit e574fcd

Browse files
committed
[SPARK-32376][SQL] Make unionByName null-filling behavior work with struct columns
### What changes were proposed in this pull request? SPARK-29358 added support for `unionByName` to work when the two datasets didn't necessarily have the same schema, but it does not work with nested columns like structs. This patch adds the support to work with struct columns. The behavior before this PR: ```scala scala> val df1 = spark.range(1).selectExpr("id c0", "named_struct('c', id + 1, 'b', id + 2, 'a', id + 3) c1") scala> val df2 = spark.range(1).selectExpr("id c0", "named_struct('c', id + 1, 'b', id + 2) c1") scala> df1.unionByName(df2, true).printSchema org.apache.spark.sql.AnalysisException: Union can only be performed on tables with the compatible column types. struct<c:bigint,b:bigint> <> struct<c:bigint,b:bigint,a:bigint> at the second column of the second table;; 'Union false, false :- Project [id#0L AS c0#2L, named_struct(c, (id#0L + cast(1 as bigint)), b, (id#0L + cast(2 as bigint)), a, (id#0L + cast(3 as bigint))) AS c1#3] : +- Range (0, 1, step=1, splits=Some(12)) +- Project [c0#8L, c1#9] +- Project [id#6L AS c0#8L, named_struct(c, (id#6L + cast(1 as bigint)), b, (id#6L + cast(2 as bigint))) AS c1#9] +- Range (0, 1, step=1, splits=Some(12)) ``` The behavior after this PR: ```scala scala> df1.unionByName(df2, true).printSchema root |-- c0: long (nullable = false) |-- c1: struct (nullable = false) | |-- a: long (nullable = true) | |-- b: long (nullable = false) | |-- c: long (nullable = false) scala> df1.unionByName(df2, true).show() +---+-------------+ | c0| c1| +---+-------------+ | 0| {3, 2, 1}| | 0|{ null, 2, 1}| +---+-------------+ ``` ### Why are the changes needed? The `allowMissingColumns` of `unionByName` is a feature allowing merging different schema from two datasets when unioning them together. Nested column support makes the feature more general and flexible for usage. ### Does this PR introduce _any_ user-facing change? Yes, after this change users can union two datasets with different schema with different structs. ### How was this patch tested? Unit tests. Closes #29587 from viirya/SPARK-32376. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Liang-Chi Hsieh <[email protected]>
1 parent ce6180c commit e574fcd

File tree

8 files changed

+555
-47
lines changed

8 files changed

+555
-47
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala

Lines changed: 181 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,188 @@
1717

1818
package org.apache.spark.sql.catalyst.analysis
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.spark.sql.AnalysisException
21-
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
23+
import org.apache.spark.sql.catalyst.expressions._
2224
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
2325
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
2426
import org.apache.spark.sql.catalyst.rules.Rule
2527
import org.apache.spark.sql.internal.SQLConf
28+
import org.apache.spark.sql.types._
2629
import org.apache.spark.sql.util.SchemaUtils
30+
import org.apache.spark.unsafe.types.UTF8String
2731

2832
/**
2933
* Resolves different children of Union to a common set of columns.
3034
*/
3135
object ResolveUnion extends Rule[LogicalPlan] {
32-
private def unionTwoSides(
36+
/**
37+
* This method sorts columns recursively in a struct expression based on column names.
38+
*/
39+
private def sortStructFields(expr: Expression): Expression = {
40+
val existingExprs = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
41+
case (name, i) =>
42+
val fieldExpr = GetStructField(KnownNotNull(expr), i)
43+
if (fieldExpr.dataType.isInstanceOf[StructType]) {
44+
(name, sortStructFields(fieldExpr))
45+
} else {
46+
(name, fieldExpr)
47+
}
48+
}.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2))
49+
50+
val newExpr = CreateNamedStruct(existingExprs)
51+
if (expr.nullable) {
52+
If(IsNull(expr), Literal(null, newExpr.dataType), newExpr)
53+
} else {
54+
newExpr
55+
}
56+
}
57+
58+
/**
59+
* Assumes input expressions are field expression of `CreateNamedStruct`. This method
60+
* sorts the expressions based on field names.
61+
*/
62+
private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = {
63+
fieldExprs.grouped(2).map { e =>
64+
Seq(e.head, e.last)
65+
}.toSeq.sortBy { pair =>
66+
assert(pair.head.isInstanceOf[Literal])
67+
pair.head.eval().asInstanceOf[UTF8String].toString
68+
}.flatten
69+
}
70+
71+
/**
72+
* This helper method sorts fields in a `UpdateFields` expression by field name.
73+
*/
74+
private def sortStructFieldsInWithFields(expr: Expression): Expression = expr transformUp {
75+
case u: UpdateFields if u.resolved =>
76+
u.evalExpr match {
77+
case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) =>
78+
val sorted = sortFieldExprs(fieldExprs)
79+
val newStruct = CreateNamedStruct(sorted)
80+
i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = newStruct)
81+
case CreateNamedStruct(fieldExprs) =>
82+
val sorted = sortFieldExprs(fieldExprs)
83+
val newStruct = CreateNamedStruct(sorted)
84+
newStruct
85+
case other =>
86+
throw new IllegalStateException(s"`UpdateFields` has incorrect expression: $other. " +
87+
"Please file a bug report with this error message, stack trace, and the query.")
88+
}
89+
}
90+
91+
def simplifyWithFields(expr: Expression): Expression = {
92+
expr.transformUp {
93+
case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) =>
94+
UpdateFields(struct, fieldOps1 ++ fieldOps2)
95+
}
96+
}
97+
98+
/**
99+
* Adds missing fields recursively into given `col` expression, based on the target `StructType`.
100+
* This is called by `compareAndAddFields` when we find two struct columns with same name but
101+
* different nested fields. This method will find out the missing nested fields from `col` to
102+
* `target` struct and add these missing nested fields. Currently we don't support finding out
103+
* missing nested fields of struct nested in array or struct nested in map.
104+
*/
105+
private def addFields(col: NamedExpression, target: StructType): Expression = {
106+
assert(col.dataType.isInstanceOf[StructType], "Only support StructType.")
107+
108+
val resolver = SQLConf.get.resolver
109+
val missingFieldsOpt =
110+
StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver)
111+
112+
// We need to sort columns in result, because we might add another column in other side.
113+
// E.g., we want to union two structs "a int, b long" and "a int, c string".
114+
// If we don't sort, we will have "a int, b long, c string" and
115+
// "a int, c string, b long", which are not compatible.
116+
if (missingFieldsOpt.isEmpty) {
117+
sortStructFields(col)
118+
} else {
119+
missingFieldsOpt.map { s =>
120+
val struct = addFieldsInto(col, s.fields)
121+
// Combines `WithFields`s to reduce expression tree.
122+
val reducedStruct = simplifyWithFields(struct)
123+
val sorted = sortStructFieldsInWithFields(reducedStruct)
124+
sorted
125+
}.get
126+
}
127+
}
128+
129+
/**
130+
* Adds missing fields recursively into given `col` expression. The missing fields are given
131+
* in `fields`. For example, given `col` as "z struct<z:int, y:int>, x int", and `fields` is
132+
* "z struct<w:long>, w string". This method will add a nested `z.w` field and a top-level
133+
* `w` field to `col` and fill null values for them. Note that because we might also add missing
134+
* fields at other side of Union, we must make sure corresponding attributes at two sides have
135+
* same field order in structs, so when we adding missing fields, we will sort the fields based on
136+
* field names. So the data type of returned expression will be
137+
* "w string, x int, z struct<w:long, y:int, z:int>".
138+
*/
139+
private def addFieldsInto(
140+
col: Expression,
141+
fields: Seq[StructField]): Expression = {
142+
fields.foldLeft(col) { case (currCol, field) =>
143+
field.dataType match {
144+
case st: StructType =>
145+
val resolver = SQLConf.get.resolver
146+
val colField = currCol.dataType.asInstanceOf[StructType]
147+
.find(f => resolver(f.name, field.name))
148+
if (colField.isEmpty) {
149+
// The whole struct is missing. Add a null.
150+
UpdateFields(currCol, field.name, Literal(null, st))
151+
} else {
152+
UpdateFields(currCol, field.name,
153+
addFieldsInto(ExtractValue(currCol, Literal(field.name), resolver), st.fields))
154+
}
155+
case dt =>
156+
UpdateFields(currCol, field.name, Literal(null, dt))
157+
}
158+
}
159+
}
160+
161+
/**
162+
* This method will compare right to left plan's outputs. If there is one struct attribute
163+
* at right side has same name with left side struct attribute, but two structs are not the
164+
* same data type, i.e., some missing (nested) fields at right struct attribute, then this
165+
* method will try to add missing (nested) fields into the right attribute with null values.
166+
*/
167+
private def compareAndAddFields(
33168
left: LogicalPlan,
34169
right: LogicalPlan,
35-
allowMissingCol: Boolean): LogicalPlan = {
170+
allowMissingCol: Boolean): (Seq[NamedExpression], Seq[NamedExpression]) = {
36171
val resolver = SQLConf.get.resolver
37172
val leftOutputAttrs = left.output
38173
val rightOutputAttrs = right.output
39174

40-
// Builds a project list for `right` based on `left` output names
175+
val aliased = mutable.ArrayBuffer.empty[Attribute]
176+
41177
val rightProjectList = leftOutputAttrs.map { lattr =>
42-
rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse {
178+
val found = rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }
179+
if (found.isDefined) {
180+
val foundAttr = found.get
181+
val foundDt = foundAttr.dataType
182+
(foundDt, lattr.dataType) match {
183+
case (source: StructType, target: StructType)
184+
if allowMissingCol && !source.sameType(target) =>
185+
// Having an output with same name, but different struct type.
186+
// We need to add missing fields. Note that if there are deeply nested structs such as
187+
// nested struct of array in struct, we don't support to add missing deeply nested field
188+
// like that. We will sort columns in the struct expression to make sure two sides of
189+
// union have consistent schema.
190+
aliased += foundAttr
191+
Alias(addFields(foundAttr, target), foundAttr.name)()
192+
case _ =>
193+
// We don't need/try to add missing fields if:
194+
// 1. The attributes of left and right side are the same struct type
195+
// 2. The attributes are not struct types. They might be primitive types, or array, map
196+
// types. We don't support adding missing fields of nested structs in array or map
197+
// types now.
198+
// 3. `allowMissingCol` is disabled.
199+
foundAttr
200+
}
201+
} else {
43202
if (allowMissingCol) {
44203
Alias(Literal(null, lattr.dataType), lattr.name)()
45204
} else {
@@ -50,18 +209,29 @@ object ResolveUnion extends Rule[LogicalPlan] {
50209
}
51210
}
52211

212+
(rightProjectList, aliased.toSeq)
213+
}
214+
215+
private def unionTwoSides(
216+
left: LogicalPlan,
217+
right: LogicalPlan,
218+
allowMissingCol: Boolean): LogicalPlan = {
219+
val rightOutputAttrs = right.output
220+
221+
// Builds a project list for `right` based on `left` output names
222+
val (rightProjectList, aliased) = compareAndAddFields(left, right, allowMissingCol)
223+
53224
// Delegates failure checks to `CheckAnalysis`
54-
val notFoundAttrs = rightOutputAttrs.diff(rightProjectList)
225+
val notFoundAttrs = rightOutputAttrs.diff(rightProjectList ++ aliased)
55226
val rightChild = Project(rightProjectList ++ notFoundAttrs, right)
56227

57228
// Builds a project for `logicalPlan` based on `right` output names, if allowing
58229
// missing columns.
59230
val leftChild = if (allowMissingCol) {
60-
val missingAttrs = notFoundAttrs.map { attr =>
61-
Alias(Literal(null, attr.dataType), attr.name)()
62-
}
63-
if (missingAttrs.nonEmpty) {
64-
Project(leftOutputAttrs ++ missingAttrs, left)
231+
// Add missing (nested) fields to left plan.
232+
val (leftProjectList, _) = compareAndAddFields(rightChild, left, allowMissingCol)
233+
if (leftProjectList.map(_.toAttribute) != left.output) {
234+
Project(leftProjectList, left)
65235
} else {
66236
left
67237
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.expressions
2020
import scala.collection.mutable.ArrayBuffer
2121

2222
import org.apache.spark.sql.catalyst.InternalRow
23-
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion}
23+
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedExtractValue}
2424
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder}
2525
import org.apache.spark.sql.catalyst.expressions.codegen._
2626
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
27+
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
2728
import org.apache.spark.sql.catalyst.util._
2829
import org.apache.spark.sql.internal.SQLConf
2930
import org.apache.spark.sql.types._
@@ -661,3 +662,52 @@ case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperat
661662
}
662663
}
663664
}
665+
666+
object UpdateFields {
667+
private def nameParts(fieldName: String): Seq[String] = {
668+
require(fieldName != null, "fieldName cannot be null")
669+
670+
if (fieldName.isEmpty) {
671+
fieldName :: Nil
672+
} else {
673+
CatalystSqlParser.parseMultipartIdentifier(fieldName)
674+
}
675+
}
676+
677+
/**
678+
* Adds/replaces field of `StructType` into `col` expression by name.
679+
*/
680+
def apply(col: Expression, fieldName: String, expr: Expression): UpdateFields = {
681+
updateFieldsHelper(col, nameParts(fieldName), name => WithField(name, expr))
682+
}
683+
684+
/**
685+
* Drops fields of `StructType` in `col` expression by name.
686+
*/
687+
def apply(col: Expression, fieldName: String): UpdateFields = {
688+
updateFieldsHelper(col, nameParts(fieldName), name => DropField(name))
689+
}
690+
691+
private def updateFieldsHelper(
692+
structExpr: Expression,
693+
namePartsRemaining: Seq[String],
694+
valueFunc: String => StructFieldsOperation) : UpdateFields = {
695+
val fieldName = namePartsRemaining.head
696+
if (namePartsRemaining.length == 1) {
697+
UpdateFields(structExpr, valueFunc(fieldName) :: Nil)
698+
} else {
699+
val newStruct = if (structExpr.resolved) {
700+
val resolver = SQLConf.get.resolver
701+
ExtractValue(structExpr, Literal(fieldName), resolver)
702+
} else {
703+
UnresolvedExtractValue(structExpr, Literal(fieldName))
704+
}
705+
706+
val newValue = updateFieldsHelper(
707+
structExpr = newStruct,
708+
namePartsRemaining = namePartsRemaining.tail,
709+
valueFunc = valueFunc)
710+
UpdateFields(structExpr, WithField(fieldName, newValue) :: Nil)
711+
}
712+
}
713+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,10 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
116116
s"$child.${name.getOrElse(fieldName)}"
117117
}
118118

119+
def extractFieldName: String = name.getOrElse(childSchema(ordinal).name)
120+
119121
override def sql: String =
120-
child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}"
122+
child.sql + s".${quoteIdentifier(extractFieldName)}"
121123

122124
protected override def nullSafeEval(input: Any): Any =
123125
input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType)

sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,4 +641,39 @@ object StructType extends AbstractDataType {
641641
fields.foreach(s => map.put(s.name, s))
642642
map
643643
}
644+
645+
/**
646+
* Returns a `StructType` that contains missing fields recursively from `source` to `target`.
647+
* Note that this doesn't support looking into array type and map type recursively.
648+
*/
649+
def findMissingFields(
650+
source: StructType,
651+
target: StructType,
652+
resolver: Resolver): Option[StructType] = {
653+
def bothStructType(dt1: DataType, dt2: DataType): Boolean =
654+
dt1.isInstanceOf[StructType] && dt2.isInstanceOf[StructType]
655+
656+
val newFields = mutable.ArrayBuffer.empty[StructField]
657+
658+
target.fields.foreach { field =>
659+
val found = source.fields.find(f => resolver(field.name, f.name))
660+
if (found.isEmpty) {
661+
// Found a missing field in `source`.
662+
newFields += field
663+
} else if (bothStructType(found.get.dataType, field.dataType) &&
664+
!found.get.dataType.sameType(field.dataType)) {
665+
// Found a field with same name, but different data type.
666+
findMissingFields(found.get.dataType.asInstanceOf[StructType],
667+
field.dataType.asInstanceOf[StructType], resolver).map { missingType =>
668+
newFields += found.get.copy(dataType = missingType)
669+
}
670+
}
671+
}
672+
673+
if (newFields.isEmpty) {
674+
None
675+
} else {
676+
Some(StructType(newFields.toSeq))
677+
}
678+
}
644679
}

0 commit comments

Comments
 (0)