diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6f97121d88ede..3472b9fdec9d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2030,7 +2030,47 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.3.0 */ - def unionByName(other: Dataset[T]): Dataset[T] = withSetOperator { + def unionByName(other: Dataset[T]): Dataset[T] = unionByName(other, false) + + /** + * Returns a new Dataset containing union of rows in this Dataset and another Dataset. + * + * The difference between this function and [[union]] is that this function + * resolves columns by name (not by position). + * + * When the parameter `allowMissingColumns` is true, this function allows different set + * of column names between two Datasets. Missing columns at each side, will be filled with + * null values. The missing columns at left Dataset will be added at the end in the schema + * of the union result: + * + * {{{ + * val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2") + * val df2 = Seq((4, 5, 6)).toDF("col1", "col0", "col3") + * df1.unionByName(df2, true).show + * + * // output: "col3" is missing at left df1 and added at the end of schema. + * // +----+----+----+----+ + * // |col0|col1|col2|col3| + * // +----+----+----+----+ + * // | 1| 2| 3|null| + * // | 5| 4|null| 6| + * // +----+----+----+----+ + * + * df2.unionByName(df1, true).show + * + * // output: "col2" is missing at left df2 and added at the end of schema. + * // +----+----+----+----+ + * // |col1|col0|col3|col2| + * // +----+----+----+----+ + * // | 4| 5| 6|null| + * // | 2| 1|null| 3| + * // +----+----+----+----+ + * }}} + * + * @group typedrel + * @since 3.1.0 + */ + def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T] = withSetOperator { // Check column name duplication val resolver = sparkSession.sessionState.analyzer.resolver val leftOutputAttrs = logicalPlan.output @@ -2048,9 +2088,13 @@ class Dataset[T] private[sql]( // Builds a project list for `other` based on `logicalPlan` output names val rightProjectList = leftOutputAttrs.map { lattr => rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse { - throw new AnalysisException( - s"""Cannot resolve column name "${lattr.name}" among """ + - s"""(${rightOutputAttrs.map(_.name).mkString(", ")})""") + if (allowMissingColumns) { + Alias(Literal(null, lattr.dataType), lattr.name)() + } else { + throw new AnalysisException( + s"""Cannot resolve column name "${lattr.name}" among """ + + s"""(${rightOutputAttrs.map(_.name).mkString(", ")})""") + } } } @@ -2058,9 +2102,20 @@ class Dataset[T] private[sql]( val notFoundAttrs = rightOutputAttrs.diff(rightProjectList) val rightChild = Project(rightProjectList ++ notFoundAttrs, other.logicalPlan) + // Builds a project for `logicalPlan` based on `other` output names, if allowing + // missing columns. + val leftChild = if (allowMissingColumns) { + val missingAttrs = notFoundAttrs.map { attr => + Alias(Literal(null, attr.dataType), attr.name)() + } + Project(leftOutputAttrs ++ missingAttrs, logicalPlan) + } else { + logicalPlan + } + // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, rightChild)) + CombineUnions(Union(leftChild, rightChild)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index bd3f48078374d..11d7907c5a193 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -506,4 +506,35 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { check(lit(2).cast("int"), $"c" === 2, Seq(Row(1, 1, 2, 0), Row(1, 1, 2, 1))) check(lit(2).cast("int"), $"c" =!= 2, Seq()) } + + test("SPARK-29358: Make unionByName optionally fill missing columns with nulls") { + var df1 = Seq(1, 2, 3).toDF("a") + var df2 = Seq(3, 1, 2).toDF("b") + val df3 = Seq(2, 3, 1).toDF("c") + val unionDf = df1.unionByName(df2.unionByName(df3, true), true) + checkAnswer(unionDf, + Row(1, null, null) :: Row(2, null, null) :: Row(3, null, null) :: // df1 + Row(null, 3, null) :: Row(null, 1, null) :: Row(null, 2, null) :: // df2 + Row(null, null, 2) :: Row(null, null, 3) :: Row(null, null, 1) :: Nil // df3 + ) + + df1 = Seq((1, 2)).toDF("a", "c") + df2 = Seq((3, 4, 5)).toDF("a", "b", "c") + checkAnswer(df1.unionByName(df2, true), + Row(1, 2, null) :: Row(3, 5, 4) :: Nil) + checkAnswer(df2.unionByName(df1, true), + Row(3, 4, 5) :: Row(1, null, 2) :: Nil) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + df2 = Seq((3, 4, 5)).toDF("a", "B", "C") + val union1 = df1.unionByName(df2, true) + val union2 = df2.unionByName(df1, true) + + checkAnswer(union1, Row(1, 2, null, null) :: Row(3, null, 4, 5) :: Nil) + checkAnswer(union2, Row(3, 4, 5, null) :: Row(1, null, null, 2) :: Nil) + + assert(union1.schema.fieldNames === Array("a", "c", "B", "C")) + assert(union2.schema.fieldNames === Array("a", "B", "C", "c")) + } + } }