Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions python/pyspark/sql/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#
from abc import ABC, abstractmethod
from typing import final, Any, Dict, Iterator, List, Optional, Tuple, Type, Union, TYPE_CHECKING
from typing import final, Any, Dict, Iterator, List, Tuple, Type, Union, TYPE_CHECKING

from pyspark import since
from pyspark.sql import Row
Expand Down Expand Up @@ -45,30 +45,19 @@ class DataSource(ABC):
"""

@final
def __init__(
self,
paths: List[str],
userSpecifiedSchema: Optional[StructType],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we also remove user specified schema?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This field is actually not used. Both the reader and writer functions take in the schema parameter, and we can pass in the actual schema there.

options: Dict[str, "OptionalPrimitiveType"],
) -> None:
def __init__(self, options: Dict[str, "OptionalPrimitiveType"]) -> None:
"""
Initializes the data source with user-provided information.
Initializes the data source with user-provided options.

Parameters
----------
paths : list
A list of paths to the data source.
userSpecifiedSchema : StructType, optional
The user-specified schema of the data source.
options : dict
A dictionary representing the options for this data source.

Notes
-----
This method should not be overridden.
"""
self.paths = paths
self.userSpecifiedSchema = userSpecifiedSchema
self.options = options

@classmethod
Expand Down
36 changes: 12 additions & 24 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MyDataSource(DataSource):
...

options = dict(a=1, b=2)
ds = MyDataSource(paths=[], userSpecifiedSchema=None, options=options)
ds = MyDataSource(options=options)
self.assertEqual(ds.options, options)
self.assertEqual(ds.name(), "MyDataSource")
with self.assertRaises(NotImplementedError):
Expand All @@ -53,8 +53,7 @@ def test_in_memory_data_source(self):
class InMemDataSourceReader(DataSourceReader):
DEFAULT_NUM_PARTITIONS: int = 3

def __init__(self, paths, options):
self.paths = paths
def __init__(self, options):
self.options = options

def partitions(self):
Expand All @@ -76,7 +75,7 @@ def schema(self):
return "x INT, y STRING"

def reader(self, schema) -> "DataSourceReader":
return InMemDataSourceReader(self.paths, self.options)
return InMemDataSourceReader(self.options)

self.spark.dataSource.register(InMemoryDataSource)
df = self.spark.read.format("memory").load()
Expand All @@ -91,14 +90,13 @@ def test_custom_json_data_source(self):
import json

class JsonDataSourceReader(DataSourceReader):
def __init__(self, paths, options):
self.paths = paths
def __init__(self, options):
self.options = options

def partitions(self):
return iter(self.paths)

def read(self, path):
def read(self, partition):
path = self.options.get("path")
if path is None:
raise Exception("path is not specified")
with open(path, "r") as file:
for line in file.readlines():
if line.strip():
Expand All @@ -114,28 +112,18 @@ def schema(self):
return "name STRING, age INT"

def reader(self, schema) -> "DataSourceReader":
return JsonDataSourceReader(self.paths, self.options)
return JsonDataSourceReader(self.options)

self.spark.dataSource.register(JsonDataSource)
path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
path2 = os.path.join(SPARK_HOME, "python/test_support/sql/people1.json")
df1 = self.spark.read.format("my-json").load(path1)
self.assertEqual(df1.rdd.getNumPartitions(), 1)
assertDataFrameEqual(
df1,
self.spark.read.format("my-json").load(path1),
[Row(name="Michael", age=None), Row(name="Andy", age=30), Row(name="Justin", age=19)],
)

df2 = self.spark.read.format("my-json").load([path1, path2])
self.assertEqual(df2.rdd.getNumPartitions(), 2)
assertDataFrameEqual(
df2,
[
Row(name="Michael", age=None),
Row(name="Andy", age=30),
Row(name="Justin", age=19),
Row(name="Jonathan", age=None),
],
self.spark.read.format("my-json").load(path2),
[Row(name="Jonathan", age=None)],
)


Expand Down
15 changes: 2 additions & 13 deletions python/pyspark/sql/worker/create_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import inspect
import os
import sys
from typing import IO, List
from typing import IO

from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, PySparkTypeError
Expand Down Expand Up @@ -55,7 +55,6 @@ def main(infile: IO, outfile: IO) -> None:
The JVM sends the following information to this process:
- a `DataSource` class representing the data source to be created.
- a provider name in string.
- a list of paths in string.
- an optional user-specified schema in json string.
- a dictionary of options in string.

Expand Down Expand Up @@ -107,12 +106,6 @@ def main(infile: IO, outfile: IO) -> None:
},
)

# Receive the paths.
num_paths = read_int(infile)
paths: List[str] = []
for _ in range(num_paths):
paths.append(utf8_deserializer.loads(infile))

# Receive the user-specified schema
user_specified_schema = None
if read_bool(infile):
Expand All @@ -136,11 +129,7 @@ def main(infile: IO, outfile: IO) -> None:

# Instantiate a data source.
try:
data_source = data_source_cls(
paths=paths,
userSpecifiedSchema=user_specified_schema, # type: ignore
options=options,
)
data_source = data_source_cls(options=options)
except Exception as e:
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {

private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
val builder = sparkSession.sharedState.dataSourceManager.lookupDataSource(source)
// Unless the legacy path option behavior is enabled, the extraOptions here
// should not include "path" or "paths" as keys.
val plan = builder(sparkSession, source, paths, userSpecifiedSchema, extraOptions)
// Add `path` and `paths` options to the extra options if specified.
val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, paths: _*)
val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath)
Dataset.ofRows(sparkSession, plan)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class DataSourceManager {
private type DataSourceBuilder = (
SparkSession, // Spark session
String, // provider name
Seq[String], // paths
Option[StructType], // user specified schema
CaseInsensitiveMap[String] // options
) => LogicalPlan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ private[sql] object DataSourceV2Utils extends Logging {
}

private lazy val objectMapper = new ObjectMapper()
private def getOptionsWithPaths(
def getOptionsWithPaths(
extraOptions: CaseInsensitiveMap[String],
paths: String*): CaseInsensitiveMap[String] = {
if (paths.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,11 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
def builder(
sparkSession: SparkSession,
provider: String,
paths: Seq[String],
userSpecifiedSchema: Option[StructType],
options: CaseInsensitiveMap[String]): LogicalPlan = {

val runner = new UserDefinedPythonDataSourceRunner(
dataSourceCls, provider, paths, userSpecifiedSchema, options)
dataSourceCls, provider, userSpecifiedSchema, options)

val result = runner.runInPython()
val pickledDataSourceInstance = result.dataSource
Expand All @@ -68,10 +67,9 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
def apply(
sparkSession: SparkSession,
provider: String,
paths: Seq[String] = Seq.empty,
userSpecifiedSchema: Option[StructType] = None,
options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)): DataFrame = {
val plan = builder(sparkSession, provider, paths, userSpecifiedSchema, options)
val plan = builder(sparkSession, provider, userSpecifiedSchema, options)
Dataset.ofRows(sparkSession, plan)
}
}
Expand All @@ -89,7 +87,6 @@ case class PythonDataSourceCreationResult(
class UserDefinedPythonDataSourceRunner(
dataSourceCls: PythonFunction,
provider: String,
paths: Seq[String],
userSpecifiedSchema: Option[StructType],
options: CaseInsensitiveMap[String])
extends PythonPlannerRunner[PythonDataSourceCreationResult](dataSourceCls) {
Expand All @@ -103,10 +100,6 @@ class UserDefinedPythonDataSourceRunner(
// Send the provider name
PythonWorkerUtils.writeUTF(provider, dataOut)

// Send the paths
dataOut.writeInt(paths.length)
paths.foreach(PythonWorkerUtils.writeUTF(_, dataOut))

// Send the user-specified schema, if provided
dataOut.writeBoolean(userSpecifiedSchema.isDefined)
userSpecifiedSchema.map(_.json).foreach(PythonWorkerUtils.writeUTF(_, dataOut))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,20 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader
|import json
|
|class SimpleDataSourceReader(DataSourceReader):
| def __init__(self, paths, options):
Copy link
Member

@HyukjinKwon HyukjinKwon Nov 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regardless we should remove this paths in the interface. Not all Python Datasources require paths.

| self.paths = paths
| def __init__(self, options):
| self.options = options
|
| def partitions(self):
| return iter(self.paths)
| if "paths" in self.options:
| paths = json.loads(self.options["paths"])
| elif "path" in self.options:
| paths = [self.options["path"]]
| else:
| paths = []
| return paths
|
| def read(self, path):
| yield (path, 1)
Expand All @@ -180,11 +187,10 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
| return "id STRING, value INT"
|
| def reader(self, schema):
| return SimpleDataSourceReader(self.paths, self.options)
| return SimpleDataSourceReader(self.options)
|""".stripMargin
val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)
spark.dataSource.registerPython("test", dataSource)

checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1)))
checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1)))
checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), Row("2", 1)))
Expand Down