From 2278e3fc02e68cda4efff9ca1d704ea3f774b838 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 16 Jan 2024 09:38:40 +0900 Subject: [PATCH] Do not pass static Python Data Sources around --- .../datasources/DataSourceManager.scala | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index 28c93357d8b47..8ee2325ca1f94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -32,18 +32,15 @@ import org.apache.spark.util.Utils * A manager for user-defined data sources. It is used to register and lookup data sources by * their short names or fully qualified names. */ -class DataSourceManager( - initDataSourceBuilders: => Option[ - Map[String, UserDefinedPythonDataSource]] = None - ) extends Logging { +class DataSourceManager extends Logging { import DataSourceManager._ + // Lazy to avoid being invoked during Session initialization. // Otherwise, it goes infinite loop, session -> Python runner -> SQLConf -> session. - private lazy val staticDataSourceBuilders = initDataSourceBuilders.getOrElse { - initialDataSourceBuilders - } + private lazy val staticDataSourceBuilders = initialStaticDataSourceBuilders - private val dataSourceBuilders = new ConcurrentHashMap[String, UserDefinedPythonDataSource]() + private val runtimeDataSourceBuilders = + new ConcurrentHashMap[String, UserDefinedPythonDataSource]() /** * Register a data source builder for the given provider. @@ -55,7 +52,7 @@ class DataSourceManager( // Cannot overwrite static Python Data Sources. throw QueryCompilationErrors.dataSourceAlreadyExists(name) } - val previousValue = dataSourceBuilders.put(normalizedName, source) + val previousValue = runtimeDataSourceBuilders.put(normalizedName, source) if (previousValue != null) { logWarning(f"The data source $name replaced a previously registered data source.") } @@ -69,7 +66,7 @@ class DataSourceManager( if (dataSourceExists(name)) { val normalizedName = normalize(name) staticDataSourceBuilders.getOrElse( - normalizedName, dataSourceBuilders.get(normalizedName)) + normalizedName, runtimeDataSourceBuilders.get(normalizedName)) } else { throw QueryCompilationErrors.dataSourceDoesNotExist(name) } @@ -81,12 +78,12 @@ class DataSourceManager( def dataSourceExists(name: String): Boolean = { val normalizedName = normalize(name) staticDataSourceBuilders.contains(normalizedName) || - dataSourceBuilders.containsKey(normalizedName) + runtimeDataSourceBuilders.containsKey(normalizedName) } override def clone(): DataSourceManager = { - val manager = new DataSourceManager(Some(staticDataSourceBuilders)) - dataSourceBuilders.forEach((k, v) => manager.registerDataSource(k, v)) + val manager = new DataSourceManager + runtimeDataSourceBuilders.forEach((k, v) => manager.registerDataSource(k, v)) manager } } @@ -103,7 +100,7 @@ object DataSourceManager extends Logging { private def normalize(name: String): String = name.toLowerCase(Locale.ROOT) - private def initialDataSourceBuilders: Map[String, UserDefinedPythonDataSource] = { + private def initialStaticDataSourceBuilders: Map[String, UserDefinedPythonDataSource] = { if (Utils.isTesting || shouldLoadPythonDataSources) this.synchronized { if (dataSourceBuilders.isEmpty) { val maybeResult = try {