diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 86a27b5afc25..55ca4e3624bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -605,6 +605,13 @@ object JdbcUtils extends Logging { * implementation changes elsewhere might easily render such a closure * non-Serializable. Instead, we explicitly close over all variables that * are used. + * + * Note that this method records task output metrics. It assumes the method is + * running in a task. For now, we only records the number of rows being written + * because there's no good way to measure the total bytes being written. Only + * effective outputs are taken into account: for example, metric will not be updated + * if it supports transaction and transaction is rolled back, but metric will be + * updated even with error if it doesn't support transaction, as there're dirty outputs. */ def savePartition( getConnection: () => Connection, @@ -615,7 +622,9 @@ object JdbcUtils extends Logging { batchSize: Int, dialect: JdbcDialect, isolationLevel: Int, - options: JDBCOptions): Iterator[Byte] = { + options: JDBCOptions): Unit = { + val outMetrics = TaskContext.get().taskMetrics().outputMetrics + val conn = getConnection() var committed = false @@ -643,7 +652,7 @@ object JdbcUtils extends Logging { } } val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE - + var totalRowCount = 0 try { if (supportsTransactions) { conn.setAutoCommit(false) // Everything in the same db transaction. @@ -672,6 +681,7 @@ object JdbcUtils extends Logging { } stmt.addBatch() rowCount += 1 + totalRowCount += 1 if (rowCount % batchSize == 0) { stmt.executeBatch() rowCount = 0 @@ -687,7 +697,6 @@ object JdbcUtils extends Logging { conn.commit() } committed = true - Iterator.empty } catch { case e: SQLException => val cause = e.getNextException @@ -715,9 +724,13 @@ object JdbcUtils extends Logging { // tell the user about another problem. if (supportsTransactions) { conn.rollback() + } else { + outMetrics.setRecordsWritten(totalRowCount) } conn.close() } else { + outMetrics.setRecordsWritten(totalRowCount) + // The stage must succeed. We cannot propagate any exception close() might throw. try { conn.close() @@ -840,10 +853,10 @@ object JdbcUtils extends Logging { case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n) case _ => df } - repartitionedDF.rdd.foreachPartition(iterator => savePartition( + repartitionedDF.rdd.foreachPartition { iterator => savePartition( getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, options) - ) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index b28c6531d42b..8021ef1a17a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -21,10 +21,12 @@ import java.sql.DriverManager import java.util.Properties import scala.collection.JavaConverters.propertiesAsScalaMapConverter +import scala.collection.mutable.ArrayBuffer import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} @@ -543,4 +545,57 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { }.getMessage assert(errMsg.contains("Statement was canceled or the session timed out")) } + + test("metrics") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + + runAndVerifyRecordsWritten(2) { + df.write.mode(SaveMode.Append).jdbc(url, "TEST.BASICCREATETEST", new Properties()) + } + + runAndVerifyRecordsWritten(1) { + df2.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.BASICCREATETEST", new Properties()) + } + + runAndVerifyRecordsWritten(1) { + df2.write.mode(SaveMode.Overwrite).option("truncate", true) + .jdbc(url, "TEST.BASICCREATETEST", new Properties()) + } + + runAndVerifyRecordsWritten(0) { + intercept[AnalysisException] { + df2.write.mode(SaveMode.ErrorIfExists).jdbc(url, "TEST.BASICCREATETEST", new Properties()) + } + } + + runAndVerifyRecordsWritten(0) { + df.write.mode(SaveMode.Ignore).jdbc(url, "TEST.BASICCREATETEST", new Properties()) + } + } + + private def runAndVerifyRecordsWritten(expected: Long)(job: => Unit): Unit = { + assert(expected === runAndReturnMetrics(job, _.taskMetrics.outputMetrics.recordsWritten)) + } + + private def runAndReturnMetrics(job: => Unit, collector: (SparkListenerTaskEnd) => Long): Long = { + val taskMetrics = new ArrayBuffer[Long]() + + // Avoid receiving earlier taskEnd events + sparkContext.listenerBus.waitUntilEmpty() + + val listener = new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + taskMetrics += collector(taskEnd) + } + } + sparkContext.addSparkListener(listener) + + job + + sparkContext.listenerBus.waitUntilEmpty() + + sparkContext.removeSparkListener(listener) + taskMetrics.sum + } }