diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java index 04ad8fd90be9..0e2eb9c3cabb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java @@ -26,6 +26,10 @@ * The base interface for v2 data sources which don't have a real catalog. Implementations must * have a public, 0-arg constructor. *

+ * Note that, TableProvider can only apply data operations to existing tables, like read, append, + * delete, and overwrite. It does not support the operations that require metadata changes, like + * create/drop tables. + *

* The major responsibility of this interface is to return a {@link Table} for read/write. *

*/ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsSaveMode.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsSaveMode.java deleted file mode 100644 index c4295f237187..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsSaveMode.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.writer; - -import org.apache.spark.sql.SaveMode; - -// A temporary mixin trait for `WriteBuilder` to support `SaveMode`. Will be removed before -// Spark 3.0 when all the new write operators are finished. See SPARK-26356 for more details. -public interface SupportsSaveMode extends WriteBuilder { - WriteBuilder mode(SaveMode mode); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java index aab46b078c33..eeb6a9bb84f1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java @@ -60,10 +60,6 @@ default WriteBuilder withInputDataSchema(StructType schema) { * exception, data sources must overwrite this method to provide an implementation, if the * {@link Table} that creates this write returns {@link TableCapability#BATCH_WRITE} support in * its {@link Table#capabilities()}. - * - * Note that, the returned {@link BatchWrite} can be null if the implementation supports SaveMode, - * to indicate that no writing is needed. We can clean it up after removing - * {@link SupportsSaveMode}. */ default BatchWrite buildForBatch() { throw new UnsupportedOperationException(getClass().getName() + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 18653b2bf542..0c48ec9bb465 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -30,12 +30,11 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.TableCapability._ -import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -56,13 +55,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `SaveMode.Overwrite`: overwrite the existing data.
  • *
  • `SaveMode.Append`: append the data.
  • *
  • `SaveMode.Ignore`: ignore the operation (i.e. no-op).
  • - *
  • `SaveMode.ErrorIfExists`: default option, throw an exception at runtime.
  • + *
  • `SaveMode.ErrorIfExists`: throw an exception at runtime.
  • * + *

    + * When writing to data source v1, the default option is `ErrorIfExists`. When writing to data + * source v2, the default option is `Append`. * * @since 1.4.0 */ def mode(saveMode: SaveMode): DataFrameWriter[T] = { - this.mode = saveMode + this.mode = Some(saveMode) this } @@ -78,15 +80,15 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def mode(saveMode: String): DataFrameWriter[T] = { - this.mode = saveMode.toLowerCase(Locale.ROOT) match { - case "overwrite" => SaveMode.Overwrite - case "append" => SaveMode.Append - case "ignore" => SaveMode.Ignore - case "error" | "errorifexists" | "default" => SaveMode.ErrorIfExists + saveMode.toLowerCase(Locale.ROOT) match { + case "overwrite" => mode(SaveMode.Overwrite) + case "append" => mode(SaveMode.Append) + case "ignore" => mode(SaveMode.Ignore) + case "error" | "errorifexists" => mode(SaveMode.ErrorIfExists) + case "default" => this case _ => throw new IllegalArgumentException(s"Unknown save mode: $saveMode. " + "Accepted save modes are 'overwrite', 'append', 'ignore', 'error', 'errorifexists'.") } - this } /** @@ -268,9 +270,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ provider.getTable(dsOptions) match { + // TODO (SPARK-27815): To not break existing tests, here we treat file source as a special + // case, and pass the save mode to file source directly. This hack should be removed. + case table: FileTable => + val write = table.newWriteBuilder(dsOptions).asInstanceOf[FileWriteBuilder] + .mode(modeForDSV1) // should not change default mode for file source. + .withQueryId(UUID.randomUUID().toString) + .withInputDataSchema(df.logicalPlan.schema) + .buildForBatch() + // The returned `Write` can be null, which indicates that we can skip writing. + if (write != null) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(write, df.logicalPlan) + } + } + case table: SupportsWrite if table.supports(BATCH_WRITE) => lazy val relation = DataSourceV2Relation.create(table, dsOptions) - mode match { + modeForDSV2 match { case SaveMode.Append => runCommand(df.sparkSession, "save") { AppendData.byName(relation, df.logicalPlan) @@ -282,25 +299,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true)) } - case _ => - table.newWriteBuilder(dsOptions) match { - case writeBuilder: SupportsSaveMode => - val write = writeBuilder.mode(mode) - .withQueryId(UUID.randomUUID().toString) - .withInputDataSchema(df.logicalPlan.schema) - .buildForBatch() - // It can only return null with `SupportsSaveMode`. We can clean it up after - // removing `SupportsSaveMode`. - if (write != null) { - runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(write, df.logicalPlan) - } - } - - case _ => - throw new AnalysisException( - s"data source ${table.name} does not support SaveMode $mode") - } + case other => + throw new AnalysisException(s"TableProvider implementation $source cannot be " + + s"written with $other mode, please use Append or Overwrite " + + "modes instead.") } // Streaming also uses the data source V2 API. So it may be that the data source implements @@ -328,7 +330,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { sparkSession = df.sparkSession, className = source, partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) + options = extraOptions.toMap).planForWriting(modeForDSV1, df.logicalPlan) } } @@ -377,7 +379,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { table = UnresolvedRelation(tableIdent), partition = Map.empty[String, Option[String]], query = df.logicalPlan, - overwrite = mode == SaveMode.Overwrite, + overwrite = modeForDSV1 == SaveMode.Overwrite, ifPartitionNotExists = false) } } @@ -457,7 +459,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val tableIdentWithDB = tableIdent.copy(database = Some(db)) val tableName = tableIdentWithDB.unquotedString - (tableExists, mode) match { + (tableExists, modeForDSV1) match { case (true, SaveMode.Ignore) => // Do nothing @@ -512,7 +514,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitionColumnNames = partitioningColumns.getOrElse(Nil), bucketSpec = getBucketSpec) - runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan))) + runCommand(df.sparkSession, "saveAsTable")( + CreateTable(tableDesc, modeForDSV1, Some(df.logicalPlan))) } /** @@ -718,13 +721,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { SQLExecution.withNewExecutionId(session, qe, Some(name))(qe.toRdd) } + private def modeForDSV1 = mode.getOrElse(SaveMode.ErrorIfExists) + + private def modeForDSV2 = mode.getOrElse(SaveMode.Append) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// private var source: String = df.sparkSession.sessionState.conf.defaultDataSourceName - private var mode: SaveMode = SaveMode.ErrorIfExists + private var mode: Option[SaveMode] = None private val extraOptions = new scala.collection.mutable.HashMap[String, String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 6b4efaf303c6..e4f9e49c4dd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -21,7 +21,6 @@ import java.util import scala.collection.JavaConverters._ -import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ @@ -47,13 +46,12 @@ private[noop] object NoopTable extends Table with SupportsWrite { Set( TableCapability.BATCH_WRITE, TableCapability.STREAMING_WRITE, + TableCapability.TRUNCATE, TableCapability.ACCEPT_ANY_SCHEMA).asJava } } -private[noop] object NoopWriteBuilder extends WriteBuilder - with SupportsSaveMode with SupportsTruncate { - override def mode(mode: SaveMode): WriteBuilder = this +private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsTruncate { override def truncate(): WriteBuilder = this override def buildForBatch(): BatchWrite = NoopBatchWrite override def buildForStreaming(): StreamingWrite = NoopStreamingWrite diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala index 7ff5c4182d98..eacc4cb3ac4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, DataSource, OutputWriterFactory, WriteJobDescription} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder} +import org.apache.spark.sql.sources.v2.writer.{BatchWrite, WriteBuilder} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.util.SchemaUtils @@ -43,8 +43,7 @@ abstract class FileWriteBuilder( options: CaseInsensitiveStringMap, paths: Seq[String], formatName: String, - supportsDataType: DataType => Boolean) - extends WriteBuilder with SupportsSaveMode { + supportsDataType: DataType => Boolean) extends WriteBuilder { private var schema: StructType = _ private var queryId: String = _ private var mode: SaveMode = _ @@ -59,7 +58,7 @@ abstract class FileWriteBuilder( this } - override def mode(mode: SaveMode): WriteBuilder = { + def mode(mode: SaveMode): WriteBuilder = { this.mode = mode this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 1797166bbe0b..6c771ea98832 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -26,7 +26,6 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog} import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.InternalRow @@ -36,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.sources.{AlwaysTrue, Filter} import org.apache.spark.sql.sources.v2.SupportsWrite -import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsSaveMode, SupportsTruncate, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{LongAccumulator, Utils} @@ -81,16 +80,10 @@ case class CreateTableAsSelectExec( Utils.tryWithSafeFinallyAndFailureCallbacks({ catalog.createTable(ident, query.schema, partitioning.toArray, properties.asJava) match { case table: SupportsWrite => - val builder = table.newWriteBuilder(writeOptions) - .withInputDataSchema(query.schema) - .withQueryId(UUID.randomUUID().toString) - val batchWrite = builder match { - case supportsSaveMode: SupportsSaveMode => - supportsSaveMode.mode(SaveMode.Append).buildForBatch() - - case _ => - builder.buildForBatch() - } + val batchWrite = table.newWriteBuilder(writeOptions) + .withInputDataSchema(query.schema) + .withQueryId(UUID.randomUUID().toString) + .buildForBatch() doWrite(batchWrite) @@ -116,13 +109,7 @@ case class AppendDataExec( query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { override protected def doExecute(): RDD[InternalRow] = { - val batchWrite = newWriteBuilder() match { - case builder: SupportsSaveMode => - builder.mode(SaveMode.Append).buildForBatch() - - case builder => - builder.buildForBatch() - } + val batchWrite = newWriteBuilder().buildForBatch() doWrite(batchWrite) } } @@ -152,9 +139,6 @@ case class OverwriteByExpressionExec( case builder: SupportsTruncate if isTruncate(deleteWhere) => builder.truncate().buildForBatch() - case builder: SupportsSaveMode if isTruncate(deleteWhere) => - builder.mode(SaveMode.Overwrite).buildForBatch() - case builder: SupportsOverwrite => builder.overwrite(deleteWhere).buildForBatch() @@ -185,9 +169,6 @@ case class OverwritePartitionsDynamicExec( case builder: SupportsDynamicOverwrite => builder.overwriteDynamicPartitions().buildForBatch() - case builder: SupportsSaveMode => - builder.mode(SaveMode.Overwrite).buildForBatch() - case _ => throw new SparkException(s"Table does not support dynamic partition overwrite: $table") } @@ -350,8 +331,8 @@ object DataWritingSparkTask extends Logging { } private[v2] case class DataWritingSparkTaskResult( - numRows: Long, - writerCommitMessage: WriterCommitMessage) + numRows: Long, + writerCommitMessage: WriterCommitMessage) /** * Sink progress information collected after commit. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 4e071c5af6a6..379c9c4303cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException -import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} @@ -219,14 +219,14 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) - .option("path", path).save() + .option("path", path).mode("append").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(10).select('id, -'id)) - // test with different save modes + // default save mode is append spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) - .option("path", path).mode("append").save() + .option("path", path).save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(10).union(spark.range(10)).select('id, -'id)) @@ -237,17 +237,17 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { spark.read.format(cls.getName).option("path", path).load(), spark.range(5).select('id, -'id)) - spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) - .option("path", path).mode("ignore").save() - checkAnswer( - spark.read.format(cls.getName).option("path", path).load(), - spark.range(5).select('id, -'id)) + val e = intercept[AnalysisException] { + spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) + .option("path", path).mode("ignore").save() + } + assert(e.message.contains("please use Append or Overwrite modes instead")) - val e = intercept[Exception] { + val e2 = intercept[AnalysisException] { spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("error").save() } - assert(e.getMessage.contains("data already exists")) + assert(e2.getMessage.contains("please use Append or Overwrite modes instead")) // test transaction val failingUdf = org.apache.spark.sql.functions.udf { @@ -262,10 +262,10 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } // this input data will fail to read middle way. val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i as 'j) - val e2 = intercept[SparkException] { + val e3 = intercept[SparkException] { input.write.format(cls.getName).option("path", path).mode("overwrite").save() } - assert(e2.getMessage.contains("Writing job aborted")) + assert(e3.getMessage.contains("Writing job aborted")) // make sure we don't have partial data. assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) } @@ -375,24 +375,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - test("SPARK-25700: do not read schema when writing in other modes except append and overwrite") { - withTempPath { file => - val cls = classOf[SimpleWriteOnlyDataSource] - val path = file.getCanonicalPath - val df = spark.range(5).select('id as 'i, -'id as 'j) - // non-append mode should not throw exception, as they don't access schema. - df.write.format(cls.getName).option("path", path).mode("error").save() - df.write.format(cls.getName).option("path", path).mode("ignore").save() - // append and overwrite modes will access the schema and should throw exception. - intercept[SchemaReadAttemptException] { - df.write.format(cls.getName).option("path", path).mode("append").save() - } - intercept[SchemaReadAttemptException] { - df.write.format(cls.getName).option("path", path).mode("overwrite").save() - } - } - } - test("SPARK-27411: DataSourceV2Strategy should not eliminate subquery") { withTempView("t1") { val t2 = spark.read.format(classOf[SimpleDataSourceV2].getName).load() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala index 8627bdf4ae18..3ae305655e92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala @@ -50,7 +50,7 @@ class DummyReadOnlyFileTable extends Table with SupportsRead { } override def capabilities(): java.util.Set[TableCapability] = - Set(TableCapability.BATCH_READ).asJava + Set(TableCapability.BATCH_READ, TableCapability.ACCEPT_ANY_SCHEMA).asJava } class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { @@ -73,7 +73,7 @@ class DummyWriteOnlyFileTable extends Table with SupportsWrite { throw new AnalysisException("Dummy file writer") override def capabilities(): java.util.Set[TableCapability] = - Set(TableCapability.BATCH_WRITE).asJava + Set(TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA).asJava } class FileDataSourceV2FallBackSuite extends QueryTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index edebb0b62b29..c9d2f1eef24b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -26,7 +26,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext -import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.sources.v2.reader._ @@ -70,38 +69,26 @@ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { override def readSchema(): StructType = tableSchema } - class MyWriteBuilder(path: String) extends WriteBuilder with SupportsSaveMode { + class MyWriteBuilder(path: String) extends WriteBuilder with SupportsTruncate { private var queryId: String = _ - private var mode: SaveMode = _ + private var needTruncate = false override def withQueryId(queryId: String): WriteBuilder = { this.queryId = queryId this } - override def mode(mode: SaveMode): WriteBuilder = { - this.mode = mode + override def truncate(): WriteBuilder = { + this.needTruncate = true this } override def buildForBatch(): BatchWrite = { - assert(mode != null) - val hadoopPath = new Path(path) val hadoopConf = SparkContext.getActive.get.hadoopConfiguration val fs = hadoopPath.getFileSystem(hadoopConf) - if (mode == SaveMode.ErrorIfExists) { - if (fs.exists(hadoopPath)) { - throw new RuntimeException("data already exists.") - } - } - if (mode == SaveMode.Ignore) { - if (fs.exists(hadoopPath)) { - return null - } - } - if (mode == SaveMode.Overwrite) { + if (needTruncate) { fs.delete(hadoopPath, true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index d34da330496b..5e6e3b4fc164 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -38,11 +38,15 @@ import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression} +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.execution.datasources.noop.NoopDataSource import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.util.Utils @@ -239,15 +243,75 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } test("save mode") { - val df = spark.read + spark.range(10).write .format("org.apache.spark.sql.test") - .load() + .mode(SaveMode.ErrorIfExists) + .save() + assert(LastOptions.saveMode === SaveMode.ErrorIfExists) - df.write + spark.range(10).write + .format("org.apache.spark.sql.test") + .mode(SaveMode.Append) + .save() + assert(LastOptions.saveMode === SaveMode.Append) + + // By default the save mode is `ErrorIfExists` for data source v1. + spark.range(10).write .format("org.apache.spark.sql.test") - .mode(SaveMode.ErrorIfExists) .save() assert(LastOptions.saveMode === SaveMode.ErrorIfExists) + + spark.range(10).write + .format("org.apache.spark.sql.test") + .mode("default") + .save() + assert(LastOptions.saveMode === SaveMode.ErrorIfExists) + } + + test("save mode for data source v2") { + var plan: LogicalPlan = null + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + plan = qe.analyzed + + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + } + + spark.listenerManager.register(listener) + try { + // append mode creates `AppendData` + spark.range(10).write + .format(classOf[NoopDataSource].getName) + .mode(SaveMode.Append) + .save() + sparkContext.listenerBus.waitUntilEmpty(1000) + assert(plan.isInstanceOf[AppendData]) + + // overwrite mode creates `OverwriteByExpression` + spark.range(10).write + .format(classOf[NoopDataSource].getName) + .mode(SaveMode.Overwrite) + .save() + sparkContext.listenerBus.waitUntilEmpty(1000) + assert(plan.isInstanceOf[OverwriteByExpression]) + + // By default the save mode is `ErrorIfExists` for data source v2. + spark.range(10).write + .format(classOf[NoopDataSource].getName) + .save() + sparkContext.listenerBus.waitUntilEmpty(1000) + assert(plan.isInstanceOf[AppendData]) + + spark.range(10).write + .format(classOf[NoopDataSource].getName) + .mode("default") + .save() + sparkContext.listenerBus.waitUntilEmpty(1000) + assert(plan.isInstanceOf[AppendData]) + } finally { + spark.listenerManager.unregister(listener) + } } test("test path option in load") {