Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 12 additions & 22 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* @since 1.4.0
*/
def mode(saveMode: SaveMode): DataFrameWriter[T] = {
this.mode = Some(saveMode)
this.mode = saveMode
this
}

Expand Down Expand Up @@ -267,7 +267,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
"if partition columns are specified.")
}
lazy val relation = DataSourceV2Relation.create(table, dsOptions)
modeForDSV2 match {
mode match {
case SaveMode.Append =>
runCommand(df.sparkSession, "save") {
AppendData.byName(relation, df.logicalPlan, extraOptions.toMap)
Expand Down Expand Up @@ -308,7 +308,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
sparkSession = df.sparkSession,
className = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
options = extraOptions.toMap).planForWriting(modeForDSV1, df.logicalPlan)
options = extraOptions.toMap).planForWriting(mode, df.logicalPlan)
}
}

Expand Down Expand Up @@ -380,8 +380,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
DataSourceV2Relation.create(t)
}

val command = modeForDSV2 match {
case SaveMode.Append =>
val command = mode match {
case SaveMode.Append | SaveMode.ErrorIfExists | SaveMode.Ignore =>
Copy link
Contributor

@cloud-fan cloud-fan Sep 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A note to future readers: this is the old behavior, that non-overwrite mode means append. This is due to the bad design of DataFrameWriter: we only need to know overwrite or not when calling insert, but DataFrameWriter gives you a save mode. Since the default save mode is ErrorIfExists, treating non-overwrite mode as append is a reasonable compromise.

Note that, we don't have this problem in the new DataFrameWriterV2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the previous version used this:

      InsertIntoTable(
        table = UnresolvedRelation(tableIdent),
        partition = Map.empty[String, Option[String]],
        query = df.logicalPlan,
        overwrite = mode == SaveMode.Overwrite, // << Either overwrite or append
        ifPartitionNotExists = false)

So I agree that this is using the same behavior that v1 did.

AppendData.byPosition(table, df.logicalPlan, extraOptions.toMap)

case SaveMode.Overwrite =>
Expand All @@ -394,10 +394,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
} else {
OverwriteByExpression.byPosition(table, df.logicalPlan, Literal(true), extraOptions.toMap)
}

case other =>
throw new AnalysisException(s"insertInto does not support $other mode, " +
s"please use Append or Overwrite mode instead.")
}

runCommand(df.sparkSession, "insertInto") {
Expand All @@ -411,7 +407,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
table = UnresolvedRelation(tableIdent),
partitionSpec = Map.empty[String, Option[String]],
query = df.logicalPlan,
overwrite = modeForDSV1 == SaveMode.Overwrite,
overwrite = mode == SaveMode.Overwrite,
ifPartitionNotExists = false)
}
}
Expand Down Expand Up @@ -490,12 +486,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case CatalogObjectIdentifier(Some(catalog), ident) =>
saveAsTable(catalog.asTableCatalog, ident, modeForDSV2)
saveAsTable(catalog.asTableCatalog, ident)

case CatalogObjectIdentifier(None, ident) if canUseV2 && ident.namespace().length <= 1 =>
// We pass in the modeForDSV1, as using the V2 session catalog should maintain compatibility
// for now.
saveAsTable(sessionCatalog.asTableCatalog, ident, modeForDSV1)
saveAsTable(sessionCatalog.asTableCatalog, ident)

case AsTableIdentifier(tableIdentifier) =>
saveAsTable(tableIdentifier)
Expand All @@ -507,7 +501,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}


private def saveAsTable(catalog: TableCatalog, ident: Identifier, mode: SaveMode): Unit = {
private def saveAsTable(catalog: TableCatalog, ident: Identifier): Unit = {
val partitioning = partitioningColumns.map { colNames =>
colNames.map(name => IdentityTransform(FieldReference(name)))
}.getOrElse(Seq.empty[Transform])
Expand Down Expand Up @@ -568,7 +562,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val tableIdentWithDB = tableIdent.copy(database = Some(db))
val tableName = tableIdentWithDB.unquotedString

(tableExists, modeForDSV1) match {
(tableExists, mode) match {
case (true, SaveMode.Ignore) =>
// Do nothing

Expand Down Expand Up @@ -624,7 +618,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
bucketSpec = getBucketSpec)

runCommand(df.sparkSession, "saveAsTable")(
CreateTable(tableDesc, modeForDSV1, Some(df.logicalPlan)))
CreateTable(tableDesc, mode, Some(df.logicalPlan)))
}

/**
Expand Down Expand Up @@ -830,10 +824,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
SQLExecution.withNewExecutionId(session, qe, Some(name))(qe.toRdd)
}

private def modeForDSV1 = mode.getOrElse(SaveMode.ErrorIfExists)

private def modeForDSV2 = mode.getOrElse(SaveMode.Append)

private def lookupV2Provider(): Option[TableProvider] = {
DataSource.lookupDataSourceV2(source, df.sparkSession.sessionState.conf) match {
// TODO(SPARK-28396): File source v2 write path is currently broken.
Expand All @@ -848,7 +838,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

private var source: String = df.sparkSession.sessionState.conf.defaultDataSourceName

private var mode: Option[SaveMode] = None
private var mode: SaveMode = SaveMode.ErrorIfExists

private val extraOptions = new scala.collection.mutable.HashMap[String, String]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.connector

import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.{DataFrame, Row, SaveMode}

class DataSourceV2DataFrameSuite
Expand Down Expand Up @@ -75,13 +76,15 @@ class DataSourceV2DataFrameSuite
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo")
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
// Default saveMode is append, therefore this doesn't throw a table already exists exception
df.write.saveAsTable(t1)
// Default saveMode is ErrorIfExists
intercept[TableAlreadyExistsException] {
df.write.saveAsTable(t1)
}
assert(spark.table(t1).count() === 0)

// appends are by name not by position
df.select('data, 'id).write.mode("append").saveAsTable(t1)
checkAnswer(spark.table(t1), df)

// also appends are by name not by position
df.select('data, 'id).write.saveAsTable(t1)
checkAnswer(spark.table(t1), df.union(df))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,12 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession {
spark.read.format(cls.getName).option("path", path).load(),
spark.range(10).select('id, -'id))

// default save mode is append
spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName)
// default save mode is ErrorIfExists
intercept[AnalysisException] {
spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName)
.option("path", path).save()
}
spark.range(10).select('id as 'i, -'id as 'j).write.mode("append").format(cls.getName)
.option("path", path).save()
checkAnswer(
spark.read.format(cls.getName).option("path", path).load(),
Expand Down Expand Up @@ -281,7 +285,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession {

val numPartition = 6
spark.range(0, 10, 1, numPartition).select('id as 'i, -'id as 'j).write.format(cls.getName)
.option("path", path).save()
.mode("append").option("path", path).save()
checkAnswer(
spark.read.format(cls.getName).option("path", path).load(),
spark.range(10).select('id, -'id))
Expand Down Expand Up @@ -368,7 +372,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession {
val format = classOf[SimpleWritableDataSource].getName

val df = Seq((1L, 2L)).toDF("i", "j")
df.write.format(format).option("path", optionPath).save()
df.write.format(format).mode("append").option("path", optionPath).save()
assert(!new File(sessionPath).exists)
checkAnswer(spark.read.format(format).option("path", optionPath).load(), df)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class NoopSuite extends SharedSparkSession {
}
.write
.format("noop")
.mode("append")
.save()
assert(accum.value == numElems)
}
Expand All @@ -54,7 +55,7 @@ class NoopSuite extends SharedSparkSession {
accum.add(1)
x
}
.write.format("noop").save()
.write.mode("append").format("noop").save()
assert(accum.value == numElems)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,18 +289,20 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with
assert(plan.isInstanceOf[OverwriteByExpression])

// By default the save mode is `ErrorIfExists` for data source v2.
spark.range(10).write
.format(classOf[NoopDataSource].getName)
.save()
sparkContext.listenerBus.waitUntilEmpty()
assert(plan.isInstanceOf[AppendData])
val e = intercept[AnalysisException] {
spark.range(10).write
.format(classOf[NoopDataSource].getName)
.save()
}
assert(e.getMessage.contains("ErrorIfExists"))

spark.range(10).write
.format(classOf[NoopDataSource].getName)
.mode("default")
.save()
sparkContext.listenerBus.waitUntilEmpty()
assert(plan.isInstanceOf[AppendData])
val e2 = intercept[AnalysisException] {
spark.range(10).write
.format(classOf[NoopDataSource].getName)
.mode("default")
.save()
}
assert(e2.getMessage.contains("ErrorIfExists"))
} finally {
spark.listenerManager.unregister(listener)
}
Expand Down