Skip to content

Commit 138e59a

Browse files
authored
[data] Handle nullable fields in schema across blocks for parquet files (#48478)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? When writing blocks to parquet, there might be blocks with fields that differ ONLY in nullability - by default, this would be rejected since some blocks might have a different schema than the ParquetWriter. However, we could potentially allow it to happen by tweaking the schema. This PR goes through all blocks before writing them to parquet, and merge schemas that differ only in nullability of the fields. It also casts the table to the newly merged schema so that the write could happen. <!-- Please give a short summary of the change and the problem this solves. --> ## Related issue number Closes #48102 --------- Signed-off-by: rickyx <[email protected]>
1 parent bcee207 commit 138e59a

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

python/ray/data/_internal/datasource/parquet_datasink.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def write(
5858
blocks: Iterable[Block],
5959
ctx: TaskContext,
6060
) -> None:
61+
import pyarrow as pa
6162
import pyarrow.parquet as pq
6263

6364
blocks = list(blocks)
@@ -72,16 +73,19 @@ def write(
7273
write_kwargs = _resolve_kwargs(
7374
self.arrow_parquet_args_fn, **self.arrow_parquet_args
7475
)
75-
schema = write_kwargs.pop("schema", None)
76-
if schema is None:
77-
schema = BlockAccessor.for_block(blocks[0]).to_arrow().schema
76+
user_schema = write_kwargs.pop("schema", None)
7877

7978
def write_blocks_to_path():
8079
with self.open_output_stream(write_path) as file:
8180
tables = [BlockAccessor.for_block(block).to_arrow() for block in blocks]
82-
with pq.ParquetWriter(file, schema, **write_kwargs) as writer:
81+
if user_schema is None:
82+
output_schema = pa.unify_schemas([table.schema for table in tables])
83+
else:
84+
output_schema = user_schema
85+
86+
with pq.ParquetWriter(file, output_schema, **write_kwargs) as writer:
8387
for table in tables:
84-
table = table.cast(schema)
88+
table = table.cast(output_schema)
8589
writer.write_table(table)
8690

8791
logger.debug(f"Writing {write_path} file.")

python/ray/data/tests/test_parquet.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,30 @@ def test_write_with_schema(ray_start_regular_shared, tmp_path):
13311331
assert pq.read_table(tmp_path).schema == schema
13321332

13331333

1334+
@pytest.mark.parametrize(
1335+
"row_data",
1336+
[
1337+
[{"a": 1, "b": None}, {"a": 1, "b": 2}],
1338+
[{"a": None, "b": 3}, {"a": 1, "b": 2}],
1339+
[{"a": None, "b": 1}, {"a": 1, "b": None}],
1340+
],
1341+
ids=["row1_b_null", "row1_a_null", "row_each_null"],
1342+
)
1343+
def test_write_auto_infer_nullable_fields(
1344+
tmp_path, ray_start_regular_shared, row_data, restore_data_context
1345+
):
1346+
"""
1347+
Test that when writing multiple blocks, we can automatically infer nullable
1348+
fields.
1349+
"""
1350+
ctx = DataContext.get_current()
1351+
# So that we force multiple blocks on mapping.
1352+
ctx.target_max_block_size = 1
1353+
ds = ray.data.range(len(row_data)).map(lambda row: row_data[row["id"]])
1354+
# So we force writing to a single file.
1355+
ds.write_parquet(tmp_path, num_rows_per_file=2)
1356+
1357+
13341358
if __name__ == "__main__":
13351359
import sys
13361360

0 commit comments

Comments
 (0)