diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
index a4e2dba53438..89618668497c 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
@@ -37,7 +37,7 @@ import org.apache.spark.tags.DockerTest
* }}}
*/
@DockerTest
-class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
+class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests {
override val db = new DatabaseOnDocker {
override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME",
"mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04")
@@ -150,8 +150,18 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
"""
|INSERT INTO bits VALUES (1, 2, 1)
""".stripMargin).executeUpdate()
+
+ conn.prepareStatement("CREATE TABLE upsert (id INT, ts DATETIME, v1 FLOAT, v2 FLOAT, " +
+ "CONSTRAINT pk_upsert PRIMARY KEY (id, ts))").executeUpdate()
+ conn.prepareStatement("INSERT INTO upsert VALUES " +
+ "(1, '1996-01-01 01:23:45', 1.234, 1.234567), " +
+ "(1, '1996-01-01 01:23:46', 1.235, 1.234568), " +
+ "(2, '1996-01-01 01:23:45', 2.345, 2.345678), " +
+ "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate()
}
+ override val createTableOption = "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)"
+
test("Basic test") {
val df = spark.read.jdbc(jdbcUrl, "tbl", new Properties)
val rows = df.collect()
@@ -429,4 +439,5 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
.load()
assert(df.collect.toSet === expectedResult)
}
+
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
index dc3acb66ff1f..d66bba64a3de 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
@@ -34,7 +34,7 @@ import org.apache.spark.tags.DockerTest
* }}}
*/
@DockerTest
-class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
+class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests {
override val db = new DatabaseOnDocker {
override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:8.0.31")
override val env = Map(
@@ -43,7 +43,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
override val usesIpc = false
override val jdbcPort: Int = 3306
override def getJdbcUrl(ip: String, port: Int): String =
- s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass"
+ s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass&allowMultiQueries=true"
}
override def dataPreparation(conn: Connection): Unit = {
@@ -72,8 +72,18 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
).executeUpdate()
conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', " +
"'jumps', 'over', 'the', 'lazy', 'dog', '{\"status\": \"merrily\"}')").executeUpdate()
+
+ conn.prepareStatement("CREATE TABLE upsert (id INTEGER, ts TIMESTAMP, v1 DOUBLE, v2 DOUBLE, " +
+ "PRIMARY KEY pk (id, ts))").executeUpdate()
+ conn.prepareStatement("INSERT INTO upsert VALUES " +
+ "(1, '1996-01-01 01:23:45', 1.234, 1.234567), " +
+ "(1, '1996-01-01 01:23:46', 1.235, 1.234568), " +
+ "(2, '1996-01-01 01:23:45', 2.345, 2.345678), " +
+ "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate()
}
+ override val createTableOption = "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)"
+
test("Basic test") {
val df = sqlContext.read.jdbc(jdbcUrl, "tbl", new Properties)
val rows = df.collect()
@@ -194,4 +204,5 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
""".stripMargin.replaceAll("\n", " "))
assert(sql("select x, y from queryOption").collect.toSet == expectedResult)
}
+
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
index c539452bb9ae..f1ff97beddb2 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
@@ -23,8 +23,7 @@ import java.text.SimpleDateFormat
import java.time.LocalDateTime
import java.util.Properties
-import org.apache.spark.sql.Column
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.types.{ArrayType, DecimalType, FloatType, ShortType}
import org.apache.spark.tags.DockerTest
@@ -38,7 +37,7 @@ import org.apache.spark.tags.DockerTest
* }}}
*/
@DockerTest
-class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
+class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests {
override val db = new DatabaseOnDocker {
override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:15.1-alpine")
override val env = Map(
@@ -154,8 +153,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
conn.prepareStatement("INSERT INTO custom_type (type_array, type) VALUES" +
"('{1,fds,fdsa}','fdasfasdf')").executeUpdate()
+ conn.prepareStatement("CREATE TABLE upsert (id integer, ts timestamp, v1 double precision, " +
+ "v2 double precision, CONSTRAINT pk PRIMARY KEY (id, ts))").executeUpdate()
+ conn.prepareStatement("INSERT INTO upsert VALUES " +
+ "(1, '1996-01-01 01:23:45', 1.234, 1.234567), " +
+ "(1, '1996-01-01 01:23:46', 1.235, 1.234568), " +
+ "(2, '1996-01-01 01:23:45', 2.345, 2.345678), " +
+ "(2, '1996-01-01 01:23:46', 2.346, 2.345679)").executeUpdate()
}
+ override val createTableOption = "; ALTER TABLE new_upsert_table ADD PRIMARY KEY (id, ts)"
+
test("Type mapping for various types") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
val rows = df.collect().sortBy(_.toString())
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala
new file mode 100644
index 000000000000..7230b1a7001c
--- /dev/null
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/UpsertTests.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.sql.Timestamp
+import java.util.Properties
+
+import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.functions.{lit, rand, when}
+
+trait UpsertTests {
+ self: DockerJDBCIntegrationSuite =>
+
+ import testImplicits._
+
+ def createTableOption: String
+ def upsertTestOptions: Map[String, String] = Map("createTableOptions" -> createTableOption)
+
+ test(s"Upsert existing table") { doTestUpsert(true) }
+ test(s"Upsert non-existing table") { doTestUpsert(false) }
+
+ def doTestUpsert(tableExists: Boolean): Unit = {
+ val df = Seq(
+ (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568), // row unchanged
+ (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678), // updates v1
+ (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680), // updates v1 and v2
+ (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789) // inserts new row
+ ).toDF("id", "ts", "v1", "v2").repartition(1) // .repartition(10)
+
+ val table = if (tableExists) "upsert" else "new_upsert_table"
+ val options = upsertTestOptions ++ Map(
+ "numPartitions" -> "10",
+ "upsert" -> "true",
+ "upsertKeyColumns" -> "id, ts"
+ )
+ df.write.mode(SaveMode.Append).options(options).jdbc(jdbcUrl, table, new Properties)
+
+ val actual = spark.read.jdbc(jdbcUrl, table, new Properties).collect.toSet
+ val existing = if (tableExists) {
+ Set((1, Timestamp.valueOf("1996-01-01 01:23:45"), 1.234, 1.234567))
+ } else {
+ Set.empty
+ }
+ val upsertedRows = Set(
+ (1, Timestamp.valueOf("1996-01-01 01:23:46"), 1.235, 1.234568),
+ (2, Timestamp.valueOf("1996-01-01 01:23:45"), 2.346, 2.345678),
+ (2, Timestamp.valueOf("1996-01-01 01:23:46"), 2.347, 2.345680),
+ (3, Timestamp.valueOf("1996-01-01 01:23:45"), 3.456, 3.456789)
+ )
+ val expected = (existing ++ upsertedRows).map { case (id, ts, v1, v2) =>
+ Row(Integer.valueOf(id), ts, v1.doubleValue(), v2.doubleValue())
+ }
+ assert(actual === expected)
+ }
+
+ test(s"Upsert concurrency") {
+ // create a table with 100k rows
+ val init =
+ spark.range(100000)
+ .withColumn("ts", lit(Timestamp.valueOf("2023-06-07 12:34:56")))
+ .withColumn("v", rand())
+
+ // upsert 100 batches of 100 rows each
+ // run in 32 tasks
+ val patch =
+ spark.range(100)
+ .join(spark.range(100).select(($"id" * 1000).as("offset")))
+ .repartition(32)
+ .select(
+ ($"id" + $"offset").as("id"),
+ lit(Timestamp.valueOf("2023-06-07 12:34:56")).as("ts"),
+ lit(-1.0).as("v")
+ )
+
+ spark.sparkContext.setJobDescription("init")
+ init
+ .write
+ .mode(SaveMode.Overwrite)
+ .option("createTableOptions", createTableOption)
+ .jdbc(jdbcUrl, "new_upsert_table", new Properties)
+
+ spark.sparkContext.setJobDescription("patch")
+ patch
+ .write
+ .mode(SaveMode.Append)
+ .option("upsert", true)
+ .option("upsertKeyColumns", "id, ts")
+ .options(upsertTestOptions)
+ .jdbc(jdbcUrl, "new_upsert_table", new Properties)
+
+ // check result table has 100*100 updated rows
+ val result = spark.read.jdbc(jdbcUrl, "new_upsert_table", new Properties)
+ .select($"id", when($"v" === -1.0, true).otherwise(false).as("patched"))
+ .groupBy($"patched")
+ .count()
+ .sort($"patched")
+ .as[(Boolean, Long)]
+ .collect()
+ assert(result === Seq((false, 90000), (true, 10000)))
+ }
+
+ test("Upsert null values") {}
+ test("Write with unspecified mode with upsert") {}
+ test("Write with overwrite mode with upsert") {}
+ test("Write with error-if-exists mode with upsert") {}
+ test("Write with ignore mode with upsert") {}
+}
diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md
index f96776514c67..4cae6b23759a 100644
--- a/docs/sql-data-sources-jdbc.md
+++ b/docs/sql-data-sources-jdbc.md
@@ -261,6 +261,19 @@ logging into the data sources.
write |
+
+ upsert, upsertKeyColumns |
+
+ These options are JDBC writer related options. They describe how to
+ use UPSERT feature for different JDBC dialects. The upsert option is applicable only when SaveMode.Append is enabled.
+ Set upsert to true to enable upsert append mode. The database is queried for the primary index to detect
+ the upsert key columns that are used to identify rows for update. The upsert key columns can be
+ defined via the upsertKeyColumns as a comma-separated list of column names.
+ Be aware that if the input data set has duplicate rows, the upsert operation is
+ non-deterministic, it is documented at the [upsert(merge) wiki:](https://en.wikipedia.org/wiki/Merge_(SQL)).
+ |
+
+
customSchema |
(none) |
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 94b8ee25dd2a..60de618be67a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -1344,6 +1344,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase {
tableDoesNotSupportError("truncates", table)
}
+ def tableDoesNotSupportUpsertsError(table: String): Throwable = {
+ new AnalysisException(
+ errorClass = "_LEGACY_ERROR_TEMP_1121",
+ messageParameters = Map(
+ "cmd" -> "upserts",
+ "table" -> table))
+ }
+
def tableDoesNotSupportPartitionManagementError(table: Table): Throwable = {
tableDoesNotSupportError("partition management", table)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 268a65b81ff6..6b02bf38cdce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -158,8 +158,12 @@ class JDBCOptions(
// ------------------------------------------------------------
// if to truncate the table from the JDBC database
val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean
-
val isCascadeTruncate: Option[Boolean] = parameters.get(JDBC_CASCADE_TRUNCATE).map(_.toBoolean)
+ // if to upsert the table in the JDBC database
+ val isUpsert = parameters.getOrElse(JDBC_UPSERT, "false").toBoolean
+ // the columns used to identify update and insert rows in upsert mode
+ val upsertKeyColumns = parameters.getOrElse(JDBC_UPSERT_KEY_COLUMNS, "").split(",").map(_.trim)
+
// the create table option , which can be table_options or partition_options.
// E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
// TODO: to reuse the existing partition parameters for those partition specific options
@@ -284,6 +288,8 @@ object JDBCOptions {
val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize")
val JDBC_TRUNCATE = newOption("truncate")
val JDBC_CASCADE_TRUNCATE = newOption("cascadeTruncate")
+ val JDBC_UPSERT = newOption("upsert")
+ val JDBC_UPSERT_KEY_COLUMNS = newOption("upsertKeyColumns")
val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes")
val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index 2760c7ac3019..43b56744d9a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -49,6 +49,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
val dialect = JdbcDialects.get(options.url)
val conn = dialect.createConnectionFactory(options)(-1)
try {
+ val upsert = mode == SaveMode.Append && options.isUpsert
val tableExists = JdbcUtils.tableExists(conn, options)
if (tableExists) {
mode match {
@@ -57,17 +58,20 @@ class JdbcRelationProvider extends CreatableRelationProvider
// In this case, we should truncate table and then load.
truncateTable(conn, options)
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
- saveTable(df, tableSchema, isCaseSensitive, options)
+ saveTable(df, tableSchema, isCaseSensitive, upsert, options)
} else {
// Otherwise, do not truncate the table, instead drop and recreate it
dropTable(conn, options.table, options)
createTable(conn, options.table, df.schema, isCaseSensitive, options)
- saveTable(df, Some(df.schema), isCaseSensitive, options)
+ saveTable(df, Some(df.schema), isCaseSensitive, upsert, options)
}
case SaveMode.Append =>
+ if (options.isUpsert && !dialect.supportsUpsert) {
+ throw QueryCompilationErrors.tableDoesNotSupportUpsertsError(options.table)
+ }
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
- saveTable(df, tableSchema, isCaseSensitive, options)
+ saveTable(df, tableSchema, isCaseSensitive, upsert, options)
case SaveMode.ErrorIfExists =>
throw QueryCompilationErrors.tableOrViewAlreadyExistsError(options.table)
@@ -79,7 +83,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
}
} else {
createTable(conn, options.table, df.schema, isCaseSensitive, options)
- saveTable(df, Some(df.schema), isCaseSensitive, options)
+ saveTable(df, Some(df.schema), isCaseSensitive, upsert, options)
}
} finally {
conn.close()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index d907ce6b100c..b54a7edf205d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -104,17 +104,12 @@ object JdbcUtils extends Logging with SQLConfHelper {
JdbcDialects.get(url).isCascadingTruncateTable()
}
- /**
- * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn.
- */
- def getInsertStatement(
- table: String,
+ protected def getInsertColumns(
rddSchema: StructType,
tableSchema: Option[StructType],
- isCaseSensitive: Boolean,
- dialect: JdbcDialect): String = {
- val columns = if (tableSchema.isEmpty) {
- rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
+ dialect: JdbcDialect): Array[String] = {
+ if (tableSchema.isEmpty) {
+ rddSchema.fields.map(x => dialect.quoteIdentifier(x.name))
} else {
// The generated insert statement needs to follow rddSchema's column sequence and
// tableSchema's column names. When appending data into some case-sensitive DBMSs like
@@ -126,10 +121,33 @@ object JdbcUtils extends Logging with SQLConfHelper {
throw QueryCompilationErrors.columnNotFoundInSchemaError(col, tableSchema)
}
dialect.quoteIdentifier(normalizedName)
- }.mkString(",")
+ }
}
- val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
- s"INSERT INTO $table ($columns) VALUES ($placeholders)"
+ }
+
+ /**
+ * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn.
+ */
+ def getInsertStatement(
+ table: String,
+ rddSchema: StructType,
+ tableSchema: Option[StructType],
+ isCaseSensitive: Boolean,
+ dialect: JdbcDialect): String = {
+ val columns = getInsertColumns(rddSchema, tableSchema, dialect)
+ val placeholders = columns.map(_ => "?")
+ s"INSERT INTO $table (${columns.mkString(",")}) VALUES (${placeholders.mkString(",")})"
+ }
+
+ def getUpsertStatement(
+ table: String,
+ rddSchema: StructType,
+ tableSchema: Option[StructType],
+ isCaseSensitive: Boolean,
+ dialect: JdbcDialect,
+ options: JDBCOptions): String = {
+ val columns = getInsertColumns(rddSchema, tableSchema, dialect)
+ dialect.getUpsertStatement(table, columns, rddSchema.fields.map(_.dataType), isCaseSensitive, options)
}
/**
@@ -878,6 +896,7 @@ object JdbcUtils extends Logging with SQLConfHelper {
df: DataFrame,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
+ upsert: Boolean,
options: JdbcOptionsInWrite): Unit = {
val url = options.url
val table = options.table
@@ -886,7 +905,12 @@ object JdbcUtils extends Logging with SQLConfHelper {
val batchSize = options.batchSize
val isolationLevel = options.isolationLevel
- val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
+ val insertStmt = if (upsert) {
+ getUpsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect, options)
+ } else {
+ getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
+ }
+
val repartitionedDF = options.numPartitions match {
case Some(n) if n <= 0 => throw QueryExecutionErrors.invalidJdbcNumPartitionsError(
n, JDBCOptions.JDBC_NUM_PARTITIONS)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala
index 7449f66ee020..1b0bf0f78815 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala
@@ -42,7 +42,8 @@ case class JDBCWriteBuilder(schema: StructType, options: JdbcOptionsInWrite) ext
val conn = dialect.createConnectionFactory(options)(-1)
JdbcUtils.truncateTable(conn, options)
}
- JdbcUtils.saveTable(data, Some(schema), SQLConf.get.caseSensitiveAnalysis, options)
+ JdbcUtils.saveTable(
+ data, Some(schema), SQLConf.get.caseSensitiveAnalysis, upsert = false, options)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 93a311be2f86..1f85a3f12c6c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -230,6 +230,18 @@ abstract class JdbcDialect extends Serializable with Logging {
s"TRUNCATE TABLE $table"
}
+ @Since("3.5.0")
+ def supportsUpsert(): Boolean = false
+
+ @Since("3.5.0")
+ def getUpsertStatement(
+ tableName: String,
+ columns: Array[String],
+ types: Array[DataType],
+ isCaseSensitive: Boolean,
+ options: JDBCOptions): String =
+ throw new UnsupportedOperationException("upserts are not supported")
+
/**
* Override connection specific properties to run before a select is made. This is in place to
* allow dialects that need special treatment to optimize behavior.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
index 78ec3ac42d79..4c39641b3401 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.expressions.{Expression, NullOrdering, SortDirection}
import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -129,6 +129,55 @@ private object MsSqlServerDialect extends JdbcDialect {
case _ => None
}
+ override def supportsUpsert(): Boolean = true
+
+ override def getUpsertStatement(
+ tableName: String,
+ columns: Array[String],
+ types: Array[DataType],
+ isCaseSensitive: Boolean,
+ options: JDBCOptions): String = {
+ val insertColumns = columns.mkString(", ")
+ val inputs = types
+ .map(t => JdbcUtils.getJdbcType(t, this).databaseTypeDefinition)
+ .zipWithIndex.map {
+ case (t, idx) => s"DECLARE @param$idx $t; SET @param$idx = ?;"
+ }.mkString("\n")
+ val values = columns.indices.map(i => s"@param$i").mkString(", ")
+ val quotedUpsertKeyColumns = options.upsertKeyColumns.map(quoteIdentifier)
+ val keyColumns = columns.zipWithIndex.filter {
+ case (col, _) => quotedUpsertKeyColumns.contains(col)
+ }
+ val updateColumns = columns.zipWithIndex.filterNot {
+ case (col, _) => quotedUpsertKeyColumns.contains(col)
+ }
+ val whereClause = keyColumns.map {
+ case (key, idx) => s"$key = @param$idx"
+ }.mkString(" AND ")
+ val updateClause = updateColumns.map {
+ case (col, idx) => s"$col = @param$idx"
+ }.mkString(", ")
+
+ s"""
+ |$inputs
+ |
+ |INSERT $tableName ($insertColumns)
+ |SELECT $values
+ |WHERE NOT EXISTS (
+ | SELECT 1
+ | FROM $tableName WITH (UPDLOCK, SERIALIZABLE)
+ | WHERE $whereClause
+ |)
+ |
+ |IF (@@ROWCOUNT = 0)
+ |BEGIN
+ | UPDATE TOP (1) $tableName
+ | SET $updateClause
+ | WHERE $whereClause
+ |END
+ |""".stripMargin
+ }
+
override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
// scalastyle:off line.size.limit
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
index d6edb67e57e4..93f3cb78487d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
@@ -132,6 +132,28 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper {
s"SELECT 1 FROM $table LIMIT 1"
}
+ override def supportsUpsert(): Boolean = true
+
+ override def getUpsertStatement(
+ tableName: String,
+ columns: Array[String],
+ types: Array[DataType],
+ isCaseSensitive: Boolean,
+ options: JDBCOptions): String = {
+ val insertColumns = columns.mkString(", ")
+ val placeholders = columns.map(_ => "?").mkString(",")
+ val upsertKeyColumns = options.upsertKeyColumns.map(quoteIdentifier)
+ val updateColumns = columns.filterNot(upsertKeyColumns.contains)
+ val updateClause =
+ updateColumns.map(x => s"$x = VALUES($x)").mkString(", ")
+
+ s"""
+ |INSERT INTO $tableName ($insertColumns)
+ |VALUES ( $placeholders )
+ |ON DUPLICATE KEY UPDATE $updateClause
+ |""".stripMargin
+ }
+
override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
// See https://dev.mysql.com/doc/refman/8.0/en/alter-table.html
@@ -262,7 +284,7 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper {
}
} catch {
case _: Exception =>
- logWarning("Cannot retrieved index info.")
+ logWarning("Cannot retrieve index info.")
}
indexMap.values.toArray
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index ab8b1a7e1a50..bdbe68626990 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -133,6 +133,29 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper {
s"SELECT 1 FROM $table LIMIT 1"
}
+ override def supportsUpsert(): Boolean = true
+
+ override def getUpsertStatement(
+ tableName: String,
+ columns: Array[String],
+ types: Array[DataType],
+ isCaseSensitive: Boolean,
+ options: JDBCOptions): String = {
+ val insertColumns = columns.mkString(", ")
+ val placeholders = columns.map(_ => "?").mkString(",")
+ val upsertKeyColumns = options.upsertKeyColumns.map(quoteIdentifier)
+ val updateColumns = columns.filterNot(upsertKeyColumns.contains)
+ val updateClause =
+ updateColumns.map(x => s"$x = EXCLUDED.$x").mkString(", ")
+
+ s"""
+ |INSERT INTO $tableName ($insertColumns)
+ |VALUES ( $placeholders )
+ |ON CONFLICT (${upsertKeyColumns.mkString(", ")})
+ |DO UPDATE SET $updateClause
+ |""".stripMargin
+ }
+
override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 93b6652d516c..9e258a19cb3f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.{DataSourceScanExec, ExtendedMode, Project
import org.apache.spark.sql.execution.command.{ExplainCommand, ShowCreateTableCommand}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation, JdbcUtils}
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions.JDBC_UPSERT_KEY_COLUMNS
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
@@ -1113,6 +1114,65 @@ class JDBCSuite extends QueryTest with SharedSparkSession {
assert(db2.getTruncateQuery(table, Some(true)) == db2Query)
}
+ Seq(
+ (JdbcDialects.get("jdbc:mysql://127.0.0.1/db"),
+ """
+ |INSERT INTO table (`id`, `time`, `value`, `comment`)
+ |VALUES ( ?,?,?,? )
+ |ON DUPLICATE KEY UPDATE `value` = VALUES(`value`), `comment` = VALUES(`comment`)
+ |""".stripMargin),
+ (JdbcDialects.get("jdbc:postgresql://127.0.0.1/db"),
+ """
+ |INSERT INTO table ("id", "time", "value", "comment")
+ |VALUES ( ?,?,?,? )
+ |ON CONFLICT ("id", "time")
+ |DO UPDATE SET "value" = EXCLUDED."value", "comment" = EXCLUDED."comment"
+ |""".stripMargin),
+ (JdbcDialects.get("jdbc:sqlserver://localhost/db"),
+ """
+ |DECLARE @param0 BIGINT; SET @param0 = ?;
+ |DECLARE @param1 DATETIME; SET @param1 = ?;
+ |DECLARE @param2 DOUBLE PRECISION; SET @param2 = ?;
+ |DECLARE @param3 NVARCHAR(MAX); SET @param3 = ?;
+ |
+ |INSERT table ("id", "time", "value", "comment")
+ |SELECT @param0, @param1, @param2, @param3
+ |WHERE NOT EXISTS (
+ | SELECT 1
+ | FROM table WITH (UPDLOCK, SERIALIZABLE)
+ | WHERE "id" = @param0 AND "time" = @param1
+ |)
+ |
+ |IF (@@ROWCOUNT = 0)
+ |BEGIN
+ | UPDATE TOP (1) table
+ | SET "value" = @param2, "comment" = @param3
+ | WHERE "id" = @param0 AND "time" = @param1
+ |END
+ |""".stripMargin)
+ ).foreach { case (dialect, expected) =>
+ test(s"upsert table query by dialect - ${dialect.getClass.getSimpleName.stripSuffix("$")}") {
+ assert(dialect.supportsUpsert() === true)
+
+ val options = {
+ new JDBCOptions(Map(
+ JDBC_UPSERT_KEY_COLUMNS -> "id, time",
+ JDBCOptions.JDBC_URL -> url,
+ JDBCOptions.JDBC_TABLE_NAME -> "table"
+ ))
+ }
+
+ val table = "table"
+ val columns = Array("id", "time", "value", "comment")
+ val quotedColumns = columns.map(dialect.quoteIdentifier)
+ val types: Array[DataType] = Array(LongType, TimestampType, DoubleType, StringType)
+ val isCaseSensitive = false
+ val stmt = dialect.getUpsertStatement(table, quotedColumns, types, isCaseSensitive, options)
+
+ assert(stmt === expected)
+ }
+ }
+
test("Test DataFrame.where for Date and Timestamp") {
// Regression test for bug SPARK-11788
val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543");