diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 8e2f9cbccdd2..396cb26259f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -609,7 +609,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitioningAsV2, df.queryExecution.analyzed, tableSpec, - writeOptions = Map.empty, + writeOptions = extraOptions.toMap, orCreate = true) // Create the table if it doesn't exist case (other, _) => @@ -631,7 +631,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitioningAsV2, df.queryExecution.analyzed, tableSpec, - Map.empty, + writeOptions = extraOptions.toMap, other == SaveMode.Ignore) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index 951d787571e2..dd810a70d158 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -21,7 +21,7 @@ import java.util.Collections import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect} import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -207,4 +207,50 @@ class DataSourceV2DataFrameSuite assert(options.get(optionName) === "false") } } + + test("CTAS and RTAS should take write options") { + + var plan: LogicalPlan = null + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + plan = qe.analyzed + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + } + + try { + spark.listenerManager.register(listener) + + val t1 = "testcat.ns1.ns2.tbl" + + val df1 = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + df1.write.option("option1", "20").saveAsTable(t1) + + sparkContext.listenerBus.waitUntilEmpty() + plan match { + case o: CreateTableAsSelect => + assert(o.writeOptions == Map("option1" -> "20")) + case other => + fail(s"Expected to parse ${classOf[CreateTableAsSelect].getName} from query," + + s"got ${other.getClass.getName}: $plan") + } + checkAnswer(spark.table(t1), df1) + + val df2 = Seq((1L, "d"), (2L, "e"), (3L, "f")).toDF("id", "data") + df2.write.option("option2", "30").mode("overwrite").saveAsTable(t1) + + sparkContext.listenerBus.waitUntilEmpty() + plan match { + case o: ReplaceTableAsSelect => + assert(o.writeOptions == Map("option2" -> "30")) + case other => + fail(s"Expected to parse ${classOf[ReplaceTableAsSelect].getName} from query," + + s"got ${other.getClass.getName}: $plan") + } + + checkAnswer(spark.table(t1), df2) + } finally { + spark.listenerManager.unregister(listener) + } + } }