-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-35290][SQL] Append new nested struct fields rather than sort for unionByName with null filling #32448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,137 +20,56 @@ package org.apache.spark.sql.catalyst.analysis | |
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.optimizer.{CombineUnions, OptimizeUpdateFields} | ||
| import org.apache.spark.sql.catalyst.optimizer.{CombineUnions} | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.catalyst.trees.TreePattern.UNION | ||
| import org.apache.spark.sql.errors.QueryCompilationErrors | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.sql.util.SchemaUtils | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
|
|
||
| /** | ||
| * Resolves different children of Union to a common set of columns. | ||
| */ | ||
| object ResolveUnion extends Rule[LogicalPlan] { | ||
| /** | ||
| * This method sorts columns recursively in a struct expression based on column names. | ||
| * Adds missing fields recursively into given `col` expression, based on the expected struct | ||
| * fields from merging the two schemas. This is called by `compareAndAddFields` when we find two | ||
| * struct columns with same name but different nested fields. This method will recursively | ||
| * return a new struct with all of the expected fields, adding null values when `col` doesn't | ||
| * already contain them. Currently we don't support merging structs nested inside of arrays | ||
| * or maps. | ||
| */ | ||
| private def sortStructFields(expr: Expression): Expression = { | ||
| val existingExprs = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { | ||
| case (name, i) => | ||
| val fieldExpr = GetStructField(KnownNotNull(expr), i) | ||
| if (fieldExpr.dataType.isInstanceOf[StructType]) { | ||
| (name, sortStructFields(fieldExpr)) | ||
| } else { | ||
| (name, fieldExpr) | ||
| } | ||
| }.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2)) | ||
|
|
||
| val newExpr = CreateNamedStruct(existingExprs) | ||
| if (expr.nullable) { | ||
| If(IsNull(expr), Literal(null, newExpr.dataType), newExpr) | ||
| } else { | ||
| newExpr | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Assumes input expressions are field expression of `CreateNamedStruct`. This method | ||
| * sorts the expressions based on field names. | ||
| */ | ||
| private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = { | ||
| fieldExprs.grouped(2).map { e => | ||
| Seq(e.head, e.last) | ||
| }.toSeq.sortBy { pair => | ||
| assert(pair.head.isInstanceOf[Literal]) | ||
| pair.head.eval().asInstanceOf[UTF8String].toString | ||
| }.flatten | ||
| } | ||
|
|
||
| /** | ||
| * This helper method sorts fields in a `UpdateFields` expression by field name. | ||
| */ | ||
| private def sortStructFieldsInWithFields(expr: Expression): Expression = expr transformUp { | ||
| case u: UpdateFields if u.resolved => | ||
| u.evalExpr match { | ||
| case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) => | ||
| val sorted = sortFieldExprs(fieldExprs) | ||
| val newStruct = CreateNamedStruct(sorted) | ||
| i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = newStruct) | ||
| case CreateNamedStruct(fieldExprs) => | ||
| val sorted = sortFieldExprs(fieldExprs) | ||
| val newStruct = CreateNamedStruct(sorted) | ||
| newStruct | ||
| case other => | ||
| throw new IllegalStateException(s"`UpdateFields` has incorrect expression: $other. " + | ||
| "Please file a bug report with this error message, stack trace, and the query.") | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Adds missing fields recursively into given `col` expression, based on the target `StructType`. | ||
| * This is called by `compareAndAddFields` when we find two struct columns with same name but | ||
| * different nested fields. This method will find out the missing nested fields from `col` to | ||
| * `target` struct and add these missing nested fields. Currently we don't support finding out | ||
| * missing nested fields of struct nested in array or struct nested in map. | ||
| */ | ||
| private def addFields(col: NamedExpression, target: StructType): Expression = { | ||
| private def addFields(col: Expression, expectedFields: Seq[StructField]): Expression = { | ||
| assert(col.dataType.isInstanceOf[StructType], "Only support StructType.") | ||
|
|
||
| val resolver = conf.resolver | ||
| val missingFieldsOpt = | ||
| StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver) | ||
|
|
||
| // We need to sort columns in result, because we might add another column in other side. | ||
| // E.g., we want to union two structs "a int, b long" and "a int, c string". | ||
| // If we don't sort, we will have "a int, b long, c string" and | ||
| // "a int, c string, b long", which are not compatible. | ||
| if (missingFieldsOpt.isEmpty) { | ||
| sortStructFields(col) | ||
| } else { | ||
| missingFieldsOpt.map { s => | ||
| val struct = addFieldsInto(col, s.fields) | ||
| // Combines `WithFields`s to reduce expression tree. | ||
| val reducedStruct = struct.transformUp(OptimizeUpdateFields.optimizeUpdateFields) | ||
| val sorted = sortStructFieldsInWithFields(reducedStruct) | ||
| sorted | ||
| }.get | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Adds missing fields recursively into given `col` expression. The missing fields are given | ||
| * in `fields`. For example, given `col` as "z struct<z:int, y:int>, x int", and `fields` is | ||
| * "z struct<w:long>, w string". This method will add a nested `z.w` field and a top-level | ||
| * `w` field to `col` and fill null values for them. Note that because we might also add missing | ||
| * fields at other side of Union, we must make sure corresponding attributes at two sides have | ||
| * same field order in structs, so when we adding missing fields, we will sort the fields based on | ||
| * field names. So the data type of returned expression will be | ||
| * "w string, x int, z struct<w:long, y:int, z:int>". | ||
| */ | ||
| private def addFieldsInto( | ||
| col: Expression, | ||
| fields: Seq[StructField]): Expression = { | ||
| fields.foldLeft(col) { case (currCol, field) => | ||
| field.dataType match { | ||
| case st: StructType => | ||
| val resolver = conf.resolver | ||
| val colField = currCol.dataType.asInstanceOf[StructType] | ||
| .find(f => resolver(f.name, field.name)) | ||
| if (colField.isEmpty) { | ||
| // The whole struct is missing. Add a null. | ||
| UpdateFields(currCol, field.name, Literal(null, st)) | ||
| } else { | ||
| UpdateFields(currCol, field.name, | ||
| addFieldsInto(ExtractValue(currCol, Literal(field.name), resolver), st.fields)) | ||
| } | ||
| case dt => | ||
| UpdateFields(currCol, field.name, Literal(null, dt)) | ||
| val colType = col.dataType.asInstanceOf[StructType] | ||
| val newStructFields = expectedFields.flatMap { expectedField => | ||
| val currentField = colType.fields.find(f => resolver(f.name, expectedField.name)) | ||
|
|
||
| val newExpression = (currentField, expectedField.dataType) match { | ||
| case (Some(cf), expectedType: StructType) if cf.dataType.isInstanceOf[StructType] => | ||
| val extractedValue = ExtractValue(col, Literal(cf.name), resolver) | ||
| val combinedStruct = addFields(extractedValue, expectedType.fields) | ||
| if (extractedValue.nullable) { | ||
| If(IsNull(extractedValue), | ||
| Literal(null, combinedStruct.dataType), | ||
| combinedStruct) | ||
| } else { | ||
| combinedStruct | ||
| } | ||
| case (Some(cf), _) => | ||
| ExtractValue(col, Literal(cf.name), resolver) | ||
| case (None, expectedType) => | ||
| Literal(null, expectedType) | ||
| } | ||
| Literal(expectedField.name) :: newExpression :: Nil | ||
| } | ||
| CreateNamedStruct(newStructFields) | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * This method will compare right to left plan's outputs. If there is one struct attribute | ||
| * at right side has same name with left side struct attribute, but two structs are not the | ||
|
|
@@ -181,7 +100,8 @@ object ResolveUnion extends Rule[LogicalPlan] { | |
| // like that. We will sort columns in the struct expression to make sure two sides of | ||
| // union have consistent schema. | ||
| aliased += foundAttr | ||
| Alias(addFields(foundAttr, target), foundAttr.name)() | ||
| val targetType = target.merge(source, conf.resolver) | ||
|
||
| Alias(addFields(foundAttr, targetType.fields.toSeq), foundAttr.name)() | ||
| case _ => | ||
| // We don't need/try to add missing fields if: | ||
| // 1. The attributes of left and right side are the same struct type | ||
|
|
@@ -208,13 +128,11 @@ object ResolveUnion extends Rule[LogicalPlan] { | |
| left: LogicalPlan, | ||
| right: LogicalPlan, | ||
| allowMissingCol: Boolean): LogicalPlan = { | ||
| val rightOutputAttrs = right.output | ||
|
|
||
| // Builds a project list for `right` based on `left` output names | ||
| val (rightProjectList, aliased) = compareAndAddFields(left, right, allowMissingCol) | ||
|
|
||
| // Delegates failure checks to `CheckAnalysis` | ||
| val notFoundAttrs = rightOutputAttrs.diff(rightProjectList ++ aliased) | ||
| val notFoundAttrs = right.output.diff(rightProjectList ++ aliased) | ||
| val rightChild = Project(rightProjectList ++ notFoundAttrs, right) | ||
|
|
||
| // Builds a project for `logicalPlan` based on `right` output names, if allowing | ||
|
|
@@ -230,6 +148,7 @@ object ResolveUnion extends Rule[LogicalPlan] { | |
| } else { | ||
| left | ||
| } | ||
|
|
||
| Union(leftChild, rightChild) | ||
| } | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.