From aa095a986bcf2cbfddabe75612ce6a78c524de2b Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Thu, 9 Nov 2023 16:29:19 -0800 Subject: [PATCH 1/4] change scope --- .../sql/tests/test_python_datasource.py | 29 +++++++++++++++++++ .../apache/spark/sql/DataFrameReader.scala | 4 +-- .../org/apache/spark/sql/SparkSession.scala | 2 +- .../datasources/DataSourceManager.scala | 15 +++++++--- .../internal/BaseSessionStateBuilder.scala | 15 ++++++++++ .../spark/sql/internal/SessionState.scala | 5 ++++ .../spark/sql/internal/SharedState.scala | 12 -------- 7 files changed, 63 insertions(+), 19 deletions(-) diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index 46b9fa642fd0..fc18c42d606b 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -49,6 +49,35 @@ def read(self, partition): self.assertEqual(list(reader.partitions()), [None]) self.assertEqual(list(reader.read(None)), [(None,)]) + def test_data_source_register(self): + class TestReader(DataSourceReader): + def read(self, partition): + yield (0, 1) + + class TestDataSource(DataSource): + def schema(self): + return "a INT, b INT" + + def reader(self, schema): + return TestReader() + + self.spark.dataSource.register(TestDataSource) + df = self.spark.read.format("TestDataSource").load() + assertDataFrameEqual(df, [Row(a=0, b=1)]) + + class MyDataSource(TestDataSource): + @classmethod + def name(cls): + return "TestDataSource" + def schema(self): + return "c INT, d INT" + + # Should be able to register the data source with the same name. + self.spark.dataSource.register(MyDataSource) + + df = self.spark.read.format("TestDataSource").load() + assertDataFrameEqual(df, [Row(c=0, d=1)]) + def test_in_memory_data_source(self): class InMemDataSourceReader(DataSourceReader): DEFAULT_NUM_PARTITIONS: int = 3 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 7fadbbfac687..c29ffb329072 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -210,7 +210,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } val isUserDefinedDataSource = - sparkSession.sharedState.dataSourceManager.dataSourceExists(source) + sparkSession.sessionState.dataSourceManager.dataSourceExists(source) Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) match { case Success(providerOpt) => @@ -243,7 +243,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = { - val builder = sparkSession.sharedState.dataSourceManager.lookupDataSource(source) + val builder = sparkSession.sessionState.dataSourceManager.lookupDataSource(source) // Add `path` and `paths` options to the extra options if specified. val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, paths: _*) val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 5eba9e59c17b..24497add04f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -233,7 +233,7 @@ class SparkSession private( /** * A collection of methods for registering user-defined data sources. */ - private[sql] def dataSource: DataSourceRegistration = sharedState.dataSourceRegistration + private[sql] def dataSource: DataSourceRegistration = sessionState.dataSourceRegistration /** * Returns a `StreamingQueryManager` that allows managing all the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index a8c9c892b8b0..1cdc3d9cb69e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources import java.util.Locale import java.util.concurrent.ConcurrentHashMap +import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -30,7 +31,7 @@ import org.apache.spark.sql.types.StructType * A manager for user-defined data sources. It is used to register and lookup data sources by * their short names or fully qualified names. */ -class DataSourceManager { +class DataSourceManager extends Logging { private type DataSourceBuilder = ( SparkSession, // Spark session @@ -49,10 +50,10 @@ class DataSourceManager { */ def registerDataSource(name: String, builder: DataSourceBuilder): Unit = { val normalizedName = normalize(name) - if (dataSourceBuilders.containsKey(normalizedName)) { - throw QueryCompilationErrors.dataSourceAlreadyExists(name) + val previousValue = dataSourceBuilders.put(normalizedName, builder) + if (previousValue != null) { + logWarning(f"The data source $name replaced a previously registered data source.") } - dataSourceBuilders.put(normalizedName, builder) } /** @@ -73,4 +74,10 @@ class DataSourceManager { def dataSourceExists(name: String): Boolean = { dataSourceBuilders.containsKey(normalize(name)) } + + override def clone(): DataSourceManager = { + val manager = new DataSourceManager + dataSourceBuilders.forEach((k, v) => manager.registerDataSource(k, v)) + manager + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 630e1202f6d3..d198e8f5d1f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -120,6 +120,13 @@ abstract class BaseSessionStateBuilder( .getOrElse(extensions.registerTableFunctions(TableFunctionRegistry.builtin.clone())) } + /** + * Manages the registration of data sources + */ + protected lazy val dataSourceManager: DataSourceManager = { + parentState.map(_.dataSourceManager.clone()).getOrElse(new DataSourceManager) + } + /** * Experimental methods that can be used to define custom optimization rules and custom planning * strategies. @@ -178,6 +185,12 @@ abstract class BaseSessionStateBuilder( protected def udtfRegistration: UDTFRegistration = new UDTFRegistration(tableFunctionRegistry) + /** + * A collection of method used for registering user-defined data sources. + */ + protected def dataSourceRegistration: DataSourceRegistration = + new DataSourceRegistration(dataSourceManager) + /** * Logical query plan analyzer for resolving unresolved attributes and relations. * @@ -376,6 +389,8 @@ abstract class BaseSessionStateBuilder( tableFunctionRegistry, udfRegistration, udtfRegistration, + dataSourceManager, + dataSourceRegistration, () => catalog, sqlParser, () => analyzer, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index adf3e0cb6cad..bc6710e6cbdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder +import org.apache.spark.sql.execution.datasources.DataSourceManager import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.{DependencyUtils, Utils} @@ -49,6 +50,8 @@ import org.apache.spark.util.{DependencyUtils, Utils} * @param udfRegistration Interface exposed to the user for registering user-defined functions. * @param udtfRegistration Interface exposed to the user for registering user-defined * table functions. + * @param dataSourceManager Internal catalog for managing data sources registered by users. + * @param dataSourceRegistration Interface exposed to users for registering data sources. * @param catalogBuilder a function to create an internal catalog for managing table and database * states. * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. @@ -73,6 +76,8 @@ private[sql] class SessionState( val tableFunctionRegistry: TableFunctionRegistry, val udfRegistration: UDFRegistration, val udtfRegistration: UDTFRegistration, + val dataSourceManager: DataSourceManager, + val dataSourceRegistration: DataSourceRegistration, catalogBuilder: () => SessionCatalog, val sqlParser: ParserInterface, analyzerBuilder: () => Analyzer, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 8adc32fcf621..164710cdd883 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -30,11 +30,9 @@ import org.apache.hadoop.fs.{FsUrlStreamHandlerFactory, Path} import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.DataSourceRegistration import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.CacheManager -import org.apache.spark.sql.execution.datasources.DataSourceManager import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLTab, StreamingQueryStatusStore} import org.apache.spark.sql.internal.StaticSQLConf._ @@ -107,16 +105,6 @@ private[sql] class SharedState( @GuardedBy("activeQueriesLock") private[sql] val activeStreamingQueries = new ConcurrentHashMap[UUID, StreamExecution]() - /** - * A data source manager shared by all sessions. - */ - lazy val dataSourceManager = new DataSourceManager() - - /** - * A collection of method used for registering user-defined data sources. - */ - lazy val dataSourceRegistration = new DataSourceRegistration(dataSourceManager) - /** * A status store to query SQL status/metrics of this Spark application, based on SQL-specific * [[org.apache.spark.scheduler.SparkListenerEvent]]s. From 944d3762035a0829bc74e8e90ea21a6faf51f61f Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Thu, 16 Nov 2023 20:22:39 -0800 Subject: [PATCH 2/4] style --- python/pyspark/sql/tests/test_python_datasource.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index fc18c42d606b..bab062c48215 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -69,6 +69,7 @@ class MyDataSource(TestDataSource): @classmethod def name(cls): return "TestDataSource" + def schema(self): return "c INT, d INT" From f6ff4cc2a4ffdaecb32f872c90ad13352b42ed06 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Fri, 17 Nov 2023 13:42:10 -0800 Subject: [PATCH 3/4] fix tests --- .../python/PythonDataSourceSuite.scala | 38 ++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index bd0b08cbec8b..ba19bc1bb2a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} import org.apache.spark.sql.catalyst.plans.logical.{BatchEvalPythonUDTF, PythonDataSourcePartitions} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType @@ -143,16 +144,35 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) spark.dataSource.registerPython(dataSourceName, dataSource) - assert(spark.sharedState.dataSourceManager.dataSourceExists(dataSourceName)) + assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName)) + val ds1 = spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName) + checkAnswer( + ds1(spark, dataSourceName, Seq.empty, None, CaseInsensitiveMap(Map.empty)), + Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) - // Check error when registering a data source with the same name. - val err = intercept[AnalysisException] { - spark.dataSource.registerPython(dataSourceName, dataSource) - } - checkError( - exception = err, - errorClass = "DATA_SOURCE_ALREADY_EXISTS", - parameters = Map("provider" -> dataSourceName)) + // Should be able to override an already registered data source. + val newScript = + s""" + |from pyspark.sql.datasource import DataSource, DataSourceReader + |class SimpleDataSourceReader(DataSourceReader): + | def read(self, partition): + | yield (0, ) + | + |class $dataSourceName(DataSource): + | def schema(self) -> str: + | return "id INT" + | + | def reader(self, schema): + | return SimpleDataSourceReader() + |""".stripMargin + val newDataSource = createUserDefinedPythonDataSource(dataSourceName, newScript) + spark.dataSource.registerPython(dataSourceName, newDataSource) + assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName)) + + val ds2 = spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName) + checkAnswer( + ds2(spark, dataSourceName, Seq.empty, None, CaseInsensitiveMap(Map.empty)), + Seq(Row(0))) } test("load data source") { From 5c8501f42b9e2ea876bb34592f60249689ca3d78 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Tue, 21 Nov 2023 21:29:39 -0800 Subject: [PATCH 4/4] fix tests --- .../spark/sql/execution/python/PythonDataSourceSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index ba19bc1bb2a9..33b34b39ab2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -147,7 +147,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName)) val ds1 = spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName) checkAnswer( - ds1(spark, dataSourceName, Seq.empty, None, CaseInsensitiveMap(Map.empty)), + ds1(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)), Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) // Should be able to override an already registered data source. @@ -171,7 +171,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { val ds2 = spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName) checkAnswer( - ds2(spark, dataSourceName, Seq.empty, None, CaseInsensitiveMap(Map.empty)), + ds2(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)), Seq(Row(0))) }