diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py index e20d44039a69..bdedbac3544e 100644 --- a/python/pyspark/sql/datasource.py +++ b/python/pyspark/sql/datasource.py @@ -15,18 +15,25 @@ # limitations under the License. # from abc import ABC, abstractmethod -from typing import final, Any, Dict, Iterator, List, Sequence, Tuple, Type, Union, TYPE_CHECKING +from collections import UserDict +from typing import Any, Dict, Iterator, List, Sequence, Tuple, Type, Union, TYPE_CHECKING from pyspark.sql import Row from pyspark.sql.types import StructType from pyspark.errors import PySparkNotImplementedError if TYPE_CHECKING: - from pyspark.sql._typing import OptionalPrimitiveType from pyspark.sql.session import SparkSession -__all__ = ["DataSource", "DataSourceReader", "DataSourceWriter", "DataSourceRegistration"] +__all__ = [ + "DataSource", + "DataSourceReader", + "DataSourceWriter", + "DataSourceRegistration", + "InputPartition", + "WriterCommitMessage", +] class DataSource(ABC): @@ -45,15 +52,14 @@ class DataSource(ABC): .. versionadded: 4.0.0 """ - @final - def __init__(self, options: Dict[str, "OptionalPrimitiveType"]) -> None: + def __init__(self, options: Dict[str, str]) -> None: """ Initializes the data source with user-provided options. Parameters ---------- options : dict - A dictionary representing the options for this data source. + A case-insensitive dictionary representing the options for this data source. Notes ----- @@ -403,3 +409,36 @@ def register( assert sc._jvm is not None ds = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonDataSource(wrapped) self.sparkSession._jsparkSession.dataSource().registerPython(name, ds) + + +class CaseInsensitiveDict(UserDict): + """ + A case-insensitive map of string keys to values. + + This is used by Python data source options to ensure consistent case insensitivity. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.update(*args, **kwargs) + + def __setitem__(self, key: str, value: Any) -> None: + super().__setitem__(key.lower(), value) + + def __getitem__(self, key: str) -> Any: + return super().__getitem__(key.lower()) + + def __delitem__(self, key: str) -> None: + super().__delitem__(key.lower()) + + def __contains__(self, key: object) -> bool: + if isinstance(key, str): + return super().__contains__(key.lower()) + return False + + def update(self, *args: Any, **kwargs: Any) -> None: + for k, v in dict(*args, **kwargs).items(): + self[k] = v + + def copy(self) -> "CaseInsensitiveDict": + return type(self)(self) diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index ce629b2718e2..79414cb7ed69 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -26,6 +26,7 @@ InputPartition, DataSourceWriter, WriterCommitMessage, + CaseInsensitiveDict, ) from pyspark.sql.types import Row, StructType from pyspark.testing import assertDataFrameEqual @@ -346,6 +347,26 @@ def test_custom_json_data_source_abort(self): text = file.read() assert text == "failed" + def test_case_insensitive_dict(self): + d = CaseInsensitiveDict({"foo": 1, "Bar": 2}) + self.assertEqual(d["foo"], d["FOO"]) + self.assertEqual(d["bar"], d["BAR"]) + self.assertTrue("baR" in d) + d["BAR"] = 3 + self.assertEqual(d["BAR"], 3) + # Test update + d.update({"BaZ": 3}) + self.assertEqual(d["BAZ"], 3) + d.update({"FOO": 4}) + self.assertEqual(d["foo"], 4) + # Test delete + del d["FoO"] + self.assertFalse("FOO" in d) + # Test copy + d2 = d.copy() + self.assertEqual(d2["BaR"], 3) + self.assertEqual(d2["baz"], 3) + class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase): ... diff --git a/python/pyspark/sql/worker/create_data_source.py b/python/pyspark/sql/worker/create_data_source.py index 1ba4dc9e8a3c..a377911c6e9b 100644 --- a/python/pyspark/sql/worker/create_data_source.py +++ b/python/pyspark/sql/worker/create_data_source.py @@ -29,7 +29,7 @@ write_with_length, SpecialLengths, ) -from pyspark.sql.datasource import DataSource +from pyspark.sql.datasource import DataSource, CaseInsensitiveDict from pyspark.sql.types import _parse_datatype_json_string, StructType from pyspark.util import handle_worker_exception from pyspark.worker_util import ( @@ -120,7 +120,7 @@ def main(infile: IO, outfile: IO) -> None: ) # Receive the options. - options = dict() + options = CaseInsensitiveDict() num_options = read_int(infile) for _ in range(num_options): key = utf8_deserializer.loads(infile) @@ -129,7 +129,7 @@ def main(infile: IO, outfile: IO) -> None: # Instantiate a data source. try: - data_source = data_source_cls(options=options) + data_source = data_source_cls(options=options) # type: ignore except Exception as e: raise PySparkRuntimeError( error_class="PYTHON_DATA_SOURCE_CREATE_ERROR", diff --git a/python/pyspark/sql/worker/write_into_data_source.py b/python/pyspark/sql/worker/write_into_data_source.py index 36b3c23b3379..0ba6fc6eb17f 100644 --- a/python/pyspark/sql/worker/write_into_data_source.py +++ b/python/pyspark/sql/worker/write_into_data_source.py @@ -30,7 +30,7 @@ SpecialLengths, ) from pyspark.sql import Row -from pyspark.sql.datasource import DataSource, WriterCommitMessage +from pyspark.sql.datasource import DataSource, WriterCommitMessage, CaseInsensitiveDict from pyspark.sql.types import ( _parse_datatype_json_string, StructType, @@ -142,7 +142,7 @@ def main(infile: IO, outfile: IO) -> None: return_col_name = return_type[0].name # Receive the options. - options = dict() + options = CaseInsensitiveDict() num_options = read_int(infile) for _ in range(num_options): key = utf8_deserializer.loads(infile) @@ -153,7 +153,7 @@ def main(infile: IO, outfile: IO) -> None: overwrite = read_bool(infile) # Instantiate a data source. - data_source = data_source_cls(options=options) + data_source = data_source_cls(options=options) # type: ignore # Instantiate the data source writer. writer = data_source.writer(schema, overwrite) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index 3e7cd82db8d7..dd065c97cb02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -772,4 +772,47 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { } } } + + test("SPARK-46568: case insensitive options") { + assume(shouldTestPandasUDFs) + val dataSourceScript = + s""" + |from pyspark.sql.datasource import ( + | DataSource, DataSourceReader, DataSourceWriter, WriterCommitMessage) + |class SimpleDataSourceReader(DataSourceReader): + | def __init__(self, options): + | self.options = options + | + | def read(self, partition): + | foo = self.options.get("Foo") + | bar = self.options.get("BAR") + | baz = "BaZ" in self.options + | yield (foo, bar, baz) + | + |class SimpleDataSourceWriter(DataSourceWriter): + | def __init__(self, options): + | self.options = options + | + | def write(self, row): + | if "FOO" not in self.options or "BAR" not in self.options: + | raise Exception("FOO or BAR not found") + | return WriterCommitMessage() + | + |class $dataSourceName(DataSource): + | def schema(self) -> str: + | return "a string, b string, c string" + | + | def reader(self, schema): + | return SimpleDataSourceReader(self.options) + | + | def writer(self, schema, overwrite): + | return SimpleDataSourceWriter(self.options) + |""".stripMargin + val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + val df = spark.read.option("foo", 1).option("bar", 2).option("BAZ", 3) + .format(dataSourceName).load() + checkAnswer(df, Row("1", "2", "true")) + df.write.option("foo", 1).option("bar", 2).format(dataSourceName).mode("append").save() + } }