From f2d8b7e0f3bce91e8bdef7bacaae4ff81113ebbd Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Tue, 14 Nov 2023 14:57:36 -0800 Subject: [PATCH 1/4] update --- .../src/main/resources/error/error-classes.json | 6 ++++++ docs/sql-error-conditions.md | 6 ++++++ python/pyspark/sql/datasource.py | 17 +++-------------- python/pyspark/sql/worker/create_data_source.py | 14 ++------------ .../sql/errors/QueryCompilationErrors.scala | 8 ++++++++ .../org/apache/spark/sql/DataFrameReader.scala | 10 +++++++++- .../datasources/DataSourceManager.scala | 1 - .../python/UserDefinedPythonDataSource.scala | 11 ++--------- .../python/PythonDataSourceSuite.scala | 15 +++++++++------ 9 files changed, 45 insertions(+), 43 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index afcd841a2ce0..2ff8f28aeee4 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -2468,6 +2468,12 @@ ], "sqlState" : "42803" }, + "MULTIPLE_PATHS_UNSUPPORTED" : { + "message" : [ + "Data source '' does not support load() with multiple paths: ''. Please load each path individually." + ], + "sqlState" : "42K02" + }, "MULTIPLE_TIME_TRAVEL_SPEC" : { "message" : [ "Cannot specify time travel in both the time travel clause and options." diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index cba6a24b8699..1e088dde79f6 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -1407,6 +1407,12 @@ For more details see [MISSING_ATTRIBUTES](sql-error-conditions-missing-attribute The query does not include a GROUP BY clause. Add GROUP BY or turn it into the window functions using OVER clauses. +### MULTIPLE_PATHS_UNSUPPORTED + +[SQLSTATE: 42K02](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Data source '``' does not support load() with multiple paths: '``'. Please load each path individually. + ### MULTIPLE_TIME_TRAVEL_SPEC [SQLSTATE: 42K0E](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py index c30a2c8689d6..b380e8b534eb 100644 --- a/python/pyspark/sql/datasource.py +++ b/python/pyspark/sql/datasource.py @@ -15,7 +15,7 @@ # limitations under the License. # from abc import ABC, abstractmethod -from typing import final, Any, Dict, Iterator, List, Optional, Tuple, Type, Union, TYPE_CHECKING +from typing import final, Any, Dict, Iterator, List, Tuple, Type, Union, TYPE_CHECKING from pyspark import since from pyspark.sql import Row @@ -45,21 +45,12 @@ class DataSource(ABC): """ @final - def __init__( - self, - paths: List[str], - userSpecifiedSchema: Optional[StructType], - options: Dict[str, "OptionalPrimitiveType"], - ) -> None: + def __init__(self, options: Dict[str, "OptionalPrimitiveType"]) -> None: """ - Initializes the data source with user-provided information. + Initializes the data source with user-provided options. Parameters ---------- - paths : list - A list of paths to the data source. - userSpecifiedSchema : StructType, optional - The user-specified schema of the data source. options : dict A dictionary representing the options for this data source. @@ -67,8 +58,6 @@ def __init__( ----- This method should not be overridden. """ - self.paths = paths - self.userSpecifiedSchema = userSpecifiedSchema self.options = options @classmethod diff --git a/python/pyspark/sql/worker/create_data_source.py b/python/pyspark/sql/worker/create_data_source.py index 6a9ef79b7c18..77f9d378fce2 100644 --- a/python/pyspark/sql/worker/create_data_source.py +++ b/python/pyspark/sql/worker/create_data_source.py @@ -17,7 +17,7 @@ import inspect import os import sys -from typing import IO, List +from typing import IO from pyspark.accumulators import _accumulatorRegistry from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, PySparkTypeError @@ -107,12 +107,6 @@ def main(infile: IO, outfile: IO) -> None: }, ) - # Receive the paths. - num_paths = read_int(infile) - paths: List[str] = [] - for _ in range(num_paths): - paths.append(utf8_deserializer.loads(infile)) - # Receive the user-specified schema user_specified_schema = None if read_bool(infile): @@ -136,11 +130,7 @@ def main(infile: IO, outfile: IO) -> None: # Instantiate a data source. try: - data_source = data_source_cls( - paths=paths, - userSpecifiedSchema=user_specified_schema, # type: ignore - options=options, - ) + data_source = data_source_cls(options=options) except Exception as e: raise PySparkRuntimeError( error_class="PYTHON_DATA_SOURCE_CREATE_ERROR", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 603de520b18b..394f01aa36f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3129,6 +3129,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "config" -> SQLConf.LEGACY_PATH_OPTION_BEHAVIOR.key)) } + def multiplePathsUnsupportedError(provider: String, paths: Seq[String]): Throwable = { + new AnalysisException( + errorClass = "MULTIPLE_PATHS_UNSUPPORTED", + messageParameters = Map( + "provider" -> provider, + "paths" -> paths.mkString("[", ", ", "]"))) + } + def pathOptionNotSetCorrectlyWhenWritingError(): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1307", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ef447e8a8010..c36dc1436f58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -246,7 +246,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val builder = sparkSession.sharedState.dataSourceManager.lookupDataSource(source) // Unless the legacy path option behavior is enabled, the extraOptions here // should not include "path" or "paths" as keys. - val plan = builder(sparkSession, source, paths, userSpecifiedSchema, extraOptions) + // Add path to the options field. Note currently it only supports a single path. + val optionsWithPath = if (paths.isEmpty) { + extraOptions + } else if (paths.length == 1) { + extraOptions + ("path" -> paths.head) + } else { + throw QueryCompilationErrors.multiplePathsUnsupportedError(source, paths) + } + val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath) Dataset.ofRows(sparkSession, plan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index 72a9e6497aca..a8c9c892b8b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -35,7 +35,6 @@ class DataSourceManager { private type DataSourceBuilder = ( SparkSession, // Spark session String, // provider name - Seq[String], // paths Option[StructType], // user specified schema CaseInsensitiveMap[String] // options ) => LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index 0e7eb056f434..7044ef65c638 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -42,12 +42,11 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { def builder( sparkSession: SparkSession, provider: String, - paths: Seq[String], userSpecifiedSchema: Option[StructType], options: CaseInsensitiveMap[String]): LogicalPlan = { val runner = new UserDefinedPythonDataSourceRunner( - dataSourceCls, provider, paths, userSpecifiedSchema, options) + dataSourceCls, provider, userSpecifiedSchema, options) val result = runner.runInPython() val pickledDataSourceInstance = result.dataSource @@ -68,10 +67,9 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { def apply( sparkSession: SparkSession, provider: String, - paths: Seq[String] = Seq.empty, userSpecifiedSchema: Option[StructType] = None, options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)): DataFrame = { - val plan = builder(sparkSession, provider, paths, userSpecifiedSchema, options) + val plan = builder(sparkSession, provider, userSpecifiedSchema, options) Dataset.ofRows(sparkSession, plan) } } @@ -89,7 +87,6 @@ case class PythonDataSourceCreationResult( class UserDefinedPythonDataSourceRunner( dataSourceCls: PythonFunction, provider: String, - paths: Seq[String], userSpecifiedSchema: Option[StructType], options: CaseInsensitiveMap[String]) extends PythonPlannerRunner[PythonDataSourceCreationResult](dataSourceCls) { @@ -103,10 +100,6 @@ class UserDefinedPythonDataSourceRunner( // Send the provider name PythonWorkerUtils.writeUTF(provider, dataOut) - // Send the paths - dataOut.writeInt(paths.length) - paths.foreach(PythonWorkerUtils.writeUTF(_, dataOut)) - // Send the user-specified schema, if provided dataOut.writeBoolean(userSpecifiedSchema.isDefined) userSpecifiedSchema.map(_.json).foreach(PythonWorkerUtils.writeUTF(_, dataOut)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index 22a1e5250cd9..c45b995a3de4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -161,12 +161,12 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { s""" |from pyspark.sql.datasource import DataSource, DataSourceReader |class SimpleDataSourceReader(DataSourceReader): - | def __init__(self, paths, options): - | self.paths = paths + | def __init__(self, options): | self.options = options | | def partitions(self): - | return iter(self.paths) + | paths = self.options.get("path", []) + | return paths | | def read(self, path): | yield (path, 1) @@ -180,14 +180,17 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { | return "id STRING, value INT" | | def reader(self, schema): - | return SimpleDataSourceReader(self.paths, self.options) + | return SimpleDataSourceReader(self.options) |""".stripMargin val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) spark.dataSource.registerPython("test", dataSource) - checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1))) checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1))) - checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), Row("2", 1))) + checkError( + exception = intercept[AnalysisException](spark.read.format("test").load("1", "2")), + errorClass = "MULTIPLE_PATHS_UNSUPPORTED", + parameters = Map("provider" -> "test", "paths" -> "[1, 2]") + ) } test("reader not implemented") { From 5c02050b9a7908aca87cce890fe20dfb364e7d2e Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Tue, 14 Nov 2023 15:10:28 -0800 Subject: [PATCH 2/4] fix tests --- .../sql/tests/test_python_datasource.py | 36 +++++++------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index fe6a84175274..46b9fa642fd0 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -30,7 +30,7 @@ class MyDataSource(DataSource): ... options = dict(a=1, b=2) - ds = MyDataSource(paths=[], userSpecifiedSchema=None, options=options) + ds = MyDataSource(options=options) self.assertEqual(ds.options, options) self.assertEqual(ds.name(), "MyDataSource") with self.assertRaises(NotImplementedError): @@ -53,8 +53,7 @@ def test_in_memory_data_source(self): class InMemDataSourceReader(DataSourceReader): DEFAULT_NUM_PARTITIONS: int = 3 - def __init__(self, paths, options): - self.paths = paths + def __init__(self, options): self.options = options def partitions(self): @@ -76,7 +75,7 @@ def schema(self): return "x INT, y STRING" def reader(self, schema) -> "DataSourceReader": - return InMemDataSourceReader(self.paths, self.options) + return InMemDataSourceReader(self.options) self.spark.dataSource.register(InMemoryDataSource) df = self.spark.read.format("memory").load() @@ -91,14 +90,13 @@ def test_custom_json_data_source(self): import json class JsonDataSourceReader(DataSourceReader): - def __init__(self, paths, options): - self.paths = paths + def __init__(self, options): self.options = options - def partitions(self): - return iter(self.paths) - - def read(self, path): + def read(self, partition): + path = self.options.get("path") + if path is None: + raise Exception("path is not specified") with open(path, "r") as file: for line in file.readlines(): if line.strip(): @@ -114,28 +112,18 @@ def schema(self): return "name STRING, age INT" def reader(self, schema) -> "DataSourceReader": - return JsonDataSourceReader(self.paths, self.options) + return JsonDataSourceReader(self.options) self.spark.dataSource.register(JsonDataSource) path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json") path2 = os.path.join(SPARK_HOME, "python/test_support/sql/people1.json") - df1 = self.spark.read.format("my-json").load(path1) - self.assertEqual(df1.rdd.getNumPartitions(), 1) assertDataFrameEqual( - df1, + self.spark.read.format("my-json").load(path1), [Row(name="Michael", age=None), Row(name="Andy", age=30), Row(name="Justin", age=19)], ) - - df2 = self.spark.read.format("my-json").load([path1, path2]) - self.assertEqual(df2.rdd.getNumPartitions(), 2) assertDataFrameEqual( - df2, - [ - Row(name="Michael", age=None), - Row(name="Andy", age=30), - Row(name="Justin", age=19), - Row(name="Jonathan", age=None), - ], + self.spark.read.format("my-json").load(path2), + [Row(name="Jonathan", age=None)], ) From 8fa55602c750d9fd31cc96f83c38a46b07421ca3 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Wed, 15 Nov 2023 15:59:44 -0800 Subject: [PATCH 3/4] update comments --- python/pyspark/sql/worker/create_data_source.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/worker/create_data_source.py b/python/pyspark/sql/worker/create_data_source.py index 77f9d378fce2..1ba4dc9e8a3c 100644 --- a/python/pyspark/sql/worker/create_data_source.py +++ b/python/pyspark/sql/worker/create_data_source.py @@ -55,7 +55,6 @@ def main(infile: IO, outfile: IO) -> None: The JVM sends the following information to this process: - a `DataSource` class representing the data source to be created. - a provider name in string. - - a list of paths in string. - an optional user-specified schema in json string. - a dictionary of options in string. From 60f230aa2640b7f12ea2d75ba557e917eefd0492 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Thu, 16 Nov 2023 23:11:38 -0800 Subject: [PATCH 4/4] address comments --- .../src/main/resources/error/error-classes.json | 6 ------ docs/sql-error-conditions.md | 6 ------ .../spark/sql/errors/QueryCompilationErrors.scala | 8 -------- .../org/apache/spark/sql/DataFrameReader.scala | 12 ++---------- .../datasources/v2/DataSourceV2Utils.scala | 2 +- .../execution/python/PythonDataSourceSuite.scala | 15 +++++++++------ 6 files changed, 12 insertions(+), 37 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 2ff8f28aeee4..afcd841a2ce0 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -2468,12 +2468,6 @@ ], "sqlState" : "42803" }, - "MULTIPLE_PATHS_UNSUPPORTED" : { - "message" : [ - "Data source '' does not support load() with multiple paths: ''. Please load each path individually." - ], - "sqlState" : "42K02" - }, "MULTIPLE_TIME_TRAVEL_SPEC" : { "message" : [ "Cannot specify time travel in both the time travel clause and options." diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 1e088dde79f6..cba6a24b8699 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -1407,12 +1407,6 @@ For more details see [MISSING_ATTRIBUTES](sql-error-conditions-missing-attribute The query does not include a GROUP BY clause. Add GROUP BY or turn it into the window functions using OVER clauses. -### MULTIPLE_PATHS_UNSUPPORTED - -[SQLSTATE: 42K02](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) - -Data source '``' does not support load() with multiple paths: '``'. Please load each path individually. - ### MULTIPLE_TIME_TRAVEL_SPEC [SQLSTATE: 42K0E](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 394f01aa36f3..603de520b18b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3129,14 +3129,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "config" -> SQLConf.LEGACY_PATH_OPTION_BEHAVIOR.key)) } - def multiplePathsUnsupportedError(provider: String, paths: Seq[String]): Throwable = { - new AnalysisException( - errorClass = "MULTIPLE_PATHS_UNSUPPORTED", - messageParameters = Map( - "provider" -> provider, - "paths" -> paths.mkString("[", ", ", "]"))) - } - def pathOptionNotSetCorrectlyWhenWritingError(): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1307", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index c36dc1436f58..7fadbbfac687 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -244,16 +244,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = { val builder = sparkSession.sharedState.dataSourceManager.lookupDataSource(source) - // Unless the legacy path option behavior is enabled, the extraOptions here - // should not include "path" or "paths" as keys. - // Add path to the options field. Note currently it only supports a single path. - val optionsWithPath = if (paths.isEmpty) { - extraOptions - } else if (paths.length == 1) { - extraOptions + ("path" -> paths.head) - } else { - throw QueryCompilationErrors.multiplePathsUnsupportedError(source, paths) - } + // Add `path` and `paths` options to the extra options if specified. + val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, paths: _*) val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath) Dataset.ofRows(sparkSession, plan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index c4e7bf23cace..3dde20ac44e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -152,7 +152,7 @@ private[sql] object DataSourceV2Utils extends Logging { } private lazy val objectMapper = new ObjectMapper() - private def getOptionsWithPaths( + def getOptionsWithPaths( extraOptions: CaseInsensitiveMap[String], paths: String*): CaseInsensitiveMap[String] = { if (paths.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index c45b995a3de4..bd0b08cbec8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -160,12 +160,19 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { val dataSourceScript = s""" |from pyspark.sql.datasource import DataSource, DataSourceReader + |import json + | |class SimpleDataSourceReader(DataSourceReader): | def __init__(self, options): | self.options = options | | def partitions(self): - | paths = self.options.get("path", []) + | if "paths" in self.options: + | paths = json.loads(self.options["paths"]) + | elif "path" in self.options: + | paths = [self.options["path"]] + | else: + | paths = [] | return paths | | def read(self, path): @@ -186,11 +193,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { spark.dataSource.registerPython("test", dataSource) checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1))) checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1))) - checkError( - exception = intercept[AnalysisException](spark.read.format("test").load("1", "2")), - errorClass = "MULTIPLE_PATHS_UNSUPPORTED", - parameters = Map("provider" -> "test", "paths" -> "[1, 2]") - ) + checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), Row("2", 1))) } test("reader not implemented") {