Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,13 @@ object JdbcUtils extends Logging {
* implementation changes elsewhere might easily render such a closure
* non-Serializable. Instead, we explicitly close over all variables that
* are used.
*
* Note that this method records task output metrics. It assumes the method is
* running in a task. For now, we only records the number of rows being written
* because there's no good way to measure the total bytes being written. Only
* effective outputs are taken into account: for example, metric will not be updated
* if it supports transaction and transaction is rolled back, but metric will be
* updated even with error if it doesn't support transaction, as there're dirty outputs.
*/
def savePartition(
getConnection: () => Connection,
Expand All @@ -615,7 +622,9 @@ object JdbcUtils extends Logging {
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int,
options: JDBCOptions): Iterator[Byte] = {
options: JDBCOptions): Unit = {
val outMetrics = TaskContext.get().taskMetrics().outputMetrics

val conn = getConnection()
var committed = false

Expand Down Expand Up @@ -643,7 +652,7 @@ object JdbcUtils extends Logging {
}
}
val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE

var totalRowCount = 0
try {
if (supportsTransactions) {
conn.setAutoCommit(false) // Everything in the same db transaction.
Expand Down Expand Up @@ -672,6 +681,7 @@ object JdbcUtils extends Logging {
}
stmt.addBatch()
rowCount += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot move rowCount outside try then just use it for the metric?

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Oct 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used to determine whether it needs one more flush or not at the end of iterating. It can just be a boolean flag, but we should have one specific variable for taking this into account anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ur, I see.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave some comments somewhere about the policy to collect metrics?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added it. 6e908d1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

totalRowCount += 1
if (rowCount % batchSize == 0) {
stmt.executeBatch()
rowCount = 0
Expand All @@ -687,7 +697,6 @@ object JdbcUtils extends Logging {
conn.commit()
}
committed = true
Iterator.empty
} catch {
case e: SQLException =>
val cause = e.getNextException
Expand Down Expand Up @@ -715,9 +724,13 @@ object JdbcUtils extends Logging {
// tell the user about another problem.
if (supportsTransactions) {
conn.rollback()
} else {
outMetrics.setRecordsWritten(totalRowCount)
}
conn.close()
} else {
outMetrics.setRecordsWritten(totalRowCount)

// The stage must succeed. We cannot propagate any exception close() might throw.
try {
conn.close()
Expand Down Expand Up @@ -840,10 +853,10 @@ object JdbcUtils extends Logging {
case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n)
case _ => df
}
repartitionedDF.rdd.foreachPartition(iterator => savePartition(
repartitionedDF.rdd.foreachPartition { iterator => savePartition(
getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel,
options)
)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ import java.sql.DriverManager
import java.util.Properties

import scala.collection.JavaConverters.propertiesAsScalaMapConverter
import scala.collection.mutable.ArrayBuffer

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
Expand Down Expand Up @@ -543,4 +545,57 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter {
}.getMessage
assert(errMsg.contains("Statement was canceled or the session timed out"))
}

test("metrics") {
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2)

runAndVerifyRecordsWritten(2) {
df.write.mode(SaveMode.Append).jdbc(url, "TEST.BASICCREATETEST", new Properties())
}

runAndVerifyRecordsWritten(1) {
df2.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.BASICCREATETEST", new Properties())
}

runAndVerifyRecordsWritten(1) {
df2.write.mode(SaveMode.Overwrite).option("truncate", true)
.jdbc(url, "TEST.BASICCREATETEST", new Properties())
}

runAndVerifyRecordsWritten(0) {
intercept[AnalysisException] {
df2.write.mode(SaveMode.ErrorIfExists).jdbc(url, "TEST.BASICCREATETEST", new Properties())
}
}

runAndVerifyRecordsWritten(0) {
df.write.mode(SaveMode.Ignore).jdbc(url, "TEST.BASICCREATETEST", new Properties())
}
}

private def runAndVerifyRecordsWritten(expected: Long)(job: => Unit): Unit = {
assert(expected === runAndReturnMetrics(job, _.taskMetrics.outputMetrics.recordsWritten))
}

private def runAndReturnMetrics(job: => Unit, collector: (SparkListenerTaskEnd) => Long): Long = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied from InputOutputMetricsSuite - please let me know if it should be extracted with some utility class/object.

val taskMetrics = new ArrayBuffer[Long]()

// Avoid receiving earlier taskEnd events
sparkContext.listenerBus.waitUntilEmpty()

val listener = new SparkListener() {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
taskMetrics += collector(taskEnd)
}
}
sparkContext.addSparkListener(listener)

job

sparkContext.listenerBus.waitUntilEmpty()

sparkContext.removeSparkListener(listener)
taskMetrics.sum
}
}