Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,22 @@ def test_data_source_read_output_row(self):
df = self.spark.read.format("test").load()
assertDataFrameEqual(df, [Row(0, 1)])

def test_data_source_read_output_named_row(self):
self.register_data_source(
read_func=lambda schema, partition: iter([Row(j=1, i=0), Row(i=1, j=2)])
)
df = self.spark.read.format("test").load()
assertDataFrameEqual(df, [Row(0, 1), Row(1, 2)])

def test_data_source_read_output_named_row_with_wrong_schema(self):
self.register_data_source(
read_func=lambda schema, partition: iter([Row(i=1, j=2), Row(j=3, k=4)])
)
with self.assertRaisesRegex(
PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH"
):
self.spark.read.format("test").load().show()

def test_data_source_read_output_none(self):
self.register_data_source(read_func=lambda schema, partition: None)
df = self.spark.read.format("test").load()
Expand Down
24 changes: 22 additions & 2 deletions python/pyspark/sql/worker/plan_data_source_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
write_int,
SpecialLengths,
)
from pyspark.sql import Row
from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, LocalDataToArrowConversion
from pyspark.sql.datasource import DataSource, InputPartition
from pyspark.sql.pandas.types import to_arrow_schema
Expand Down Expand Up @@ -234,6 +235,8 @@ def batched(iterator: Iterator, n: int) -> Iterator:

# Convert the results from the `reader.read` method to an iterator of arrow batches.
num_cols = len(column_names)
col_mapping = {name: i for i, name in enumerate(column_names)}
col_name_set = set(column_names)
for batch in batched(output_iter, max_arrow_batch_size):
pylist: List[List] = [[] for _ in range(num_cols)]
for result in batch:
Expand All @@ -258,8 +261,25 @@ def batched(iterator: Iterator, n: int) -> Iterator:
},
)

for col in range(num_cols):
pylist[col].append(column_converters[col](result[col]))
# Assign output values by name of the field, not position, if the result is a
# named `Row` object.
if isinstance(result, Row) and hasattr(result, "__fields__"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we match the implementation with python worker? See assign_cols_by_name at worker.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually different from assign_cols_by_name which re-arranges arrow batch columns by the arrow type names. Here we want to match a single named Row object to the return schema. The only way to tell whether it's named Row(a=1, b=1) from an unnamed Row(1,2) is by checking this __fields__.

# Check if the names are the same as the schema.
if set(result.__fields__) != col_name_set:
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH",
message_parameters={
"expected": str(column_names),
"actual": str(result.__fields__),
},
)
# Assign the values by name.
for name in column_names:
idx = col_mapping[name]
pylist[idx].append(column_converters[idx](result[name]))
else:
for col in range(num_cols):
pylist[col].append(column_converters[col](result[col]))

yield pa.RecordBatch.from_arrays(pylist, schema=pa_schema)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,32 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
}
}

test("SPARK-46540: data source read output named rows") {
assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader
|class SimpleDataSourceReader(DataSourceReader):
| def read(self, partition):
| from pyspark.sql import Row
| yield Row(x = 0, y = 1)
| yield Row(y = 2, x = 1)
| yield Row(2, 3)
| yield (3, 4)
|
|class $dataSourceName(DataSource):
| def schema(self) -> str:
| return "x int, y int"
|
| def reader(self, schema):
| return SimpleDataSourceReader()
|""".stripMargin
val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)
spark.dataSource.registerPython(dataSourceName, dataSource)
val df = spark.read.format(dataSourceName).load()
checkAnswer(df, Seq(Row(0, 1), Row(1, 2), Row(2, 3), Row(3, 4)))
}

test("SPARK-46424: Support Python metrics") {
assume(shouldTestPandasUDFs)
val dataSourceScript =
Expand Down