Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sql-migration-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ license: |
- In Spark 3.2, `FloatType` is mapped to `FLOAT` in MySQL. Prior to this, it used to be mapped to `REAL`, which is by default a synonym to `DOUBLE PRECISION` in MySQL.

- In Spark 3.2, the query executions triggered by `DataFrameWriter` are always named `command` when being sent to `QueryExecutionListener`. In Spark 3.1 and earlier, the name is one of `save`, `insertInto`, `saveAsTable`, `create`, `append`, `overwrite`, `overwritePartitions`, `replace`.

- In Spark 3.2, `Dataset.unionByName` with `allowMissingColumns` set to true will add missing nested fields to the end of structs. In Spark 3.1, nested struct fields are sorted alphabetically.

## Upgrading from Spark SQL 3.0 to 3.1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,137 +20,56 @@ package org.apache.spark.sql.catalyst.analysis
import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{CombineUnions, OptimizeUpdateFields}
import org.apache.spark.sql.catalyst.optimizer.{CombineUnions}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.UNION
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.unsafe.types.UTF8String

/**
* Resolves different children of Union to a common set of columns.
*/
object ResolveUnion extends Rule[LogicalPlan] {
/**
* This method sorts columns recursively in a struct expression based on column names.
* Adds missing fields recursively into given `col` expression, based on the expected struct
* fields from merging the two schemas. This is called by `compareAndAddFields` when we find two
* struct columns with same name but different nested fields. This method will recursively
* return a new struct with all of the expected fields, adding null values when `col` doesn't
* already contain them. Currently we don't support merging structs nested inside of arrays
* or maps.
*/
private def sortStructFields(expr: Expression): Expression = {
val existingExprs = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
case (name, i) =>
val fieldExpr = GetStructField(KnownNotNull(expr), i)
if (fieldExpr.dataType.isInstanceOf[StructType]) {
(name, sortStructFields(fieldExpr))
} else {
(name, fieldExpr)
}
}.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2))

val newExpr = CreateNamedStruct(existingExprs)
if (expr.nullable) {
If(IsNull(expr), Literal(null, newExpr.dataType), newExpr)
} else {
newExpr
}
}

/**
* Assumes input expressions are field expression of `CreateNamedStruct`. This method
* sorts the expressions based on field names.
*/
private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = {
fieldExprs.grouped(2).map { e =>
Seq(e.head, e.last)
}.toSeq.sortBy { pair =>
assert(pair.head.isInstanceOf[Literal])
pair.head.eval().asInstanceOf[UTF8String].toString
}.flatten
}

/**
* This helper method sorts fields in a `UpdateFields` expression by field name.
*/
private def sortStructFieldsInWithFields(expr: Expression): Expression = expr transformUp {
case u: UpdateFields if u.resolved =>
u.evalExpr match {
case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) =>
val sorted = sortFieldExprs(fieldExprs)
val newStruct = CreateNamedStruct(sorted)
i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = newStruct)
case CreateNamedStruct(fieldExprs) =>
val sorted = sortFieldExprs(fieldExprs)
val newStruct = CreateNamedStruct(sorted)
newStruct
case other =>
throw new IllegalStateException(s"`UpdateFields` has incorrect expression: $other. " +
"Please file a bug report with this error message, stack trace, and the query.")
}
}

/**
* Adds missing fields recursively into given `col` expression, based on the target `StructType`.
* This is called by `compareAndAddFields` when we find two struct columns with same name but
* different nested fields. This method will find out the missing nested fields from `col` to
* `target` struct and add these missing nested fields. Currently we don't support finding out
* missing nested fields of struct nested in array or struct nested in map.
*/
private def addFields(col: NamedExpression, target: StructType): Expression = {
private def addFields(col: Expression, expectedFields: Seq[StructField]): Expression = {
assert(col.dataType.isInstanceOf[StructType], "Only support StructType.")

val resolver = conf.resolver
val missingFieldsOpt =
StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver)

// We need to sort columns in result, because we might add another column in other side.
// E.g., we want to union two structs "a int, b long" and "a int, c string".
// If we don't sort, we will have "a int, b long, c string" and
// "a int, c string, b long", which are not compatible.
if (missingFieldsOpt.isEmpty) {
sortStructFields(col)
} else {
missingFieldsOpt.map { s =>
val struct = addFieldsInto(col, s.fields)
// Combines `WithFields`s to reduce expression tree.
val reducedStruct = struct.transformUp(OptimizeUpdateFields.optimizeUpdateFields)
val sorted = sortStructFieldsInWithFields(reducedStruct)
sorted
}.get
}
}

/**
* Adds missing fields recursively into given `col` expression. The missing fields are given
* in `fields`. For example, given `col` as "z struct<z:int, y:int>, x int", and `fields` is
* "z struct<w:long>, w string". This method will add a nested `z.w` field and a top-level
* `w` field to `col` and fill null values for them. Note that because we might also add missing
* fields at other side of Union, we must make sure corresponding attributes at two sides have
* same field order in structs, so when we adding missing fields, we will sort the fields based on
* field names. So the data type of returned expression will be
* "w string, x int, z struct<w:long, y:int, z:int>".
*/
private def addFieldsInto(
col: Expression,
fields: Seq[StructField]): Expression = {
fields.foldLeft(col) { case (currCol, field) =>
field.dataType match {
case st: StructType =>
val resolver = conf.resolver
val colField = currCol.dataType.asInstanceOf[StructType]
.find(f => resolver(f.name, field.name))
if (colField.isEmpty) {
// The whole struct is missing. Add a null.
UpdateFields(currCol, field.name, Literal(null, st))
} else {
UpdateFields(currCol, field.name,
addFieldsInto(ExtractValue(currCol, Literal(field.name), resolver), st.fields))
}
case dt =>
UpdateFields(currCol, field.name, Literal(null, dt))
val colType = col.dataType.asInstanceOf[StructType]
val newStructFields = expectedFields.flatMap { expectedField =>
val currentField = colType.fields.find(f => resolver(f.name, expectedField.name))

val newExpression = (currentField, expectedField.dataType) match {
case (Some(cf), expectedType: StructType) if cf.dataType.isInstanceOf[StructType] =>
val extractedValue = ExtractValue(col, Literal(cf.name), resolver)
val combinedStruct = addFields(extractedValue, expectedType.fields)
if (extractedValue.nullable) {
If(IsNull(extractedValue),
Literal(null, combinedStruct.dataType),
combinedStruct)
} else {
combinedStruct
}
case (Some(cf), _) =>
ExtractValue(col, Literal(cf.name), resolver)
case (None, expectedType) =>
Literal(null, expectedType)
}
Literal(expectedField.name) :: newExpression :: Nil
}
CreateNamedStruct(newStructFields)
}


/**
* This method will compare right to left plan's outputs. If there is one struct attribute
* at right side has same name with left side struct attribute, but two structs are not the
Expand Down Expand Up @@ -181,7 +100,8 @@ object ResolveUnion extends Rule[LogicalPlan] {
// like that. We will sort columns in the struct expression to make sure two sides of
// union have consistent schema.
aliased += foundAttr
Alias(addFields(foundAttr, target), foundAttr.name)()
val targetType = target.merge(source, conf.resolver)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, merge will throw an exception if two schemas conflict. I recall that union of conflicting schemas doesn't fail in ResolveUnion, but in CheckAnalysis. Could we follow original behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm good question, I'm not sure exactly how that would work without adding extra logic to StructType.merge to ignore conflicts. And now that you bring that up I'm starting to think using StructType.merge isn't the best method since it does care about DataType. I just noticed it doesn't handle similar types, so you get errors if you try to merge a float and a double, whereas the normal union just handles that. I might try to rework this again to not use the StructType.merge after all...

Alias(addFields(foundAttr, targetType.fields.toSeq), foundAttr.name)()
case _ =>
// We don't need/try to add missing fields if:
// 1. The attributes of left and right side are the same struct type
Expand All @@ -208,13 +128,11 @@ object ResolveUnion extends Rule[LogicalPlan] {
left: LogicalPlan,
right: LogicalPlan,
allowMissingCol: Boolean): LogicalPlan = {
val rightOutputAttrs = right.output

// Builds a project list for `right` based on `left` output names
val (rightProjectList, aliased) = compareAndAddFields(left, right, allowMissingCol)

// Delegates failure checks to `CheckAnalysis`
val notFoundAttrs = rightOutputAttrs.diff(rightProjectList ++ aliased)
val notFoundAttrs = right.output.diff(rightProjectList ++ aliased)
val rightChild = Project(rightProjectList ++ notFoundAttrs, right)

// Builds a project for `logicalPlan` based on `right` output names, if allowing
Expand All @@ -230,6 +148,7 @@ object ResolveUnion extends Rule[LogicalPlan] {
} else {
left
}

Union(leftChild, rightChild)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ object SchemaPruning extends SQLConfHelper {
// original schema
val mergedSchema = requestedRootFields
.map { root: RootField => StructType(Array(root.field)) }
.reduceLeft(_ merge _)
.reduceLeft((left, right) => left.merge(right, resolver))
val mergedDataSchema =
StructType(dataSchema.map(d => mergedSchema.find(m => resolver(m.name, d.name)).getOrElse(d)))
// Sort the fields of mergedDataSchema according to their order in dataSchema,
Expand Down Expand Up @@ -113,7 +113,7 @@ object SchemaPruning extends SQLConfHelper {
// this optional root field too.
val rootFieldType = StructType(Array(root.field))
val optFieldType = StructType(Array(opt.field))
val merged = optFieldType.merge(rootFieldType)
val merged = optFieldType.merge(rootFieldType, conf.resolver)
merged.sameType(optFieldType)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ case class Union(
children.map(_.output).transpose.map { attrs =>
val firstAttr = attrs.head
val nullable = attrs.exists(_.nullable)
val newDt = attrs.map(_.dataType).reduce(StructType.merge)
val newDt = attrs.map(_.dataType).reduce(StructType.merge(conf.resolver))
if (firstAttr.dataType == newDt) {
firstAttr.withNullability(nullable)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
* 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be
* thrown.
*/
private[sql] def merge(that: StructType): StructType =
StructType.merge(this, that).asInstanceOf[StructType]
private[sql] def merge(that: StructType, resolver: Resolver): StructType =
StructType.merge(resolver)(this, that).asInstanceOf[StructType]

override private[spark] def asNullable: StructType = {
val newFields = fields.map {
Expand Down Expand Up @@ -555,32 +555,31 @@ object StructType extends AbstractDataType {
case _ => dt
}

private[sql] def merge(left: DataType, right: DataType): DataType =
private[sql] def merge(resolver: Resolver)(left: DataType, right: DataType): DataType =
(left, right) match {
case (ArrayType(leftElementType, leftContainsNull),
ArrayType(rightElementType, rightContainsNull)) =>
ArrayType(
merge(leftElementType, rightElementType),
merge(resolver)(leftElementType, rightElementType),
leftContainsNull || rightContainsNull)

case (MapType(leftKeyType, leftValueType, leftContainsNull),
MapType(rightKeyType, rightValueType, rightContainsNull)) =>
MapType(
merge(leftKeyType, rightKeyType),
merge(leftValueType, rightValueType),
merge(resolver)(leftKeyType, rightKeyType),
merge(resolver)(leftValueType, rightValueType),
leftContainsNull || rightContainsNull)

case (StructType(leftFields), StructType(rightFields)) =>
val newFields = mutable.ArrayBuffer.empty[StructField]

val rightMapped = fieldsMap(rightFields)
leftFields.foreach {
case leftField @ StructField(leftName, leftType, leftNullable, _) =>
rightMapped.get(leftName)
rightFields.find(f => resolver(leftName, f.name))
.map { case rightField @ StructField(rightName, rightType, rightNullable, _) =>
try {
leftField.copy(
dataType = merge(leftType, rightType),
dataType = merge(resolver)(leftType, rightType),
nullable = leftNullable || rightNullable)
} catch {
case NonFatal(e) =>
Expand All @@ -593,12 +592,9 @@ object StructType extends AbstractDataType {
.foreach(newFields += _)
}

val leftMapped = fieldsMap(leftFields)
rightFields
.filterNot(f => leftMapped.get(f.name).nonEmpty)
.foreach { f =>
newFields += f
}
.filter(f => leftFields.find(lf => resolver(f.name, lf.name)).isEmpty)
.foreach(newFields += _)

StructType(newFields.toSeq)

Expand Down Expand Up @@ -634,39 +630,4 @@ object StructType extends AbstractDataType {
fields.foreach(s => map.put(s.name, s))
map
}

/**
* Returns a `StructType` that contains missing fields recursively from `source` to `target`.
* Note that this doesn't support looking into array type and map type recursively.
*/
def findMissingFields(
source: StructType,
target: StructType,
resolver: Resolver): Option[StructType] = {
def bothStructType(dt1: DataType, dt2: DataType): Boolean =
dt1.isInstanceOf[StructType] && dt2.isInstanceOf[StructType]

val newFields = mutable.ArrayBuffer.empty[StructField]

target.fields.foreach { field =>
val found = source.fields.find(f => resolver(field.name, f.name))
if (found.isEmpty) {
// Found a missing field in `source`.
newFields += field
} else if (bothStructType(found.get.dataType, field.dataType) &&
!found.get.dataType.sameType(field.dataType)) {
// Found a field with same name, but different data type.
findMissingFields(found.get.dataType.asInstanceOf[StructType],
field.dataType.asInstanceOf[StructType], resolver).map { missingType =>
newFields += found.get.copy(dataType = missingType)
}
}
}

if (newFields.isEmpty) {
None
} else {
Some(StructType(newFields.toSeq))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.fasterxml.jackson.core.JsonParseException
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataTypeTestUtils.dayTimeIntervalTypes

class DataTypeSuite extends SparkFunSuite {
Expand Down Expand Up @@ -153,7 +154,7 @@ class DataTypeSuite extends SparkFunSuite {
StructField("b", LongType) :: Nil)

val message = intercept[SparkException] {
left.merge(right)
left.merge(right, SQLConf.get.resolver)
}.getMessage
assert(message.equals("Failed to merge fields 'b' and 'b'. " +
"Failed to merge incompatible data types float and bigint"))
Expand Down
Loading