diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fa14aa14ee96..d3a170896a14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2300,6 +2300,37 @@ class Dataset[T] private[sql]( } } + /** + * Returns a new Dataset with columns renamed. + * This is a no-op if schema doesn't contain existingNames in columnMap. + * {{{ + * df.withColumnRenamed(Map( + * "c1" -> "first_column", + * "c2" -> "second_column" + * )) + * }}} + * + * @group untypedrel + * @since 3.0.0 + */ + def withColumnRenamed(columnMap: Map[String, String]): DataFrame = { + val resolver = sparkSession.sessionState.analyzer.resolver + val output = queryExecution.analyzed.output + val existingNames = columnMap.keys.toSeq + val shouldRename = !output.map(_.name).intersect(existingNames).isEmpty + if (shouldRename) { + val columns = output.map { col => + columnMap.get(col.name) match { + case Some(newName) => Column(col).as(newName) + case _ => Column(col) + } + } + select(columns : _*) + } else { + toDF() + } + } + /** * Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain * column name. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f001b138f4b8..525f139e74a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1021,6 +1021,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } + test("SPARK-25430: Add map parameter for withColumnRenamed") { + val df = testData.toDF().withColumn("newCol", col("key") + 1) + .withColumnRenamed(Map("value"->"valueRenamed", "newCol"->"newColRenamed", + "newCol2"->"newColRenamed2")) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1) + }.toSeq) + assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newColRenamed")) + } + private lazy val person2: DataFrame = Seq( ("Bob", 16, 176), ("Alice", 32, 164),