diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 023ef2ee17473..ed6fd93de926a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -249,6 +249,7 @@ class Analyzer( ResolveTimeZone(conf) :: ResolveRandomSeed :: ResolveBinaryArithmetic :: + ResolveUnion :: TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -1387,10 +1388,11 @@ class Analyzer( i.copy(right = dedupRight(left, right)) case e @ Except(left, right, _) if !e.duplicateResolved => e.copy(right = dedupRight(left, right)) - case u @ Union(children) if !u.duplicateResolved => + // Only after we finish by-name resolution for Union + case u: Union if !u.byName && !u.duplicateResolved => // Use projection-based de-duplication for Union to avoid breaking the checkpoint sharing // feature in streaming. - val newChildren = children.foldRight(Seq.empty[LogicalPlan]) { (head, tail) => + val newChildren = u.children.foldRight(Seq.empty[LogicalPlan]) { (head, tail) => head +: tail.map { case child if head.outputSet.intersect(child.outputSet).isEmpty => child @@ -3398,7 +3400,7 @@ object EliminateSubqueryAliases extends Rule[LogicalPlan] { */ object EliminateUnions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case Union(children) if children.size == 1 => children.head + case u: Union if u.children.size == 1 => u.children.head } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala new file mode 100644 index 0000000000000..693a5a4e75443 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.optimizer.CombineUnions +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.util.SchemaUtils + +/** + * Resolves different children of Union to a common set of columns. + */ +object ResolveUnion extends Rule[LogicalPlan] { + private def unionTwoSides( + left: LogicalPlan, + right: LogicalPlan, + allowMissingCol: Boolean): LogicalPlan = { + val resolver = SQLConf.get.resolver + val leftOutputAttrs = left.output + val rightOutputAttrs = right.output + + // Builds a project list for `right` based on `left` output names + val rightProjectList = leftOutputAttrs.map { lattr => + rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse { + if (allowMissingCol) { + Alias(Literal(null, lattr.dataType), lattr.name)() + } else { + throw new AnalysisException( + s"""Cannot resolve column name "${lattr.name}" among """ + + s"""(${rightOutputAttrs.map(_.name).mkString(", ")})""") + } + } + } + + // Delegates failure checks to `CheckAnalysis` + val notFoundAttrs = rightOutputAttrs.diff(rightProjectList) + val rightChild = Project(rightProjectList ++ notFoundAttrs, right) + + // Builds a project for `logicalPlan` based on `right` output names, if allowing + // missing columns. + val leftChild = if (allowMissingCol) { + val missingAttrs = notFoundAttrs.map { attr => + Alias(Literal(null, attr.dataType), attr.name)() + } + if (missingAttrs.nonEmpty) { + Project(leftOutputAttrs ++ missingAttrs, left) + } else { + left + } + } else { + left + } + Union(leftChild, rightChild) + } + + // Check column name duplication + private def checkColumnNames(left: LogicalPlan, right: LogicalPlan): Unit = { + val caseSensitiveAnalysis = SQLConf.get.caseSensitiveAnalysis + val leftOutputAttrs = left.output + val rightOutputAttrs = right.output + + SchemaUtils.checkColumnNameDuplication( + leftOutputAttrs.map(_.name), + "in the left attributes", + caseSensitiveAnalysis) + SchemaUtils.checkColumnNameDuplication( + rightOutputAttrs.map(_.name), + "in the right attributes", + caseSensitiveAnalysis) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { + case e if !e.childrenResolved => e + + case Union(children, byName, allowMissingCol) if byName => + val union = children.reduceLeft { (left, right) => + checkColumnNames(left, right) + unionTwoSides(left, right, allowMissingCol) + } + CombineUnions(union) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 3484108a5503f..604a082be4e55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -341,10 +341,10 @@ object TypeCoercion { assert(newChildren.length == 2) Intersect(newChildren.head, newChildren.last, isAll) - case s: Union if s.childrenResolved && + case s: Union if s.childrenResolved && !s.byName && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) - s.makeCopy(Array(newChildren)) + s.copy(children = newChildren) } /** Build new children with the widest types for each attribute among all the children */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e481cdbd5fdf4..c4a3f85bbf54b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -497,8 +497,8 @@ object LimitPushDown extends Rule[LogicalPlan] { // Note: right now Union means UNION ALL, which does not de-duplicate rows, so it is safe to // pushdown Limit through it. Once we add UNION DISTINCT, however, we will not be able to // pushdown Limit. - case LocalLimit(exp, Union(children)) => - LocalLimit(exp, Union(children.map(maybePushLocalLimit(exp, _)))) + case LocalLimit(exp, u: Union) => + LocalLimit(exp, u.copy(children = u.children.map(maybePushLocalLimit(exp, _)))) // Add extra limits below OUTER JOIN. For LEFT OUTER and RIGHT OUTER JOIN we push limits to // the left and right sides, respectively. It's not safe to push limits below FULL OUTER // JOIN in the general case without a more invasive rewrite. @@ -556,15 +556,15 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Push down deterministic projection through UNION ALL - case p @ Project(projectList, Union(children)) => - assert(children.nonEmpty) + case p @ Project(projectList, u: Union) => + assert(u.children.nonEmpty) if (projectList.forall(_.deterministic)) { - val newFirstChild = Project(projectList, children.head) - val newOtherChildren = children.tail.map { child => - val rewrites = buildRewrites(children.head, child) + val newFirstChild = Project(projectList, u.children.head) + val newOtherChildren = u.children.tail.map { child => + val rewrites = buildRewrites(u.children.head, child) Project(projectList.map(pushToRight(_, rewrites)), child) } - Union(newFirstChild +: newOtherChildren) + u.copy(children = newFirstChild +: newOtherChildren) } else { p } @@ -928,19 +928,28 @@ object CombineUnions extends Rule[LogicalPlan] { } private def flattenUnion(union: Union, flattenDistinct: Boolean): Union = { + val topByName = union.byName + val topAllowMissingCol = union.allowMissingCol + val stack = mutable.Stack[LogicalPlan](union) val flattened = mutable.ArrayBuffer.empty[LogicalPlan] + // Note that we should only flatten the unions with same byName and allowMissingCol. + // Although we do `UnionCoercion` at analysis phase, we manually run `CombineUnions` + // in some places like `Dataset.union`. Flattening unions with different resolution + // rules (by position and by name) could cause incorrect results. while (stack.nonEmpty) { stack.pop() match { - case Distinct(Union(children)) if flattenDistinct => + case Distinct(Union(children, byName, allowMissingCol)) + if flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol => stack.pushAll(children.reverse) - case Union(children) => + case Union(children, byName, allowMissingCol) + if byName == topByName && allowMissingCol == topAllowMissingCol => stack.pushAll(children.reverse) case child => flattened += child } } - Union(flattened.toSeq) + union.copy(children = flattened) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index 0299646150ff3..d3cdd71eafdb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -50,8 +50,8 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper wit override def conf: SQLConf = SQLConf.get def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p @ Union(children) if children.exists(isEmptyLocalRelation) => - val newChildren = children.filterNot(isEmptyLocalRelation) + case p: Union if p.children.exists(isEmptyLocalRelation) => + val newChildren = p.children.filterNot(isEmptyLocalRelation) if (newChildren.isEmpty) { empty(p) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 039fd9382000a..54ec5859823c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -220,8 +220,18 @@ object Union { /** * Logical plan for unioning two plans, without a distinct. This is UNION ALL in SQL. + * + * @param byName Whether resolves columns in the children by column names. + * @param allowMissingCol Allows missing columns in children query plans. If it is true, + * this function allows different set of column names between two Datasets. + * This can be set to true only if `byName` is true. */ -case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { +case class Union( + children: Seq[LogicalPlan], + byName: Boolean = false, + allowMissingCol: Boolean = false) extends LogicalPlan { + assert(!allowMissingCol || byName, "`allowMissingCol` can be true only if `byName` is true.") + override def maxRows: Option[Long] = { if (children.exists(_.maxRows.isEmpty)) { None @@ -271,7 +281,7 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { child.output.zip(children.head.output).forall { case (l, r) => l.dataType.sameType(r.dataType) }) - children.length > 1 && childrenResolved && allChildrenCompatible + children.length > 1 && !(byName || allowMissingCol) && childrenResolved && allChildrenCompatible } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index c3e18c7f9557f..d5991ff10ce6c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -73,7 +73,7 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { Union(Project(Seq(Alias(left, "l")()), relation), Project(Seq(Alias(right, "r")()), relation)) val (l, r) = analyzer.execute(plan).collect { - case Union(Seq(child1, child2)) => (child1.output.head, child2.output.head) + case Union(Seq(child1, child2), _, _) => (child1.output.head, child2.output.head) }.head assert(l.dataType === expectedType) assert(r.dataType === expectedType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnionSuite.scala new file mode 100644 index 0000000000000..5c7ad0067a456 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnionSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types._ + +class ResolveUnionSuite extends AnalysisTest { + test("Resolve Union") { + val table1 = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("b", ByteType)(), + AttributeReference("d", DoubleType)()) + val table2 = LocalRelation( + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("b", ByteType)(), + AttributeReference("d", DoubleType)(), + AttributeReference("i", IntegerType)()) + val table3 = LocalRelation( + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("d", DoubleType)(), + AttributeReference("i", IntegerType)()) + val table4 = LocalRelation( + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("i", IntegerType)()) + + val rules = Seq(ResolveUnion) + val analyzer = new RuleExecutor[LogicalPlan] { + override val batches = Seq(Batch("Resolution", Once, rules: _*)) + } + + // By name resolution + val union1 = Union(table1 :: table2 :: Nil, true, false) + val analyzed1 = analyzer.execute(union1) + val projected1 = + Project(Seq(table2.output(3), table2.output(0), table2.output(1), table2.output(2)), table2) + val expected1 = Union(table1 :: projected1 :: Nil) + comparePlans(analyzed1, expected1) + + // Allow missing column + val union2 = Union(table1 :: table3 :: Nil, true, true) + val analyzed2 = analyzer.execute(union2) + val nullAttr1 = Alias(Literal(null, ByteType), "b")() + val projected2 = + Project(Seq(table2.output(3), table2.output(0), nullAttr1, table2.output(2)), table3) + val expected2 = Union(table1 :: projected2 :: Nil) + comparePlans(analyzed2, expected2) + + // Allow missing column + Allow missing column + val union3 = Union(union2 :: table4 :: Nil, true, true) + val analyzed3 = analyzer.execute(union3) + val nullAttr2 = Alias(Literal(null, DoubleType), "d")() + val projected3 = + Project(Seq(table2.output(3), table2.output(0), nullAttr1, nullAttr2), table4) + val expected3 = Union(table1 :: projected2 :: projected3 :: Nil) + comparePlans(analyzed3, expected3) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index ccc30b1d2f8ce..2eea840e21a31 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -223,4 +223,21 @@ class SetOperationSuite extends PlanTest { val unionCorrectAnswer = unionQuery.analyze comparePlans(unionOptimized, unionCorrectAnswer) } + + test("CombineUnions only flatten the unions with same byName and allowMissingCol") { + val union1 = Union(testRelation :: testRelation :: Nil, true, false) + val union2 = Union(testRelation :: testRelation :: Nil, true, true) + val union3 = Union(testRelation :: testRelation2 :: Nil, false, false) + + val union4 = Union(union1 :: union2 :: union3 :: Nil) + val unionOptimized1 = Optimize.execute(union4) + val unionCorrectAnswer1 = Union(union1 :: union2 :: testRelation :: testRelation2 :: Nil) + comparePlans(unionOptimized1, unionCorrectAnswer1, false) + + val union5 = Union(union1 :: union1 :: Nil, true, false) + val unionOptimized2 = Optimize.execute(union5) + val unionCorrectAnswer2 = + Union(testRelation :: testRelation :: testRelation :: testRelation :: Nil, true, false) + comparePlans(unionOptimized2, unionCorrectAnswer2, false) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index f5259706325eb..ff51bc0071c80 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -583,7 +583,9 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { JObject( "class" -> classOf[Union].getName, "num-children" -> 2, - "children" -> List(0, 1)), + "children" -> List(0, 1), + "byName" -> JBool(false), + "allowMissingCol" -> JBool(false)), JObject( "class" -> classOf[JsonTestTreeNode].getName, "num-children" -> 0, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3472b9fdec9d8..bf124db38edf8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -227,7 +227,7 @@ class Dataset[T] private[sql]( val plan = queryExecution.analyzed match { case c: Command => LocalRelation(c.output, withAction("command", queryExecution)(_.executeCollect())) - case u @ Union(children) if children.forall(_.isInstanceOf[Command]) => + case u @ Union(children, _, _) if children.forall(_.isInstanceOf[Command]) => LocalRelation(u.output, withAction("command", queryExecution)(_.executeCollect())) case _ => queryExecution.analyzed @@ -2071,51 +2071,9 @@ class Dataset[T] private[sql]( * @since 3.1.0 */ def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T] = withSetOperator { - // Check column name duplication - val resolver = sparkSession.sessionState.analyzer.resolver - val leftOutputAttrs = logicalPlan.output - val rightOutputAttrs = other.logicalPlan.output - - SchemaUtils.checkColumnNameDuplication( - leftOutputAttrs.map(_.name), - "in the left attributes", - sparkSession.sessionState.conf.caseSensitiveAnalysis) - SchemaUtils.checkColumnNameDuplication( - rightOutputAttrs.map(_.name), - "in the right attributes", - sparkSession.sessionState.conf.caseSensitiveAnalysis) - - // Builds a project list for `other` based on `logicalPlan` output names - val rightProjectList = leftOutputAttrs.map { lattr => - rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse { - if (allowMissingColumns) { - Alias(Literal(null, lattr.dataType), lattr.name)() - } else { - throw new AnalysisException( - s"""Cannot resolve column name "${lattr.name}" among """ + - s"""(${rightOutputAttrs.map(_.name).mkString(", ")})""") - } - } - } - - // Delegates failure checks to `CheckAnalysis` - val notFoundAttrs = rightOutputAttrs.diff(rightProjectList) - val rightChild = Project(rightProjectList ++ notFoundAttrs, other.logicalPlan) - - // Builds a project for `logicalPlan` based on `other` output names, if allowing - // missing columns. - val leftChild = if (allowMissingColumns) { - val missingAttrs = notFoundAttrs.map { attr => - Alias(Literal(null, attr.dataType), attr.name)() - } - Project(leftOutputAttrs ++ missingAttrs, logicalPlan) - } else { - logicalPlan - } - // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(leftChild, rightChild)) + CombineUnions(Union(logicalPlan :: other.logicalPlan :: Nil, true, allowMissingColumns)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7b5d8f15962d0..78aa258387daa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -683,8 +683,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.LocalLimitExec(limit, planLater(child)) :: Nil case logical.GlobalLimit(IntegerLiteral(limit), child) => execution.GlobalLimitExec(limit, planLater(child)) :: Nil - case logical.Union(unionChildren) => - execution.UnionExec(unionChildren.map(planLater)) :: Nil + case union: logical.Union => + execution.UnionExec(union.children.map(planLater)) :: Nil case g @ logical.Generate(generator, _, outer, _, _, child) => execution.GenerateExec( generator, g.requiredChildOutput, outer, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala index 9107f8afa83d7..b4cb7e3fce3cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala @@ -37,9 +37,9 @@ class SparkPlannerSuite extends SharedSparkSession { case ReturnAnswer(child) => planned += 1 planLater(child) :: planLater(NeverPlanned) :: Nil - case Union(children) => + case u: Union => planned += 1 - UnionExec(children.map(planLater)) :: planLater(NeverPlanned) :: Nil + UnionExec(u.children.map(planLater)) :: planLater(NeverPlanned) :: Nil case LocalRelation(output, data, _) => planned += 1 LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil