Skip to content
Closed
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,190 @@

package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.unsafe.types.UTF8String

/**
* Resolves different children of Union to a common set of columns.
*/
object ResolveUnion extends Rule[LogicalPlan] {
private def unionTwoSides(
/**
* This method sorts recursively columns in a struct expression based on column names.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorts recursively columns -> sorts columns recursively

*/
private def sortStructFields(expr: Expression): Expression = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are functions having the same names, so could we assign different names? I think its a bit confusing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, let me think about better method names.

val existingExprs = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
case (name, i) =>
val fieldExpr = GetStructField(KnownNotNull(expr), i)
if (fieldExpr.dataType.isInstanceOf[StructType]) {
(name, sortStructFields(fieldExpr))
} else {
(name, fieldExpr)
}
}.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2))

val newExpr = CreateNamedStruct(existingExprs)
if (expr.nullable) {
If(IsNull(expr), Literal(null, newExpr.dataType), newExpr)
} else {
newExpr
}
}

/**
* Assumes input expressions are field expression of `CreateNamedStruct`. This method
* sorts the expressions based on field names.
*/
private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = {
fieldExprs.grouped(2).map { e =>
Seq(e.head, e.last)
}.toSeq.sortBy { pair =>
assert(pair.head.isInstanceOf[Literal])
pair.head.eval().asInstanceOf[UTF8String].toString
}.flatten
}

/**
* This helper method sorts fields in a `UpdateFields` expression by field name.
*/
private def sortStructFieldsInWithFields(expr: Expression): Expression = expr transformUp {
case u: UpdateFields if u.resolved =>
u.evalExpr match {
case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) =>
val sorted = sortFieldExprs(fieldExprs)
val newStruct = CreateNamedStruct(sorted)
i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = newStruct)
case CreateNamedStruct(fieldExprs) =>
val sorted = sortFieldExprs(fieldExprs)
val newStruct = CreateNamedStruct(sorted)
newStruct
case other =>
throw new AnalysisException(s"`UpdateFields` has incorrect eval expression: $other. " +
"Please file a bug report with this error message, stack trace, and the query.")
}
}

def simplifyWithFields(expr: Expression): Expression = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: private. Btw, all the transformations in this method will be moved into an optimizer rule in followup? We normally add tests when adding a new rule, but this PR does not have any test for them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, actually there is #29812 for that, but is stuck by other PR that is refactoring WithFields.

expr.transformUp {
case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it important to have this optimization inside this analyzer rule?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea. Without optimizing the expressions, we cannot scale up well for deeply nested schema, e.g. the added test SPARK-32376: Make unionByName null-filling behavior work with struct columns - deep expr. in DataFrameSetOperationsSuite.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I plan to move this optimization out of ResolveUnion into a separate rule in analyzer in #29812. For complex deeply nested schema, it is easier to write inefficient expression tree that is very slow in analysis phase. For the test case in this PR, it is unable to evaluate the query at all, but after adding this optimization, it can normally evaluate.

UpdateFields(struct, fieldOps1 ++ fieldOps2)
}
}

/**
* Adds missing fields recursively into given `col` expression, based on the target `StructType`.
* This is called by `compareAndAddFields` when we find two struct columns with same name but
* different nested fields. This method will find out the missing nested fields from `col` to
* `target` struct and add these missing nested fields. Currently we don't support finding out
* missing nested fields of struct nested in array or struct nested in map.
*/
private def addFields(col: NamedExpression, target: StructType): Expression = {
assert(col.dataType.isInstanceOf[StructType], "Only support StructType.")

val resolver = SQLConf.get.resolver
val missingFields =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is a bit misleading and I though it's a Seq. How about missingFieldsOpt?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good. Fixed.

StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver)

// We need to sort columns in result, because we might add another column in other side.
// E.g., we want to union two structs "a int, b long" and "a int, c string".
// If we don't sort, we will have "a int, b long, c string" and
// "a int, c string, b long", which are not compatible.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this behavior consistent with top-level columns?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is related to the comment: #29587 (comment)

if (missingFields.isEmpty) {
sortStructFields(col)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to sort names recursively for nested struct cases?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, we need this to make sure two sides have consistent schema. For example the test case from @fqaiser94 in #29587 (comment), when we add field to one side, another side still needs to sort its column, otherwise there is inconsistency.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu I got you point when I was fixing the performance issue. Yeah, we should. I fixed it in latest commit. Thanks.

} else {
missingFields.map { s =>
val struct = addFieldsInto(col, s.fields)
// Combines `WithFields`s to reduce expression tree.
val reducedStruct = simplifyWithFields(struct)
val sorted = sortStructFieldsInWithFields(reducedStruct)
sorted
}.get
}
}

/**
* Adds missing fields recursively into given `col` expression. The missing fields are given
* in `fields`. For example, given `col` as "z struct<z:int, y:int>, x int", and `fields` is
* "z struct<w:long>, w string". This method will add a nested `z.w` field and a top-level
* `w` field to `col` and fill null values for them. Note that because we might also add missing
* fields at other side of Union, we must make sure corresponding attributes at two sides have
* same field order in structs, so when we adding missing fields, we will sort the fields based on
* field names. So the data type of returned expression will be
* "w string, x int, z struct<w:long, y:int, z:int>".
*/
private def addFieldsInto(
col: Expression,
fields: Seq[StructField]): Expression = {
fields.foldLeft(col) { case (currCol, field) =>
field.dataType match {
case st: StructType =>
val resolver = SQLConf.get.resolver
val colField = currCol.dataType.asInstanceOf[StructType]
.find(f => resolver(f.name, field.name))
if (colField.isEmpty) {
// The whole struct is missing. Add a null.
UpdateFields(currCol, field.name, Literal(null, st))
} else {
UpdateFields(currCol, field.name,
addFieldsInto(ExtractValue(currCol, Literal(field.name), resolver), st.fields))
}
case dt =>
UpdateFields(currCol, field.name, Literal(null, dt))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if byName is true but allowMissingCol is false?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If allowMissingCol is false, we don't compare and add top-level/nested columns. If two sides have inconsistent schema, the union doesn't pass analysis.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The top-level columns support byName and allowMissingCol individually, shall we do it for nested columns as well? Or we plan to do it in followup?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. byName support actually means we need to adjust columns between two sides to have a consistent schema. It could be top-level or nested column cases.

So it is actually the same issue as #29587 (comment), a.k.a adjusting the nested columns to have a more natural schema. As replied in the discussion, I plan to do it in followup.

}
}
}

/**
* This method will compare right to left plan's outputs. If there is one struct attribute
* at right side has same name with left side struct attribute, but two structs are not the
* same data type, i.e., some missing (nested) fields at right struct attribute, then this
* method will try to add missing (nested) fields into the right attribute with null values.
*/
private def compareAndAddFields(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although we have a rich comment in the function body, could you add a function description to give a general idea?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

left: LogicalPlan,
right: LogicalPlan,
allowMissingCol: Boolean): LogicalPlan = {
allowMissingCol: Boolean): (Seq[NamedExpression], Seq[NamedExpression]) = {
val resolver = SQLConf.get.resolver
val leftOutputAttrs = left.output
val rightOutputAttrs = right.output

// Builds a project list for `right` based on `left` output names
val aliased = mutable.ArrayBuffer.empty[Attribute]

val supportStruct = SQLConf.get.unionByNameStructSupportEnabled

val rightProjectList = leftOutputAttrs.map { lattr =>
rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse {
val found = rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }
if (found.isDefined) {
val foundAttr = found.get
val foundDt = foundAttr.dataType
(foundDt, lattr.dataType) match {
case (source: StructType, target: StructType)
if supportStruct && allowMissingCol && !source.sameType(target) =>
// Having an output with same name, but different struct type.
// We need to add missing fields. Note that if there are deeply nested structs such as
// nested struct of array in struct, we don't support to add missing deeply nested field
// like that. We will sort columns in the struct expression to make sure two sides of
// union have consistent schema.
aliased += foundAttr
Alias(addFields(foundAttr, target), foundAttr.name)()
case _ =>
// We don't need/try to add missing fields if:
// 1. The attributes of left and right side are the same struct type
// 2. The attributes are not struct types. They might be primitive types, or array, map
// types. We don't support adding missing fields of nested structs in array or map
// types now.
// 3. `allowMissingCol` is disabled.
foundAttr
}
} else {
if (allowMissingCol) {
Alias(Literal(null, lattr.dataType), lattr.name)()
} else {
Expand All @@ -50,18 +211,29 @@ object ResolveUnion extends Rule[LogicalPlan] {
}
}

(rightProjectList, aliased)
}

private def unionTwoSides(
left: LogicalPlan,
right: LogicalPlan,
allowMissingCol: Boolean): LogicalPlan = {
val rightOutputAttrs = right.output

// Builds a project list for `right` based on `left` output names
val (rightProjectList, aliased) = compareAndAddFields(left, right, allowMissingCol)

// Delegates failure checks to `CheckAnalysis`
val notFoundAttrs = rightOutputAttrs.diff(rightProjectList)
val notFoundAttrs = rightOutputAttrs.diff(rightProjectList ++ aliased)
val rightChild = Project(rightProjectList ++ notFoundAttrs, right)

// Builds a project for `logicalPlan` based on `right` output names, if allowing
// missing columns.
val leftChild = if (allowMissingCol) {
val missingAttrs = notFoundAttrs.map { attr =>
Alias(Literal(null, attr.dataType), attr.name)()
}
if (missingAttrs.nonEmpty) {
Project(leftOutputAttrs ++ missingAttrs, left)
// Add missing (nested) fields to left plan.
val (leftProjectList, _) = compareAndAddFields(rightChild, left, allowMissingCol)
if (leftProjectList.map(_.toAttribute) != left.output) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

      if (leftProjectList.length != left.output.length ||
          leftProjectList.map(_.toAttribute) != left.output) {

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't leftProjectList.map(_.toAttribute) != left.output already cover leftProjectList.length != left.output.length?

Project(leftProjectList, left)
} else {
left
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -661,3 +662,52 @@ case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperat
}
}
}

object UpdateFields {
private def nameParts(fieldName: String): Seq[String] = {
require(fieldName != null, "fieldName cannot be null")

if (fieldName.isEmpty) {
fieldName :: Nil
} else {
CatalystSqlParser.parseMultipartIdentifier(fieldName)
}
}

/**
* Adds/replaces field of `StructType` into `col` expression by name.
*/
def apply(col: Expression, fieldName: String, expr: Expression): UpdateFields = {
updateFieldsHelper(col, nameParts(fieldName), name => WithField(name, expr))
}

/**
* Drops fields of `StructType` in `col` expression by name.
*/
def apply(col: Expression, fieldName: String): UpdateFields = {
updateFieldsHelper(col, nameParts(fieldName), name => DropField(name))
}

private def updateFieldsHelper(
structExpr: Expression,
namePartsRemaining: Seq[String],
valueFunc: String => StructFieldsOperation) : UpdateFields = {
val fieldName = namePartsRemaining.head
if (namePartsRemaining.length == 1) {
UpdateFields(structExpr, valueFunc(fieldName) :: Nil)
} else {
val newStruct = if (structExpr.resolved) {
val resolver = SQLConf.get.resolver
ExtractValue(structExpr, Literal(fieldName), resolver)
} else {
UnresolvedExtractValue(structExpr, Literal(fieldName))
}

val newValue = updateFieldsHelper(
structExpr = newStruct,
namePartsRemaining = namePartsRemaining.tail,
valueFunc = valueFunc)
UpdateFields(structExpr, WithField(fieldName, newValue) :: Nil)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
s"$child.${name.getOrElse(fieldName)}"
}

def extractFieldName: String = name.getOrElse(childSchema(ordinal).name)

override def sql: String =
child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}"
child.sql + s".${quoteIdentifier(extractFieldName)}"

protected override def nullSafeEval(input: Any): Any =
input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2740,6 +2740,19 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val UNION_BYNAME_STRUCT_SUPPORT_ENABLED =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this worths a config. It's natural that the byName and allowMissingCol flag should apply to nested column, and these 2 flags are newly added in the master branch so there is no backward compatibility issues.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, sounds more correct. I will remove this config.

buildConf("spark.sql.unionByName.structSupport.enabled")
.doc("When true, the `allowMissingColumns` feature of `Dataset.unionByName` supports " +
"nested column in struct types. Missing nested columns of struct columns with same " +
"name will also be filled with null values. This currently does not support nested " +
"columns in array and map types. Note that if there is any missing nested columns " +
"to be filled, in order to make consistent schema between two sides of union, the " +
"nested fields of structs will be sorted after merging schema."
)
.version("3.1.0")
.booleanConf
.createWithDefault(true)

val LEGACY_PATH_OPTION_BEHAVIOR =
buildConf("spark.sql.legacy.pathOptionBehavior.enabled")
.internal()
Expand Down Expand Up @@ -3089,6 +3102,9 @@ class SQLConf extends Serializable with Logging {
LegacyBehaviorPolicy.withName(getConf(SQLConf.LEGACY_TIME_PARSER_POLICY))
}

def unionByNameStructSupportEnabled: Boolean =
getConf(SQLConf.UNION_BYNAME_STRUCT_SUPPORT_ENABLED)

def broadcastHashJoinOutputPartitioningExpandLimit: Int =
getConf(BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,4 +641,39 @@ object StructType extends AbstractDataType {
fields.foreach(s => map.put(s.name, s))
map
}

/**
* Returns a `StructType` that contains missing fields recursively from `source` to `target`.
* Note that this doesn't support looking into array type and map type recursively.
Copy link
Member

@maropu maropu Aug 31, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this limitation come?; we don't need to support this case, or supporting it is technically difficult? Ah, I see. Is this an unsupported case, right?
https://github.com/apache/spark/pull/29587/files#diff-4d656d696512d6bcb03a48f7e0af6251R106-R107

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I leverage WithFields to add missing nested fields into structs. WithFields doesn't support array or map types currently.

*/
def findMissingFields(
source: StructType,
target: StructType,
resolver: Resolver): Option[StructType] = {
def bothStructType(dt1: DataType, dt2: DataType): Boolean =
dt1.isInstanceOf[StructType] && dt2.isInstanceOf[StructType]

val newFields = mutable.ArrayBuffer.empty[StructField]

target.fields.foreach { field =>
val found = source.fields.find(f => resolver(field.name, f.name))
if (found.isEmpty) {
// Found a missing field in `source`.
newFields += field
} else if (bothStructType(found.get.dataType, field.dataType) &&
!found.get.dataType.sameType(field.dataType)) {
// Found a field with same name, but different data type.
findMissingFields(found.get.dataType.asInstanceOf[StructType],
field.dataType.asInstanceOf[StructType], resolver).map { missingType =>
newFields += found.get.copy(dataType = missingType)
}
}
}

if (newFields.isEmpty) {
None
} else {
Some(StructType(newFields))
}
}
}
Loading