Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -322,19 +322,19 @@ private[sql] class JDBCRDD(
}
}

// A `JDBCValueSetter` is responsible for converting and setting a value from `ResultSet`
// into a field for `MutableRow`. The last argument `Int` means the index for the
// value to be set in the row and also used for the value to retrieve from `ResultSet`.
private type JDBCValueSetter = (ResultSet, MutableRow, Int) => Unit
// A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
// for `MutableRow`. The last argument `Int` means the index for the value to be set in
// the row and also used for the value in `ResultSet`.
private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit

/**
* Creates `JDBCValueSetter`s according to [[StructType]], which can set
* Creates `JDBCValueGetter`s according to [[StructType]], which can set
* each value from `ResultSet` to each field of [[MutableRow]] correctly.
*/
def makeSetters(schema: StructType): Array[JDBCValueSetter] =
schema.fields.map(sf => makeSetter(sf.dataType, sf.metadata))
def makeGetters(schema: StructType): Array[JDBCValueGetter] =
schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))

private def makeSetter(dt: DataType, metadata: Metadata): JDBCValueSetter = dt match {
private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
case BooleanType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setBoolean(pos, rs.getBoolean(pos + 1))
Expand Down Expand Up @@ -489,15 +489,15 @@ private[sql] class JDBCRDD(
stmt.setFetchSize(fetchSize)
val rs = stmt.executeQuery()

val setters: Array[JDBCValueSetter] = makeSetters(schema)
val getters: Array[JDBCValueGetter] = makeGetters(schema)
val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))

def getNext(): InternalRow = {
if (rs.next()) {
inputMetrics.incRecordsRead(1)
var i = 0
while (i < setters.length) {
setters(i).apply(rs, mutableRow, i)
while (i < getters.length) {
getters(i).apply(rs, mutableRow, i)
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,79 @@ object JdbcUtils extends Logging {
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
}

// A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
// `PreparedStatement`. The last argument `Int` means the index for the value to be set
// in the SQL statement and also used for the value in `Row`.
private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please rename the read path too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!


private def makeSetter(
conn: Connection,
dialect: JdbcDialect,
dataType: DataType): JDBCValueSetter = dataType match {
case IntegerType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getInt(pos))

case LongType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setLong(pos + 1, row.getLong(pos))

case DoubleType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setDouble(pos + 1, row.getDouble(pos))

case FloatType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setFloat(pos + 1, row.getFloat(pos))

case ShortType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getShort(pos))

case ByteType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getByte(pos))

case BooleanType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBoolean(pos + 1, row.getBoolean(pos))

case StringType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setString(pos + 1, row.getString(pos))

case BinaryType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))

case TimestampType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))

case DateType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))

case t: DecimalType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBigDecimal(pos + 1, row.getDecimal(pos))

case ArrayType(et, _) =>
// remove type length parameters from end of type name
val typeName = getJdbcType(et, dialect).databaseTypeDefinition
Copy link

@jneira-stratio jneira-stratio Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Late to the party 😝 but why the type name is converted to lower case?
It makes code trying to get a JDBCType enum value using the string fails with: java.lang.IllegalArgumentException: No enum constant java.sql.JDBCType.varchar 😞
//cc @HyukjinKwon

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same as the original code

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.toLowerCase.split("\\(")(0)
(stmt: PreparedStatement, row: Row, pos: Int) =>
val array = conn.createArrayOf(
typeName,
row.getSeq[AnyRef](pos).toArray)
stmt.setArray(pos + 1, array)

case _ =>
(_: PreparedStatement, _: Row, pos: Int) =>
throw new IllegalArgumentException(
s"Can't translate non-null value for field $pos")
}

/**
* Saves a partition of a DataFrame to the JDBC database. This is done in
* a single database transaction (unless isolation level is "NONE")
Expand Down Expand Up @@ -215,6 +288,9 @@ object JdbcUtils extends Logging {
conn.setTransactionIsolation(finalIsolationLevel)
}
val stmt = insertStatement(conn, table, rddSchema, dialect)
val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
.map(makeSetter(conn, dialect, _)).toArray

try {
var rowCount = 0
while (iterator.hasNext) {
Expand All @@ -225,30 +301,7 @@ object JdbcUtils extends Logging {
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
rddSchema.fields(i).dataType match {
case IntegerType => stmt.setInt(i + 1, row.getInt(i))
case LongType => stmt.setLong(i + 1, row.getLong(i))
case DoubleType => stmt.setDouble(i + 1, row.getDouble(i))
case FloatType => stmt.setFloat(i + 1, row.getFloat(i))
case ShortType => stmt.setInt(i + 1, row.getShort(i))
case ByteType => stmt.setInt(i + 1, row.getByte(i))
case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i))
case StringType => stmt.setString(i + 1, row.getString(i))
case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i))
case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
case ArrayType(et, _) =>
// remove type length parameters from end of type name
val typeName = getJdbcType(et, dialect).databaseTypeDefinition
.toLowerCase.split("\\(")(0)
val array = conn.createArrayOf(
typeName,
row.getSeq[AnyRef](i).toArray)
stmt.setArray(i + 1, array)
case _ => throw new IllegalArgumentException(
s"Can't translate non-null value for field $i")
}
setters(i).apply(stmt, row, i)
}
i = i + 1
}
Expand Down Expand Up @@ -333,5 +386,4 @@ object JdbcUtils extends Logging {
getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
)
}

}