Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -386,17 +386,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

val command = modeForDSV2 match {
case SaveMode.Append =>
AppendData.byName(table, df.logicalPlan)
AppendData.byPosition(table, df.logicalPlan)

case SaveMode.Overwrite =>
val conf = df.sparkSession.sessionState.conf
val dynamicPartitionOverwrite = table.table.partitioning.size > 0 &&
conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC

if (dynamicPartitionOverwrite) {
OverwritePartitionsDynamic.byName(table, df.logicalPlan)
OverwritePartitionsDynamic.byPosition(table, df.logicalPlan)
} else {
OverwriteByExpression.byName(table, df.logicalPlan, Literal(true))
OverwriteByExpression.byPosition(table, df.logicalPlan, Literal(true))
}

case other =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,78 +29,115 @@ class DataSourceV2DataFrameSuite extends QueryTest with SharedSQLContext with Be
before {
spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName)
spark.conf.set("spark.sql.catalog.testcat2", classOf[TestInMemoryTableCatalog].getName)

val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data")
df.createOrReplaceTempView("source")
val df2 = spark.createDataFrame(Seq((4L, "d"), (5L, "e"), (6L, "f"))).toDF("id", "data")
df2.createOrReplaceTempView("source2")
}

after {
spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog].clearTables()
spark.sql("DROP VIEW source")
spark.sql("DROP VIEW source2")
test("insertInto: append") {
val t1 = "testcat.ns1.ns2.tbl"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo")
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
df.write.insertInto(t1)
checkAnswer(spark.table(t1), df)
}
}

test("insertInto: append") {
test("insertInto: append by position") {
val t1 = "testcat.ns1.ns2.tbl"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo")
spark.table("source").select("id", "data").write.insertInto(t1)
checkAnswer(spark.table(t1), spark.table("source"))
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
val dfr = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("data", "id")
dfr.write.insertInto(t1)
checkAnswer(spark.table(t1), df)
}
}

test("insertInto: append - across catalog") {
test("insertInto: append across catalog") {
val t1 = "testcat.ns1.ns2.tbl"
val t2 = "testcat2.db.tbl"
withTable(t1, t2) {
sql(s"CREATE TABLE $t1 USING foo AS TABLE source")
sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo")
sql(s"CREATE TABLE $t2 (id bigint, data string) USING foo")
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
df.write.insertInto(t1)
spark.table(t1).write.insertInto(t2)
checkAnswer(spark.table(t2), spark.table("source"))
checkAnswer(spark.table(t2), df)
}
}

test("insertInto: append partitioned table") {
val t1 = "testcat.ns1.ns2.tbl"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)")
spark.table("source").write.insertInto(t1)
checkAnswer(spark.table(t1), spark.table("source"))
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
df.write.insertInto(t1)
checkAnswer(spark.table(t1), df)
}
}

test("insertInto: overwrite non-partitioned table") {
val t1 = "testcat.ns1.ns2.tbl"
withTable(t1) {
sql(s"CREATE TABLE $t1 USING foo AS TABLE source")
spark.table("source2").write.mode("overwrite").insertInto(t1)
checkAnswer(spark.table(t1), spark.table("source2"))
sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo")
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
val df2 = Seq((4L, "d"), (5L, "e"), (6L, "f")).toDF("id", "data")
df.write.insertInto(t1)
df2.write.mode("overwrite").insertInto(t1)
checkAnswer(spark.table(t1), df2)
}
}

test("insertInto: overwrite partitioned table in static mode") {
withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) {
val t1 = "testcat.ns1.ns2.tbl"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)")
Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data").write.insertInto(t1)
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
df.write.mode("overwrite").insertInto(t1)
checkAnswer(spark.table(t1), df)
}
}
}

test("insertInto: overwrite - static mode") {

test("insertInto: overwrite partitioned table in static mode by position") {
withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) {
val t1 = "testcat.ns1.ns2.tbl"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)")
Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data").write.insertInto(t1)
spark.table("source").write.mode("overwrite").insertInto(t1)
checkAnswer(spark.table(t1), spark.table("source"))
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
val dfr = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("data", "id")
dfr.write.mode("overwrite").insertInto(t1)
checkAnswer(spark.table(t1), df)
}
}
}

test("insertInto: overwrite partitioned table in dynamic mode") {
withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) {
val t1 = "testcat.ns1.ns2.tbl"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)")
Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data").write.insertInto(t1)
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
df.write.mode("overwrite").insertInto(t1)
checkAnswer(spark.table(t1), df.union(sql("SELECT 4L, 'keep'")))
}
}
}

test("insertInto: overwrite - dynamic mode") {
test("insertInto: overwrite partitioned table in dynamic mode by position") {
withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) {
val t1 = "testcat.ns1.ns2.tbl"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)")
Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data").write.insertInto(t1)
spark.table("source").write.mode("overwrite").insertInto(t1)
checkAnswer(spark.table(t1),
spark.table("source").union(sql("SELECT 4L, 'keep'")))
val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
val dfr = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("data", "id")
dfr.write.mode("overwrite").insertInto(t1)
checkAnswer(spark.table(t1), df.union(sql("SELECT 4L, 'keep'")))
}
}
}
Expand Down