Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import java.io.File
import java.util.Locale
import java.util.concurrent.ConcurrentHashMap

import scala.jdk.CollectionConverters._

import org.apache.spark.api.python.PythonUtils
import org.apache.spark.internal.Logging
import org.apache.spark.sql.errors.QueryCompilationErrors
Expand All @@ -34,23 +32,29 @@ import org.apache.spark.util.Utils
* A manager for user-defined data sources. It is used to register and lookup data sources by
* their short names or fully qualified names.
*/
class DataSourceManager extends Logging {
class DataSourceManager(
initDataSourceBuilders: => Option[
Map[String, UserDefinedPythonDataSource]] = None
) extends Logging {
import DataSourceManager._
// Lazy to avoid being invoked during Session initialization.
// Otherwise, it goes infinite loop, session -> Python runner -> SQLConf -> session.
private lazy val dataSourceBuilders = {
val builders = new ConcurrentHashMap[String, UserDefinedPythonDataSource]()
builders.putAll(DataSourceManager.initialDataSourceBuilders.asJava)
builders
private lazy val staticDataSourceBuilders = initDataSourceBuilders.getOrElse {
initialDataSourceBuilders
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the DataSourceManager is session-level and should not be the one to initialize static data sources. When we initialize the DataSourceManager for each spark session, we can pass in the static ones.

So it might make more sense to have an API in SparkContext for static data sources?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree .. but the problem is that UserDefinedPythonDataSourceLookupRunner.runInPython requires SQLConf.get that requires SparkSession initialization.

So, this initialization of static datasources must happen at least when a session is created. So, I here put the static initialization logic into the first call of DataSourceManager in any session for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah because of these two configs:

    val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
    val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory

I think instead of accessing the SQLConf here, we should pass them as parameters to this method runInPython to avoid this initialization issue. Maybe we can add a TODO for a follow up PR?

}

private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)
private val dataSourceBuilders = new ConcurrentHashMap[String, UserDefinedPythonDataSource]()

/**
* Register a data source builder for the given provider.
* Note that the provider name is case-insensitive.
*/
def registerDataSource(name: String, source: UserDefinedPythonDataSource): Unit = {
val normalizedName = normalize(name)
if (staticDataSourceBuilders.contains(normalizedName)) {
// Cannot overwrite static Python Data Sources.
throw QueryCompilationErrors.dataSourceAlreadyExists(name)
}
val previousValue = dataSourceBuilders.put(normalizedName, source)
if (previousValue != null) {
logWarning(f"The data source $name replaced a previously registered data source.")
Expand All @@ -63,7 +67,9 @@ class DataSourceManager extends Logging {
*/
def lookupDataSource(name: String): UserDefinedPythonDataSource = {
if (dataSourceExists(name)) {
dataSourceBuilders.get(normalize(name))
val normalizedName = normalize(name)
staticDataSourceBuilders.getOrElse(
normalizedName, dataSourceBuilders.get(normalizedName))
} else {
throw QueryCompilationErrors.dataSourceDoesNotExist(name)
}
Expand All @@ -73,11 +79,13 @@ class DataSourceManager extends Logging {
* Checks if a data source with the specified name exists (case-insensitive).
*/
def dataSourceExists(name: String): Boolean = {
dataSourceBuilders.containsKey(normalize(name))
val normalizedName = normalize(name)
staticDataSourceBuilders.contains(normalizedName) ||
dataSourceBuilders.containsKey(normalizedName)
}

override def clone(): DataSourceManager = {
val manager = new DataSourceManager
val manager = new DataSourceManager(Some(staticDataSourceBuilders))
dataSourceBuilders.forEach((k, v) => manager.registerDataSource(k, v))
manager
}
Expand All @@ -93,6 +101,8 @@ object DataSourceManager extends Logging {
PythonUtils.sparkPythonPaths.forall(new File(_).exists())
}

private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)

private def initialDataSourceBuilders: Map[String, UserDefinedPythonDataSource] = {
if (Utils.isTesting || shouldLoadPythonDataSources) this.synchronized {
if (dataSourceBuilders.isEmpty) {
Expand All @@ -109,7 +119,7 @@ object DataSourceManager extends Logging {

dataSourceBuilders = maybeResult.map { result =>
result.names.zip(result.dataSources).map { case (name, dataSource) =>
name ->
normalize(name) ->
UserDefinedPythonDataSource(PythonUtils.createPythonFunction(dataSource))
}.toMap
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources

import org.apache.spark.SparkFunSuite

class DataSourceManagerSuite extends SparkFunSuite {
test("SPARK-46670: DataSourceManager should be self clone-able without lookup") {
val testAppender = new LogAppender("Cloneable DataSourceManager without lookup")
withLogAppender(testAppender) {
new DataSourceManager().clone()
}
assert(!testAppender.loggingEvents
.exists(msg =>
msg.getMessage.getFormattedMessage.contains("Skipping the lookup of Python Data Sources")))
}
}