-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-16675][SQL] Avoid per-record type dispatch in JDBC when writing #14323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4284d46
8cac7de
532b3b1
81d8aca
c33bb62
f2be8a4
ab4a1cf
fb0f9a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -154,6 +154,79 @@ object JdbcUtils extends Logging { | |
| throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}")) | ||
| } | ||
|
|
||
| // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for | ||
| // `PreparedStatement`. The last argument `Int` means the index for the value to be set | ||
| // in the SQL statement and also used for the value in `Row`. | ||
| private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit | ||
|
|
||
| private def makeSetter( | ||
| conn: Connection, | ||
| dialect: JdbcDialect, | ||
| dataType: DataType): JDBCValueSetter = dataType match { | ||
| case IntegerType => | ||
| (stmt: PreparedStatement, row: Row, pos: Int) => | ||
| stmt.setInt(pos + 1, row.getInt(pos)) | ||
|
|
||
| case LongType => | ||
| (stmt: PreparedStatement, row: Row, pos: Int) => | ||
| stmt.setLong(pos + 1, row.getLong(pos)) | ||
|
|
||
| case DoubleType => | ||
| (stmt: PreparedStatement, row: Row, pos: Int) => | ||
| stmt.setDouble(pos + 1, row.getDouble(pos)) | ||
|
|
||
| case FloatType => | ||
| (stmt: PreparedStatement, row: Row, pos: Int) => | ||
| stmt.setFloat(pos + 1, row.getFloat(pos)) | ||
|
|
||
| case ShortType => | ||
| (stmt: PreparedStatement, row: Row, pos: Int) => | ||
| stmt.setInt(pos + 1, row.getShort(pos)) | ||
|
|
||
| case ByteType => | ||
| (stmt: PreparedStatement, row: Row, pos: Int) => | ||
| stmt.setInt(pos + 1, row.getByte(pos)) | ||
|
|
||
| case BooleanType => | ||
| (stmt: PreparedStatement, row: Row, pos: Int) => | ||
| stmt.setBoolean(pos + 1, row.getBoolean(pos)) | ||
|
|
||
| case StringType => | ||
| (stmt: PreparedStatement, row: Row, pos: Int) => | ||
| stmt.setString(pos + 1, row.getString(pos)) | ||
|
|
||
| case BinaryType => | ||
| (stmt: PreparedStatement, row: Row, pos: Int) => | ||
| stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos)) | ||
|
|
||
| case TimestampType => | ||
| (stmt: PreparedStatement, row: Row, pos: Int) => | ||
| stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos)) | ||
|
|
||
| case DateType => | ||
| (stmt: PreparedStatement, row: Row, pos: Int) => | ||
| stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos)) | ||
|
|
||
| case t: DecimalType => | ||
| (stmt: PreparedStatement, row: Row, pos: Int) => | ||
| stmt.setBigDecimal(pos + 1, row.getDecimal(pos)) | ||
|
|
||
| case ArrayType(et, _) => | ||
| // remove type length parameters from end of type name | ||
| val typeName = getJdbcType(et, dialect).databaseTypeDefinition | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Late to the party 😝 but why the type name is converted to lower case?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the same as the original code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah from https://github.com/apache/spark/pull/14323/files#diff-c3859e97335ead4b131263565c987d877bea0af3adbd6c5bf2d3716768d2e083L244 |
||
| .toLowerCase.split("\\(")(0) | ||
| (stmt: PreparedStatement, row: Row, pos: Int) => | ||
| val array = conn.createArrayOf( | ||
| typeName, | ||
| row.getSeq[AnyRef](pos).toArray) | ||
| stmt.setArray(pos + 1, array) | ||
|
|
||
| case _ => | ||
| (_: PreparedStatement, _: Row, pos: Int) => | ||
| throw new IllegalArgumentException( | ||
| s"Can't translate non-null value for field $pos") | ||
| } | ||
|
|
||
| /** | ||
| * Saves a partition of a DataFrame to the JDBC database. This is done in | ||
| * a single database transaction (unless isolation level is "NONE") | ||
|
|
@@ -215,6 +288,9 @@ object JdbcUtils extends Logging { | |
| conn.setTransactionIsolation(finalIsolationLevel) | ||
| } | ||
| val stmt = insertStatement(conn, table, rddSchema, dialect) | ||
| val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType) | ||
| .map(makeSetter(conn, dialect, _)).toArray | ||
|
|
||
| try { | ||
| var rowCount = 0 | ||
| while (iterator.hasNext) { | ||
|
|
@@ -225,30 +301,7 @@ object JdbcUtils extends Logging { | |
| if (row.isNullAt(i)) { | ||
| stmt.setNull(i + 1, nullTypes(i)) | ||
| } else { | ||
| rddSchema.fields(i).dataType match { | ||
| case IntegerType => stmt.setInt(i + 1, row.getInt(i)) | ||
| case LongType => stmt.setLong(i + 1, row.getLong(i)) | ||
| case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) | ||
| case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) | ||
| case ShortType => stmt.setInt(i + 1, row.getShort(i)) | ||
| case ByteType => stmt.setInt(i + 1, row.getByte(i)) | ||
| case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) | ||
| case StringType => stmt.setString(i + 1, row.getString(i)) | ||
| case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) | ||
| case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) | ||
| case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) | ||
| case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) | ||
| case ArrayType(et, _) => | ||
| // remove type length parameters from end of type name | ||
| val typeName = getJdbcType(et, dialect).databaseTypeDefinition | ||
| .toLowerCase.split("\\(")(0) | ||
| val array = conn.createArrayOf( | ||
| typeName, | ||
| row.getSeq[AnyRef](i).toArray) | ||
| stmt.setArray(i + 1, array) | ||
| case _ => throw new IllegalArgumentException( | ||
| s"Can't translate non-null value for field $i") | ||
| } | ||
| setters(i).apply(stmt, row, i) | ||
| } | ||
| i = i + 1 | ||
| } | ||
|
|
@@ -333,5 +386,4 @@ object JdbcUtils extends Logging { | |
| getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel) | ||
| ) | ||
| } | ||
|
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please rename the read path too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!