Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.optimizer.{CollapseProject, CombineUnions}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -2146,7 +2146,7 @@ class Dataset[T] private[sql](
* Returns a new Dataset by adding columns or replacing the existing columns that has
* the same names.
*/
private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = {
private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = withPlan {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As stated on the JIRA ticket, the problem is deep query plan. I think we can have many ways to create such deep query plan, not only for withColumns. For example, you can call select many times to do that too. This change makes withColumns a special case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but I think this is a special case. I have seen many cases when withColumn is used in for loops: with this change such a pattern would be better supported.

require(colNames.size == cols.size,
s"The size of column names: ${colNames.size} isn't equal to " +
s"the size of columns: ${cols.size}")
Expand All @@ -2164,16 +2164,16 @@ class Dataset[T] private[sql](
columnMap.find { case (colName, _) =>
resolver(field.name, colName)
} match {
case Some((colName: String, col: Column)) => col.as(colName)
case _ => Column(field)
case Some((colName: String, col: Column)) => col.as(colName).named
case _ => field
}
}

val newColumns = columnMap.filter { case (colName, col) =>
val newColumns = columnMap.filter { case (colName, _) =>
!output.exists(f => resolver(f.name, colName))
}.map { case (colName, col) => col.as(colName) }
}.map { case (colName, col) => col.as(colName).named }

select(replacedAndExistingColumns ++ newColumns : _*)
CollapseProject(Project(replacedAndExistingColumns ++ newColumns, logicalPlan))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reduce the scope of this optimization? e.g. if the root node of this query is Project, update its project list to include withColumns, otherwise add a new Project.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can do that. Imagine the case when all the columns depend on the previously added one: if we would do that, we would end up with an invalid plan. Or am I missing something?

}

/**
Expand Down
13 changes: 13 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.ScroogeLikeExample
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
Expand Down Expand Up @@ -1656,6 +1657,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.groupBy(col("a")).agg(first(col("b"))),
Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111))))
}

test("SPARK-26224: withColumn produces too many Projects") {
val N = 10
val resDF = (1 to N).foldLeft(Seq(1).toDF("a")) { case (df, i) =>
df.withColumn(s"col$i", lit(0))
}
assert(resDF.queryExecution.logical.collect {
case _: Project => true
}.size == 1)
val result = Row(1 :: List.fill(N)(0): _*)
checkAnswer(resDF, result)
}
}

case class TestDataUnion(x: Int, y: Int, z: Int)
Expand Down