Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ class Analyzer(
ResolveTimeZone(conf) ::
ResolveRandomSeed ::
ResolveBinaryArithmetic ::
ResolveUnion ::
TypeCoercion.typeCoercionRules(conf) ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
Expand Down Expand Up @@ -1387,10 +1388,11 @@ class Analyzer(
i.copy(right = dedupRight(left, right))
case e @ Except(left, right, _) if !e.duplicateResolved =>
e.copy(right = dedupRight(left, right))
case u @ Union(children) if !u.duplicateResolved =>
// Only after we finish by-name resolution for Union
case u: Union if !u.byName && !u.duplicateResolved =>
// Use projection-based de-duplication for Union to avoid breaking the checkpoint sharing
// feature in streaming.
val newChildren = children.foldRight(Seq.empty[LogicalPlan]) { (head, tail) =>
val newChildren = u.children.foldRight(Seq.empty[LogicalPlan]) { (head, tail) =>
head +: tail.map {
case child if head.outputSet.intersect(child.outputSet).isEmpty =>
child
Expand Down Expand Up @@ -3398,7 +3400,7 @@ object EliminateSubqueryAliases extends Rule[LogicalPlan] {
*/
object EliminateUnions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case Union(children) if children.size == 1 => children.head
case u: Union if u.children.size == 1 => u.children.head
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.util.SchemaUtils

/**
* Resolves different children of Union to a common set of columns.
*/
object ResolveUnion extends Rule[LogicalPlan] {
private def unionTwoSides(
left: LogicalPlan,
right: LogicalPlan,
allowMissingCol: Boolean): LogicalPlan = {
val resolver = SQLConf.get.resolver
val leftOutputAttrs = left.output
val rightOutputAttrs = right.output

// Builds a project list for `right` based on `left` output names
val rightProjectList = leftOutputAttrs.map { lattr =>
rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse {
if (allowMissingCol) {
Alias(Literal(null, lattr.dataType), lattr.name)()
} else {
throw new AnalysisException(
s"""Cannot resolve column name "${lattr.name}" among """ +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about making it consistent (using ` for wrapping a column name)?

          throw new AnalysisException(
            s"Cannot resolve column name `${lattr.name}` among " +
              s"(${rightOutputAttrs.map(_.name).mkString(", ")})")

https://github.com/apache/spark/pull/29107/files#diff-1d14ac233eac6f233c027dba0bdf871dR341

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok.

s"""(${rightOutputAttrs.map(_.name).mkString(", ")})""")
}
}
}

// Delegates failure checks to `CheckAnalysis`
val notFoundAttrs = rightOutputAttrs.diff(rightProjectList)
val rightChild = Project(rightProjectList ++ notFoundAttrs, right)

// Builds a project for `logicalPlan` based on `right` output names, if allowing
// missing columns.
val leftChild = if (allowMissingCol) {
val missingAttrs = notFoundAttrs.map { attr =>
Alias(Literal(null, attr.dataType), attr.name)()
}
if (missingAttrs.nonEmpty) {
Project(leftOutputAttrs ++ missingAttrs, left)
} else {
left
}
} else {
left
}
Union(leftChild, rightChild)
}

// Check column name duplication
private def checkColumnNames(left: LogicalPlan, right: LogicalPlan): Unit = {
val caseSensitiveAnalysis = SQLConf.get.caseSensitiveAnalysis
val leftOutputAttrs = left.output
val rightOutputAttrs = right.output

SchemaUtils.checkColumnNameDuplication(
leftOutputAttrs.map(_.name),
"in the left attributes",
caseSensitiveAnalysis)
SchemaUtils.checkColumnNameDuplication(
rightOutputAttrs.map(_.name),
"in the right attributes",
caseSensitiveAnalysis)
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case e if !e.childrenResolved => e

case Union(children, byName, allowMissingCol) if byName =>
val union = children.reduceLeft { (left, right) =>
checkColumnNames(left, right)
unionTwoSides(left, right, allowMissingCol)
}
CombineUnions(union)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,10 @@ object TypeCoercion {
assert(newChildren.length == 2)
Intersect(newChildren.head, newChildren.last, isAll)

case s: Union if s.childrenResolved &&
case s: Union if s.childrenResolved && !s.byName &&
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
s.makeCopy(Array(newChildren))
s.copy(children = newChildren)
}

/** Build new children with the widest types for each attribute among all the children */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,8 @@ object LimitPushDown extends Rule[LogicalPlan] {
// Note: right now Union means UNION ALL, which does not de-duplicate rows, so it is safe to
// pushdown Limit through it. Once we add UNION DISTINCT, however, we will not be able to
// pushdown Limit.
case LocalLimit(exp, Union(children)) =>
LocalLimit(exp, Union(children.map(maybePushLocalLimit(exp, _))))
case LocalLimit(exp, u: Union) =>
LocalLimit(exp, u.copy(children = u.children.map(maybePushLocalLimit(exp, _))))
// Add extra limits below OUTER JOIN. For LEFT OUTER and RIGHT OUTER JOIN we push limits to
// the left and right sides, respectively. It's not safe to push limits below FULL OUTER
// JOIN in the general case without a more invasive rewrite.
Expand Down Expand Up @@ -556,15 +556,15 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper
def apply(plan: LogicalPlan): LogicalPlan = plan transform {

// Push down deterministic projection through UNION ALL
case p @ Project(projectList, Union(children)) =>
assert(children.nonEmpty)
case p @ Project(projectList, u: Union) =>
assert(u.children.nonEmpty)
if (projectList.forall(_.deterministic)) {
val newFirstChild = Project(projectList, children.head)
val newOtherChildren = children.tail.map { child =>
val rewrites = buildRewrites(children.head, child)
val newFirstChild = Project(projectList, u.children.head)
val newOtherChildren = u.children.tail.map { child =>
val rewrites = buildRewrites(u.children.head, child)
Project(projectList.map(pushToRight(_, rewrites)), child)
}
Union(newFirstChild +: newOtherChildren)
u.copy(children = newFirstChild +: newOtherChildren)
} else {
p
}
Expand Down Expand Up @@ -928,19 +928,28 @@ object CombineUnions extends Rule[LogicalPlan] {
}

private def flattenUnion(union: Union, flattenDistinct: Boolean): Union = {
val topByName = union.byName
val topAllowMissingCol = union.allowMissingCol

val stack = mutable.Stack[LogicalPlan](union)
val flattened = mutable.ArrayBuffer.empty[LogicalPlan]
// Note that we should only flatten the unions with same byName and allowMissingCol.
// Although we do `UnionCoercion` at analysis phase, we manually run `CombineUnions`
// in some places like `Dataset.union`. Flattening unions with different resolution
// rules (by position and by name) could cause incorrect results.
while (stack.nonEmpty) {
stack.pop() match {
case Distinct(Union(children)) if flattenDistinct =>
case Distinct(Union(children, byName, allowMissingCol))
if flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol =>
stack.pushAll(children.reverse)
case Union(children) =>
case Union(children, byName, allowMissingCol)
if byName == topByName && allowMissingCol == topAllowMissingCol =>
stack.pushAll(children.reverse)
case child =>
flattened += child
}
}
Union(flattened.toSeq)
union.copy(children = flattened)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper wit
override def conf: SQLConf = SQLConf.get

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p @ Union(children) if children.exists(isEmptyLocalRelation) =>
val newChildren = children.filterNot(isEmptyLocalRelation)
case p: Union if p.children.exists(isEmptyLocalRelation) =>
val newChildren = p.children.filterNot(isEmptyLocalRelation)
if (newChildren.isEmpty) {
empty(p)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,18 @@ object Union {

/**
* Logical plan for unioning two plans, without a distinct. This is UNION ALL in SQL.
*
* @param byName Whether resolves columns in the children by column names.
* @param allowMissingCol Allows missing columns in children query plans. If it is true,
* this function allows different set of column names between two Datasets.
* This can be set to true only if `byName` is true.
*/
case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
case class Union(
children: Seq[LogicalPlan],
byName: Boolean = false,
allowMissingCol: Boolean = false) extends LogicalPlan {
assert(!allowMissingCol || byName, "`allowMissingCol` can be true only if `byName` is true.")

override def maxRows: Option[Long] = {
if (children.exists(_.maxRows.isEmpty)) {
None
Expand Down Expand Up @@ -271,7 +281,7 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
child.output.zip(children.head.output).forall {
case (l, r) => l.dataType.sameType(r.dataType)
})
children.length > 1 && childrenResolved && allChildrenCompatible
children.length > 1 && !(byName || allowMissingCol) && childrenResolved && allChildrenCompatible
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter {
Union(Project(Seq(Alias(left, "l")()), relation),
Project(Seq(Alias(right, "r")()), relation))
val (l, r) = analyzer.execute(plan).collect {
case Union(Seq(child1, child2)) => (child1.output.head, child2.output.head)
case Union(Seq(child1, child2), _, _) => (child1.output.head, child2.output.head)
}.head
assert(l.dataType === expectedType)
assert(r.dataType === expectedType)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

class ResolveUnionSuite extends AnalysisTest {
test("Resolve Union") {
val table1 = LocalRelation(
AttributeReference("i", IntegerType)(),
AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
AttributeReference("b", ByteType)(),
AttributeReference("d", DoubleType)())
val table2 = LocalRelation(
AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
AttributeReference("b", ByteType)(),
AttributeReference("d", DoubleType)(),
AttributeReference("i", IntegerType)())
val table3 = LocalRelation(
AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
AttributeReference("d", DoubleType)(),
AttributeReference("i", IntegerType)())
val table4 = LocalRelation(
AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(),
AttributeReference("i", IntegerType)())

val rules = Seq(ResolveUnion)
val analyzer = new RuleExecutor[LogicalPlan] {
override val batches = Seq(Batch("Resolution", Once, rules: _*))
}

// By name resolution
val union1 = Union(table1 :: table2 :: Nil, true, false)
val analyzed1 = analyzer.execute(union1)
val projected1 =
Project(Seq(table2.output(3), table2.output(0), table2.output(1), table2.output(2)), table2)
val expected1 = Union(table1 :: projected1 :: Nil)
comparePlans(analyzed1, expected1)

// Allow missing column
val union2 = Union(table1 :: table3 :: Nil, true, true)
val analyzed2 = analyzer.execute(union2)
val nullAttr1 = Alias(Literal(null, ByteType), "b")()
val projected2 =
Project(Seq(table2.output(3), table2.output(0), nullAttr1, table2.output(2)), table3)
val expected2 = Union(table1 :: projected2 :: Nil)
comparePlans(analyzed2, expected2)

// Allow missing column + Allow missing column
val union3 = Union(union2 :: table4 :: Nil, true, true)
val analyzed3 = analyzer.execute(union3)
val nullAttr2 = Alias(Literal(null, DoubleType), "d")()
val projected3 =
Project(Seq(table2.output(3), table2.output(0), nullAttr1, nullAttr2), table4)
val expected3 = Union(table1 :: projected2 :: projected3 :: Nil)
comparePlans(analyzed3, expected3)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -223,4 +223,21 @@ class SetOperationSuite extends PlanTest {
val unionCorrectAnswer = unionQuery.analyze
comparePlans(unionOptimized, unionCorrectAnswer)
}

test("CombineUnions only flatten the unions with same byName and allowMissingCol") {
val union1 = Union(testRelation :: testRelation :: Nil, true, false)
val union2 = Union(testRelation :: testRelation :: Nil, true, true)
val union3 = Union(testRelation :: testRelation2 :: Nil, false, false)

val union4 = Union(union1 :: union2 :: union3 :: Nil)
val unionOptimized1 = Optimize.execute(union4)
val unionCorrectAnswer1 = Union(union1 :: union2 :: testRelation :: testRelation2 :: Nil)
comparePlans(unionOptimized1, unionCorrectAnswer1, false)

val union5 = Union(union1 :: union1 :: Nil, true, false)
val unionOptimized2 = Optimize.execute(union5)
val unionCorrectAnswer2 =
Union(testRelation :: testRelation :: testRelation :: testRelation :: Nil, true, false)
comparePlans(unionOptimized2, unionCorrectAnswer2, false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,9 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
JObject(
"class" -> classOf[Union].getName,
"num-children" -> 2,
"children" -> List(0, 1)),
"children" -> List(0, 1),
"byName" -> JBool(false),
"allowMissingCol" -> JBool(false)),
JObject(
"class" -> classOf[JsonTestTreeNode].getName,
"num-children" -> 0,
Expand Down
Loading