Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,15 @@ import org.apache.spark.util.Utils
* A manager for user-defined data sources. It is used to register and lookup data sources by
* their short names or fully qualified names.
*/
class DataSourceManager(
initDataSourceBuilders: => Option[
Map[String, UserDefinedPythonDataSource]] = None
) extends Logging {
class DataSourceManager extends Logging {
import DataSourceManager._

// Lazy to avoid being invoked during Session initialization.
// Otherwise, it goes infinite loop, session -> Python runner -> SQLConf -> session.
private lazy val staticDataSourceBuilders = initDataSourceBuilders.getOrElse {
initialDataSourceBuilders
}
private lazy val staticDataSourceBuilders = initialStaticDataSourceBuilders

private val dataSourceBuilders = new ConcurrentHashMap[String, UserDefinedPythonDataSource]()
private val runtimeDataSourceBuilders =
new ConcurrentHashMap[String, UserDefinedPythonDataSource]()

/**
* Register a data source builder for the given provider.
Expand All @@ -55,7 +52,7 @@ class DataSourceManager(
// Cannot overwrite static Python Data Sources.
throw QueryCompilationErrors.dataSourceAlreadyExists(name)
}
val previousValue = dataSourceBuilders.put(normalizedName, source)
val previousValue = runtimeDataSourceBuilders.put(normalizedName, source)
if (previousValue != null) {
logWarning(f"The data source $name replaced a previously registered data source.")
}
Expand All @@ -69,7 +66,7 @@ class DataSourceManager(
if (dataSourceExists(name)) {
val normalizedName = normalize(name)
staticDataSourceBuilders.getOrElse(
normalizedName, dataSourceBuilders.get(normalizedName))
normalizedName, runtimeDataSourceBuilders.get(normalizedName))
} else {
throw QueryCompilationErrors.dataSourceDoesNotExist(name)
}
Expand All @@ -81,12 +78,12 @@ class DataSourceManager(
def dataSourceExists(name: String): Boolean = {
val normalizedName = normalize(name)
staticDataSourceBuilders.contains(normalizedName) ||
dataSourceBuilders.containsKey(normalizedName)
runtimeDataSourceBuilders.containsKey(normalizedName)
}

override def clone(): DataSourceManager = {
val manager = new DataSourceManager(Some(staticDataSourceBuilders))
dataSourceBuilders.forEach((k, v) => manager.registerDataSource(k, v))
val manager = new DataSourceManager
runtimeDataSourceBuilders.forEach((k, v) => manager.registerDataSource(k, v))
manager
}
}
Expand All @@ -103,7 +100,7 @@ object DataSourceManager extends Logging {

private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)

private def initialDataSourceBuilders: Map[String, UserDefinedPythonDataSource] = {
private def initialStaticDataSourceBuilders: Map[String, UserDefinedPythonDataSource] = {
if (Utils.isTesting || shouldLoadPythonDataSources) this.synchronized {
if (dataSourceBuilders.isEmpty) {
val maybeResult = try {
Expand Down