diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fa14aa14ee96..d50206146265 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2300,6 +2300,60 @@ class Dataset[T] private[sql]( } } + /** + * (Scala-specific) Returns a new Dataset with renamed columns. + * This is a no-op if schema doesn't contain any columns in map. + * + * {{{ + * ds.withColumnsRenamed( + * "exist_column1" -> "new_column1", + * "exist_column2" -> "new_column2" + * ) + * }}} + * + * @group untypedrel + * @since 3.0.0 + */ + @scala.annotation.varargs + def withColumnsRenamed(columnMap: (String, String), columnMaps: (String, String)*): DataFrame = { + withColumnsRenamed((columnMap +: columnMaps).toMap) + } + + /** + * (Scala-specific) Returns a new Dataset with renamed columns. + * This is a no-op if schema doesn't contain any columns in map. + * + * {{{ + * ds.withColumnsRenamed(Map( + * "exist_column1" -> "new_column1", + * "exist_column2" -> "new_column2" + * )) + * }}} + * + * @group untypedrel + * @since 3.0.0 + */ + def withColumnsRenamed(columnMap: Map[String, String]): DataFrame = { + val resolver = sparkSession.sessionState.analyzer.resolver + val allColumns = queryExecution.analyzed.output + + val shouldRename = allColumns.exists { attribute => + columnMap.exists(m => resolver(attribute.name, m._1)) + } + + if (shouldRename) { + val newColumns = allColumns.map { attribute => + columnMap + .find(m => resolver(attribute.name, m._1)) + .map(m => Column(attribute).as(m._2)) + .getOrElse(Column(attribute)) + } + select(newColumns: _*) + } else { + toDF() + } + } + /** * Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain * column name. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 4e593ff046a5..c1e521ee23aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1547,6 +1547,36 @@ class DatasetSuite extends QueryTest with SharedSQLContext { df.where($"city".contains(new java.lang.Character('A'))), Seq(Row("Amsterdam"))) } + + test("SPARK-25571: Add withColumnsRenamed method to Dataset") { + val df1 = testData + .withColumn("key_squared", col("key") * col("key")) + .withColumnsRenamed( + "key" -> "newKey", + "value" -> "newValue", + "key_squared" -> "newKeySquared", + "not_exist_column" -> "notExistColumn" + ) + + val df2 = testData + .withColumn("key_squared", col("key") * col("key")) + .withColumnsRenamed(Map( + "key" -> "newKey", + "value" -> "newValue", + "key_squared" -> "newKeySquared", + "not_exist_column" -> "notExistColumn" + )) + + val expectedColumns = Seq("newKey", "newValue", "newKeySquared") + val expectedRows = testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key * key) + }.toSeq + + assert(df1.schema.map(_.name) === expectedColumns) + assert(df2.schema.map(_.name) === expectedColumns) + checkAnswer(df1, expectedRows) + checkAnswer(df2, expectedRows) + } } case class TestDataUnion(x: Int, y: Int, z: Int)