diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 497ef848ac78f..ab17b93ad6146 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -295,7 +295,7 @@ abstract class InMemoryBaseTable( TableCapability.TRUNCATE) override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new InMemoryScanBuilder(schema) + new InMemoryScanBuilder(schema, options) } private def canEvaluate(filter: Filter): Boolean = { @@ -309,8 +309,10 @@ abstract class InMemoryBaseTable( } } - class InMemoryScanBuilder(tableSchema: StructType) extends ScanBuilder - with SupportsPushDownRequiredColumns with SupportsPushDownFilters { + class InMemoryScanBuilder( + tableSchema: StructType, + options: CaseInsensitiveStringMap) extends ScanBuilder + with SupportsPushDownRequiredColumns with SupportsPushDownFilters { private var schema: StructType = tableSchema private var postScanFilters: Array[Filter] = Array.empty private var evaluableFilters: Array[Filter] = Array.empty @@ -318,7 +320,7 @@ abstract class InMemoryBaseTable( override def build: Scan = { val scan = InMemoryBatchScan( - data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema) + data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema, options) if (evaluableFilters.nonEmpty) { scan.filter(evaluableFilters) } @@ -442,7 +444,8 @@ abstract class InMemoryBaseTable( case class InMemoryBatchScan( var _data: Seq[InputPartition], readSchema: StructType, - tableSchema: StructType) + tableSchema: StructType, + options: CaseInsensitiveStringMap) extends BatchScanBaseClass(_data, readSchema, tableSchema) with SupportsRuntimeFiltering { override def filterAttributes(): Array[NamedReference] = { @@ -474,17 +477,17 @@ abstract class InMemoryBaseTable( } } - abstract class InMemoryWriterBuilder() extends SupportsTruncate with SupportsDynamicOverwrite - with SupportsStreamingUpdateAsAppend { + abstract class InMemoryWriterBuilder(val info: LogicalWriteInfo) + extends SupportsTruncate with SupportsDynamicOverwrite with SupportsStreamingUpdateAsAppend { - protected var writer: BatchWrite = Append - protected var streamingWriter: StreamingWrite = StreamingAppend + protected var writer: BatchWrite = new Append(info) + protected var streamingWriter: StreamingWrite = new StreamingAppend(info) override def overwriteDynamicPartitions(): WriteBuilder = { - if (writer != Append) { + if (!writer.isInstanceOf[Append]) { throw new IllegalArgumentException(s"Unsupported writer type: $writer") } - writer = DynamicOverwrite + writer = new DynamicOverwrite(info) streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions") this } @@ -529,13 +532,13 @@ abstract class InMemoryBaseTable( override def abort(messages: Array[WriterCommitMessage]): Unit = {} } - protected object Append extends TestBatchWrite { + class Append(val info: LogicalWriteInfo) extends TestBatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { withData(messages.map(_.asInstanceOf[BufferedRows])) } } - private object DynamicOverwrite extends TestBatchWrite { + class DynamicOverwrite(val info: LogicalWriteInfo) extends TestBatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { val newData = messages.map(_.asInstanceOf[BufferedRows]) dataMap --= newData.flatMap(_.rows.map(getKey)) @@ -543,7 +546,7 @@ abstract class InMemoryBaseTable( } } - protected object TruncateAndAppend extends TestBatchWrite { + class TruncateAndAppend(val info: LogicalWriteInfo) extends TestBatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { dataMap.clear() withData(messages.map(_.asInstanceOf[BufferedRows])) @@ -572,7 +575,7 @@ abstract class InMemoryBaseTable( s"${operation} isn't supported for streaming query.") } - private object StreamingAppend extends TestStreamingWrite { + class StreamingAppend(val info: LogicalWriteInfo) extends TestStreamingWrite { override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { dataMap.synchronized { withData(messages.map(_.asInstanceOf[BufferedRows])) @@ -580,7 +583,7 @@ abstract class InMemoryBaseTable( } } - protected object StreamingTruncateAndAppend extends TestStreamingWrite { + class StreamingTruncateAndAppend(val info: LogicalWriteInfo) extends TestStreamingWrite { override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { dataMap.synchronized { dataMap.clear() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala index 4abe4c8b3e3fb..3a684dc57c02f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -59,7 +59,7 @@ class InMemoryRowLevelOperationTable( } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new InMemoryScanBuilder(schema) { + new InMemoryScanBuilder(schema, options) { override def build: Scan = { val scan = super.build() configuredScan = scan.asInstanceOf[InMemoryBatchScan] @@ -115,7 +115,7 @@ class InMemoryRowLevelOperationTable( override def rowId(): Array[NamedReference] = Array(PK_COLUMN_REF) override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new InMemoryScanBuilder(schema) + new InMemoryScanBuilder(schema, options) } override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index af04816e6b6f0..c27b8fea059f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -84,23 +84,23 @@ class InMemoryTable( InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties)) InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options) - new InMemoryWriterBuilderWithOverWrite() + new InMemoryWriterBuilderWithOverWrite(info) } - private class InMemoryWriterBuilderWithOverWrite() extends InMemoryWriterBuilder - with SupportsOverwrite { + class InMemoryWriterBuilderWithOverWrite(override val info: LogicalWriteInfo) + extends InMemoryWriterBuilder(info) with SupportsOverwrite { override def truncate(): WriteBuilder = { - if (writer != Append) { + if (!writer.isInstanceOf[Append]) { throw new IllegalArgumentException(s"Unsupported writer type: $writer") } - writer = TruncateAndAppend - streamingWriter = StreamingTruncateAndAppend + writer = new TruncateAndAppend(info) + streamingWriter = new StreamingTruncateAndAppend(info) this } override def overwrite(filters: Array[Filter]): WriteBuilder = { - if (writer != Append) { + if (!writer.isInstanceOf[Append]) { throw new IllegalArgumentException(s"Unsupported writer type: $writer") } writer = new Overwrite(filters) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala index 20ada0d622bca..9b7a90774f91c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala @@ -47,19 +47,22 @@ class InMemoryTableWithV2Filter( } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new InMemoryV2FilterScanBuilder(schema) + new InMemoryV2FilterScanBuilder(schema, options) } - class InMemoryV2FilterScanBuilder(tableSchema: StructType) - extends InMemoryScanBuilder(tableSchema) { + class InMemoryV2FilterScanBuilder( + tableSchema: StructType, + options: CaseInsensitiveStringMap) + extends InMemoryScanBuilder(tableSchema, options) { override def build: Scan = InMemoryV2FilterBatchScan( - data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema) + data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema, options) } case class InMemoryV2FilterBatchScan( var _data: Seq[InputPartition], readSchema: StructType, - tableSchema: StructType) + tableSchema: StructType, + options: CaseInsensitiveStringMap) extends BatchScanBaseClass(_data, readSchema, tableSchema) with SupportsRuntimeV2Filtering { override def filterAttributes(): Array[NamedReference] = { @@ -93,21 +96,21 @@ class InMemoryTableWithV2Filter( InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties)) InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options) - new InMemoryWriterBuilderWithOverWrite() + new InMemoryWriterBuilderWithOverWrite(info) } - private class InMemoryWriterBuilderWithOverWrite() extends InMemoryWriterBuilder - with SupportsOverwriteV2 { + class InMemoryWriterBuilderWithOverWrite(override val info: LogicalWriteInfo) + extends InMemoryWriterBuilder(info) with SupportsOverwriteV2 { override def truncate(): WriteBuilder = { - assert(writer == Append) - writer = TruncateAndAppend - streamingWriter = StreamingTruncateAndAppend + assert(writer.isInstanceOf[Append]) + writer = new TruncateAndAppend(info) + streamingWriter = new StreamingTruncateAndAppend(info) this } override def overwrite(predicates: Array[Predicate]): WriteBuilder = { - assert(writer == Append) + assert(writer.isInstanceOf[Append]) writer = new Overwrite(predicates) streamingWriter = new StreamingNotSupportedOperation( s"overwrite (${predicates.mkString("filters(", ", ", ")")})") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index 319cc1c731577..17b2579ca873a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.{Optional, UUID} +import scala.jdk.CollectionConverters._ + import org.apache.spark.sql.catalyst.expressions.PredicateHelper import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Project, ReplaceData, WriteDelta} import org.apache.spark.sql.catalyst.rules.Rule @@ -44,7 +46,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case a @ AppendData(r: DataSourceV2Relation, query, options, _, None, _) => - val writeBuilder = newWriteBuilder(r.table, options, query.schema) + val writeOptions = mergeOptions(options, r.options.asScala.toMap) + val writeBuilder = newWriteBuilder(r.table, writeOptions, query.schema) val write = writeBuilder.build() val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) a.copy(write = Some(write), query = newQuery) @@ -61,7 +64,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { }.toArray val table = r.table - val writeBuilder = newWriteBuilder(table, options, query.schema) + val writeOptions = mergeOptions(options, r.options.asScala.toMap) + val writeBuilder = newWriteBuilder(table, writeOptions, query.schema) val write = writeBuilder match { case builder: SupportsTruncate if isTruncate(predicates) => builder.truncate().build() @@ -76,7 +80,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, options, _, None) => val table = r.table - val writeBuilder = newWriteBuilder(table, options, query.schema) + val writeOptions = mergeOptions(options, r.options.asScala.toMap) + val writeBuilder = newWriteBuilder(table, writeOptions, query.schema) val write = writeBuilder match { case builder: SupportsDynamicOverwrite => builder.overwriteDynamicPartitions().build() @@ -87,31 +92,44 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { o.copy(write = Some(write), query = newQuery) case WriteToMicroBatchDataSource( - relation, table, query, queryId, writeOptions, outputMode, Some(batchId)) => - + relationOpt, table, query, queryId, options, outputMode, Some(batchId)) => + val writeOptions = mergeOptions( + options, relationOpt.map(r => r.options.asScala.toMap).getOrElse(Map.empty)) val writeBuilder = newWriteBuilder(table, writeOptions, query.schema, queryId) val write = buildWriteForMicroBatch(table, writeBuilder, outputMode) val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming) val customMetrics = write.supportedCustomMetrics.toImmutableArraySeq - val funCatalogOpt = relation.flatMap(_.funCatalog) + val funCatalogOpt = relationOpt.flatMap(_.funCatalog) val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, funCatalogOpt) - WriteToDataSourceV2(relation, microBatchWrite, newQuery, customMetrics) + WriteToDataSourceV2(relationOpt, microBatchWrite, newQuery, customMetrics) case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, _, None) => val rowSchema = DataTypeUtils.fromAttributes(rd.dataInput) - val writeBuilder = newWriteBuilder(r.table, Map.empty, rowSchema) + val writeOptions = mergeOptions(Map.empty, r.options.asScala.toMap) + val writeBuilder = newWriteBuilder(r.table, writeOptions, rowSchema) val write = writeBuilder.build() val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) // project away any metadata columns that could be used for distribution and ordering rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery)) case wd @ WriteDelta(r: DataSourceV2Relation, _, query, _, projections, None) => - val deltaWriteBuilder = newDeltaWriteBuilder(r.table, Map.empty, projections) + val writeOptions = mergeOptions(Map.empty, r.options.asScala.toMap) + val deltaWriteBuilder = newDeltaWriteBuilder(r.table, writeOptions, projections) val deltaWrite = deltaWriteBuilder.build() val newQuery = DistributionAndOrderingUtils.prepareQuery(deltaWrite, query, r.funCatalog) wd.copy(write = Some(deltaWrite), query = newQuery) } + private def mergeOptions( + commandOptions: Map[String, String], + dsOptions: Map[String, String]): Map[String, String] = { + // for DataFrame API cases, same options are carried by both Command and DataSourceV2Relation + // for DataFrameV2 API cases, options are only carried by Command + // for SQL cases, options are only carried by DataSourceV2Relation + assert(commandOptions == dsOptions || commandOptions.isEmpty || dsOptions.isEmpty) + commandOptions ++ dsOptions + } + private def buildWriteForMicroBatch( table: SupportsWrite, writeBuilder: WriteBuilder, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 30180d48da71a..b59c83c23d3c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -27,7 +27,7 @@ import org.scalatest.Assertions import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.storage.StorageLevel @@ -449,12 +449,12 @@ object QueryTest extends Assertions { } } - def withPhysicalPlansCaptured(spark: SparkSession, thunk: => Unit): Seq[SparkPlan] = { - var capturedPlans = Seq.empty[SparkPlan] + def withQueryExecutionsCaptured(spark: SparkSession)(thunk: => Unit): Seq[QueryExecution] = { + var capturedQueryExecutions = Seq.empty[QueryExecution] val listener = new QueryExecutionListener { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - capturedPlans = capturedPlans :+ qe.executedPlan + capturedQueryExecutions = capturedQueryExecutions :+ qe } override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} } @@ -468,7 +468,7 @@ object QueryTest extends Assertions { spark.listenerManager.unregister(listener) } - capturedPlans + capturedQueryExecutions } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala new file mode 100644 index 0000000000000..70291336ba317 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.QueryTest.withQueryExecutionsCaptured +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.connector.catalog.InMemoryBaseTable +import org.apache.spark.sql.execution.CommandResultExec +import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.functions.lit + +class DataSourceV2OptionSuite extends DatasourceV2SQLBase { + import testImplicits._ + + private val catalogAndNamespace = "testcat.ns1.ns2." + + test("SPARK-36680: Supports Dynamic Table Options for SQL Select") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string)") + sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b')") + + var df = sql(s"SELECT * FROM $t1") + var collected = df.queryExecution.optimizedPlan.collect { + case scan: DataSourceV2ScanRelation => + assert(scan.relation.options.isEmpty) + } + assert (collected.size == 1) + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"))) + + df = sql(s"SELECT * FROM $t1 WITH (`split-size` = 5)") + collected = df.queryExecution.optimizedPlan.collect { + case scan: DataSourceV2ScanRelation => + assert(scan.relation.options.get("split-size") == "5") + } + assert (collected.size == 1) + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"))) + + collected = df.queryExecution.executedPlan.collect { + case BatchScanExec(_, scan: InMemoryBaseTable#InMemoryBatchScan, _, _, _, _) => + assert(scan.options.get("split-size") === "5") + } + assert (collected.size == 1) + + val noValues = intercept[AnalysisException]( + sql(s"SELECT * FROM $t1 WITH (`split-size`)")) + assert(noValues.message.contains( + "Operation not allowed: Values must be specified for key(s): [split-size]")) + } + } + + test("SPARK-50286: Propagate options for DataFrameReader") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string)") + sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b')") + + var df = spark.table(t1) + var collected = df.queryExecution.optimizedPlan.collect { + case scan: DataSourceV2ScanRelation => + assert(scan.relation.options.isEmpty) + } + assert (collected.size == 1) + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"))) + + df = spark.read.option("split-size", "5").table(t1) + collected = df.queryExecution.optimizedPlan.collect { + case scan: DataSourceV2ScanRelation => + assert(scan.relation.options.get("split-size") == "5") + } + assert (collected.size == 1) + checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"))) + + collected = df.queryExecution.executedPlan.collect { + case BatchScanExec(_, scan: InMemoryBaseTable#InMemoryBatchScan, _, _, _, _) => + assert(scan.options.get("split-size") === "5") + } + assert (collected.size == 1) + } + } + + test("SPARK-49098, SPARK-50286: Supports Dynamic Table Options for SQL Insert") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string)") + val df = sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) VALUES (1, 'a'), (2, 'b')") + + var collected = df.queryExecution.optimizedPlan.collect { + case CommandResult(_, AppendData(relation: DataSourceV2Relation, _, _, _, _, _), _, _) => + assert(relation.options.get("write.split-size") == "10") + } + assert (collected.size == 1) + + collected = df.queryExecution.executedPlan.collect { + case CommandResultExec( + _, AppendDataExec(_, _, write), + _) => + val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append] + assert(append.info.options.get("write.split-size") === "10") + } + assert (collected.size == 1) + + val insertResult = sql(s"SELECT * FROM $t1") + checkAnswer(insertResult, Seq(Row(1, "a"), Row(2, "b"))) + } + } + + test("SPARK-50286: Propagate options for DataFrameWriter Append") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string)") + val captured = withQueryExecutionsCaptured(spark) { + Seq(1 -> "a", 2 -> "b").toDF("id", "data") + .write + .option("write.split-size", "10") + .mode("append") + .insertInto(t1) + } + assert(captured.size === 1) + val qe = captured.head + var collected = qe.optimizedPlan.collect { + case AppendData(_: DataSourceV2Relation, _, writeOptions, _, _, _) => + assert(writeOptions("write.split-size") == "10") + } + assert (collected.size == 1) + + collected = qe.executedPlan.collect { + case AppendDataExec(_, _, write) => + val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append] + assert(append.info.options.get("write.split-size") === "10") + } + assert (collected.size == 1) + } + } + + test("SPARK-50286: Propagate options for DataFrameWriterV2 Append") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string)") + val captured = withQueryExecutionsCaptured(spark) { + Seq(1 -> "a", 2 -> "b").toDF("id", "data") + .writeTo(t1) + .option("write.split-size", "10") + .append() + } + assert(captured.size === 1) + val qe = captured.head + var collected = qe.optimizedPlan.collect { + case AppendData(_: DataSourceV2Relation, _, writeOptions, _, _, _) => + assert(writeOptions("write.split-size") == "10") + } + assert (collected.size == 1) + + collected = qe.executedPlan.collect { + case AppendDataExec(_, _, write) => + val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append] + assert(append.info.options.get("write.split-size") === "10") + } + assert (collected.size == 1) + } + } + + test("SPARK-49098, SPARK-50286: Supports Dynamic Table Options for SQL Insert Overwrite") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string)") + sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b')") + + val df = sql(s"INSERT OVERWRITE $t1 WITH (`write.split-size` = 10) " + + s"VALUES (3, 'c'), (4, 'd')") + var collected = df.queryExecution.optimizedPlan.collect { + case CommandResult(_, + OverwriteByExpression(relation: DataSourceV2Relation, _, _, _, _, _, _), + _, _) => + assert(relation.options.get("write.split-size") === "10") + } + assert (collected.size == 1) + + collected = df.queryExecution.executedPlan.collect { + case CommandResultExec( + _, OverwriteByExpressionExec(_, _, write), + _) => + val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] + assert(append.info.options.get("write.split-size") === "10") + } + assert (collected.size == 1) + + val insertResult = sql(s"SELECT * FROM $t1") + checkAnswer(insertResult, Seq(Row(3, "c"), Row(4, "d"))) + } + } + + test("SPARK-50286: Propagate options for DataFrameWriterV2 OverwritePartitions") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string)") + sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b')") + + val captured = withQueryExecutionsCaptured(spark) { + Seq(3 -> "c", 4 -> "d").toDF("id", "data") + .writeTo(t1) + .option("write.split-size", "10") + .overwritePartitions() + } + assert(captured.size === 1) + val qe = captured.head + var collected = qe.optimizedPlan.collect { + case OverwritePartitionsDynamic(_: DataSourceV2Relation, _, writeOptions, _, _) => + assert(writeOptions("write.split-size") === "10") + } + assert (collected.size == 1) + + collected = qe.executedPlan.collect { + case OverwritePartitionsDynamicExec(_, _, write) => + val dynOverwrite = write.toBatch.asInstanceOf[InMemoryBaseTable#DynamicOverwrite] + assert(dynOverwrite.info.options.get("write.split-size") === "10") + } + assert (collected.size == 1) + } + } + + test("SPARK-49098, SPARK-50286: Supports Dynamic Table Options for SQL Insert Replace") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string)") + sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b')") + + val df = sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) " + + s"REPLACE WHERE TRUE " + + s"VALUES (3, 'c'), (4, 'd')") + var collected = df.queryExecution.optimizedPlan.collect { + case CommandResult(_, + OverwriteByExpression(relation: DataSourceV2Relation, _, _, _, _, _, _), + _, _) => + assert(relation.options.get("write.split-size") == "10") + } + assert (collected.size == 1) + + collected = df.queryExecution.executedPlan.collect { + case CommandResultExec( + _, OverwriteByExpressionExec(_, _, write), + _) => + val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] + assert(append.info.options.get("write.split-size") === "10") + } + assert (collected.size == 1) + + val insertResult = sql(s"SELECT * FROM $t1") + checkAnswer(insertResult, Seq(Row(3, "c"), Row(4, "d"))) + } + } + + test("SPARK-50286: Propagate options for DataFrameWriter Overwrite") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string)") + val captured = withQueryExecutionsCaptured(spark) { + Seq(1 -> "a", 2 -> "b").toDF("id", "data") + .write + .option("write.split-size", "10") + .mode("overwrite") + .insertInto(t1) + } + assert(captured.size === 1) + + val qe = captured.head + var collected = qe.optimizedPlan.collect { + case OverwriteByExpression(_: DataSourceV2Relation, _, _, writeOptions, _, _, _) => + assert(writeOptions("write.split-size") === "10") + } + assert (collected.size == 1) + + collected = qe.executedPlan.collect { + case OverwriteByExpressionExec(_, _, write) => + val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] + assert(append.info.options.get("write.split-size") === "10") + } + assert (collected.size == 1) + } + } + + test("SPARK-50286: Propagate options for DataFrameWriterV2 Overwrite") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string)") + sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b')") + + val captured = withQueryExecutionsCaptured(spark) { + Seq(3 -> "c", 4 -> "d").toDF("id", "data") + .writeTo(t1) + .option("write.split-size", "10") + .overwrite(lit(true)) + } + assert(captured.size === 1) + val qe = captured.head + + var collected = qe.optimizedPlan.collect { + case OverwriteByExpression(_: DataSourceV2Relation, _, _, writeOptions, _, _, _) => + assert(writeOptions("write.split-size") === "10") + } + assert (collected.size == 1) + + collected = qe.executedPlan.collect { + case OverwriteByExpressionExec(_, _, write) => + val append = write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend] + assert(append.info.options.get("write.split-size") === "10") + } + assert (collected.size == 1) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 510ea49b58418..6a659fa6e3ee9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchNamespaceException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils} import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, ColumnStat, CommandResult, OverwriteByExpression} +import org.apache.spark.sql.catalyst.plans.logical.ColumnStat import org.apache.spark.sql.catalyst.statsEstimation.StatsEstimationTestBase import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.catalog.{Column => ColumnV2, _} @@ -44,7 +44,6 @@ import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelationWithTable} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -3634,96 +3633,6 @@ class DataSourceV2SQLSuiteV1Filter } } - - test("SPARK-36680: Supports Dynamic Table Options for Spark SQL") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") - sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b')") - - var df = sql(s"SELECT * FROM $t1") - var collected = df.queryExecution.optimizedPlan.collect { - case scan: DataSourceV2ScanRelation => - assert(scan.relation.options.isEmpty) - } - assert (collected.size == 1) - checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"))) - - df = sql(s"SELECT * FROM $t1 WITH (`split-size` = 5)") - collected = df.queryExecution.optimizedPlan.collect { - case scan: DataSourceV2ScanRelation => - assert(scan.relation.options.get("split-size") == "5") - } - assert (collected.size == 1) - checkAnswer(df, Seq(Row(1, "a"), Row(2, "b"))) - - val noValues = intercept[AnalysisException]( - sql(s"SELECT * FROM $t1 WITH (`split-size`)")) - assert(noValues.message.contains( - "Operation not allowed: Values must be specified for key(s): [split-size]")) - } - } - - test("SPARK-36680: Supports Dynamic Table Options for Insert") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") - val df = sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) VALUES (1, 'a'), (2, 'b')") - - val collected = df.queryExecution.optimizedPlan.collect { - case CommandResult(_, AppendData(relation: DataSourceV2Relation, _, _, _, _, _), _, _) => - assert(relation.options.get("write.split-size") == "10") - } - assert (collected.size == 1) - - val insertResult = sql(s"SELECT * FROM $t1") - checkAnswer(insertResult, Seq(Row(1, "a"), Row(2, "b"))) - } - } - - test("SPARK-36680: Supports Dynamic Table Options for Insert Overwrite") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") - sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) VALUES (1, 'a'), (2, 'b')") - - val df = sql(s"INSERT OVERWRITE $t1 WITH (`write.split-size` = 10) " + - s"VALUES (3, 'c'), (4, 'd')") - val collected = df.queryExecution.optimizedPlan.collect { - case CommandResult(_, - OverwriteByExpression(relation: DataSourceV2Relation, _, _, _, _, _, _), - _, _) => - assert(relation.options.get("write.split-size") == "10") - } - assert (collected.size == 1) - - val insertResult = sql(s"SELECT * FROM $t1") - checkAnswer(insertResult, Seq(Row(3, "c"), Row(4, "d"))) - } - } - - test("SPARK-36680: Supports Dynamic Table Options for Insert Replace") { - val t1 = s"${catalogAndNamespace}table" - withTable(t1) { - sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") - sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) VALUES (1, 'a'), (2, 'b')") - - val df = sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) " + - s"REPLACE WHERE TRUE " + - s"VALUES (3, 'c'), (4, 'd')") - val collected = df.queryExecution.optimizedPlan.collect { - case CommandResult(_, - OverwriteByExpression(relation: DataSourceV2Relation, _, _, _, _, _, _), - _, _) => - assert(relation.options.get("write.split-size") == "10") - } - assert (collected.size == 1) - - val insertResult = sql(s"SELECT * FROM $t1") - checkAnswer(insertResult, Seq(Row(3, "c"), Row(4, "d"))) - } - } - test("SPARK-49183: custom spark_catalog generates location for managed tables") { // Reset CatalogManager to clear the materialized `spark_catalog` instance, so that we can // configure a new implementation. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 04fc7e23ebb24..68c2a01c69aea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode, SparkSession, SQLContext} -import org.apache.spark.sql.QueryTest.withPhysicalPlansCaptured +import org.apache.spark.sql.QueryTest.withQueryExecutionsCaptured import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule @@ -213,8 +213,8 @@ class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with Before .getOrCreate() def captureWrite(sparkSession: SparkSession)(thunk: => Unit): SparkPlan = { - val physicalPlans = withPhysicalPlansCaptured(sparkSession, thunk) - val v1FallbackWritePlans = physicalPlans.filter { + val queryExecutions = withQueryExecutionsCaptured(sparkSession)(thunk) + val v1FallbackWritePlans = queryExecutions.map(_.executedPlan).filter { case _: AppendDataExecV1 | _: OverwriteByExpressionExecV1 => true case _ => false }