diff --git a/streamingpro-mlsql/src/main/java/streaming/core/datasource/impl/MLSQLDirectJDBC.scala b/streamingpro-mlsql/src/main/java/streaming/core/datasource/impl/MLSQLDirectJDBC.scala index 2f71d382..9faf5d10 100644 --- a/streamingpro-mlsql/src/main/java/streaming/core/datasource/impl/MLSQLDirectJDBC.scala +++ b/streamingpro-mlsql/src/main/java/streaming/core/datasource/impl/MLSQLDirectJDBC.scala @@ -1,11 +1,11 @@ package streaming.core.datasource.impl import java.util.Properties - import com.alibaba.druid.sql.SQLUtils import com.alibaba.druid.sql.repository.SchemaRepository import com.alibaba.druid.sql.visitor.SchemaStatVisitor -import com.alibaba.druid.util.JdbcConstants +import com.alibaba.druid.util.{JdbcConstants, JdbcUtils} +import org.apache.spark.{MLSQLSparkConst, SparkCoreVersion} import org.apache.spark.sql.catalyst.plans.logical.MLSQLDFParser import org.apache.spark.sql.execution.WowTableIdentifier import org.apache.spark.sql.mlsql.session.{MLSQLException, MLSQLSparkSession} @@ -53,17 +53,36 @@ class MLSQLDirectJDBC extends MLSQLDirectSource with MLSQLDirectSink with MLSQLS override def load(reader: DataFrameReader, config: DataSourceConfig): DataFrame = { val format = config.config.getOrElse("implClass", fullFormat) + var driver: Option[String] = config.config.get("driver") var url = config.config.get("url") if (config.path.contains(dbSplitter)) { - val Array(_dbname, _dbtable) = config.path.split(toSplit, 2) + val Array(_dbname, _) = config.path.split(toSplit, 2) ConnectMeta.presentThenCall(DBMappingKey(format, _dbname), options => { reader.options(options) url = options.get("url") + driver = options.get("driver") }) } - //load configs should overwrite connect configs reader.options(config.config) + assert(url.isDefined, s"url could not be null!") + assert(driver.isDefined, s"driver could not be null!") + if (JdbcUtils.isMySqlDriver(driver.get)) { + /** + * Fetch Size It's a value for JDBC PreparedStatement. + * To avoid data overload in the jvm and cause OOM, we set the default value to Integer's MINVALUE + */ + MLSQLSparkConst.majorVersion(SparkCoreVersion.exactVersion) match { + case 1 | 2 => + reader.options(Map("fetchsize" -> "1000")) + case _ => + reader.options(Map("fetchsize" -> s"${Integer.MIN_VALUE}")) + } + + url = url.map(x => if (x.contains("useCursorFetch")) x else s"$x&useCursorFetch=true") + .map(x => if (x.contains("autoReconnect")) x else s"$x&autoReconnect=true") + .map(x => if (x.contains("failOverReadOnly")) x else s"$x&failOverReadOnly=false") + } val dbtable = "(" + config.config("directQuery") + ") temp" diff --git a/streamingpro-mlsql/src/main/java/streaming/core/datasource/impl/MLSQLJDBC.scala b/streamingpro-mlsql/src/main/java/streaming/core/datasource/impl/MLSQLJDBC.scala index 9b51f810..3ca95228 100644 --- a/streamingpro-mlsql/src/main/java/streaming/core/datasource/impl/MLSQLJDBC.scala +++ b/streamingpro-mlsql/src/main/java/streaming/core/datasource/impl/MLSQLJDBC.scala @@ -19,11 +19,12 @@ package streaming.core.datasource.impl import java.util.Properties - import _root_.streaming.core.datasource.{SourceTypeRegistry, _} import _root_.streaming.dsl.mmlib.algs.param.{BaseParams, WowParams} import _root_.streaming.dsl.{ConnectMeta, DBMappingKey, ScriptSQLExec} import _root_.streaming.log.WowLog +import com.alibaba.druid.util.JdbcUtils +import org.apache.spark.{MLSQLSparkConst, SparkCoreVersion} import org.apache.spark.ml.param.{BooleanParam, LongParam, Param} import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions @@ -45,31 +46,50 @@ class MLSQLJDBC(override val uid: String) extends MLSQLSource with MLSQLSink wit def toSplit = "\\." override def load(reader: DataFrameReader, config: DataSourceConfig): DataFrame = { - var dbtable = config.path + var dbTable = config.path // if contains splitter, then we will try to find dbname in dbMapping. // otherwize we will do nothing since elasticsearch use something like index/type // it will do no harm. val format = config.config.getOrElse("implClass", fullFormat) + var driver: Option[String] = config.config.get("driver") var url = config.config.get("url") if (config.path.contains(dbSplitter)) { - val Array(_dbname, _dbtable) = config.path.split(toSplit, 2) + val Array(_dbname, _dbTable) = config.path.split(toSplit, 2) ConnectMeta.presentThenCall(DBMappingKey(format, _dbname), options => { - dbtable = _dbtable + dbTable = _dbTable reader.options(options) url = options.get("url") + driver = options.get("driver") }) } //load configs should overwrite connect configs reader.options(config.config) + assert(url.isDefined, s"url could not be null!") + assert(driver.isDefined, s"driver could not be null!") + if (JdbcUtils.isMySqlDriver(driver.get)) { + /** + * Fetch Size It's a value for JDBC PreparedStatement. + * To avoid data overload in the jvm and cause OOM, we set the default value to Integer's MINVALUE + */ + MLSQLSparkConst.majorVersion(SparkCoreVersion.exactVersion) match { + case 1 | 2 => + reader.options(Map("fetchsize" -> "1000")) + case _ => + reader.options(Map("fetchsize" -> s"${Integer.MIN_VALUE}")) + } + + url = url.map(x => if (x.contains("useCursorFetch")) x else s"$x&useCursorFetch=true") + .map(x => if (x.contains("autoReconnect")) x else s"$x&autoReconnect=true") + .map(x => if (x.contains("failOverReadOnly")) x else s"$x&failOverReadOnly=false") + } val table = if (config.config.contains("prePtnArray")){ val prePtn = config.config.get("prePtnArray").get .split(config.config.getOrElse("prePtnDelimiter" ,",")) - reader.jdbc(url.get, dbtable, prePtn, new Properties()) + reader.jdbc(url.get, dbTable, prePtn, new Properties()) }else{ - reader.option("dbtable", dbtable) - + reader.option("dbtable", dbTable) reader.format(format).load() }