Skip to content

Commit f2f0b83

Browse files
authored
Merge pull request #13 from huggingface/enable-content-defined-chunking
Enable content defined chunking
2 parents 75747c1 + fc69164 commit f2f0b83

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ license = {text = "Apache License 2.0"}
1111
readme = "README.md"
1212
requires-python = ">=3.9"
1313
dependencies = [
14-
"datasets>=3.2",
15-
"huggingface-hub>=0.27.1",
14+
"datasets>=4.0",
15+
"huggingface-hub>=0.34.4",
16+
"pyarrow>=21.0.0",
1617
]
1718

1819
[dependency-groups]

pyspark_huggingface/huggingface_sink.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
logger = logging.getLogger(__name__)
2323

24+
2425
class HuggingFaceSink(DataSource):
2526
"""
2627
A DataSource for writing Spark DataFrames to HuggingFace Datasets.
@@ -125,8 +126,9 @@ def __init__(
125126
token: str,
126127
endpoint: Optional[str] = None,
127128
row_group_size: Optional[int] = None,
128-
max_bytes_per_file=500_000_000,
129-
max_operations_per_commit=100,
129+
max_bytes_per_file: int = 500_000_000,
130+
max_operations_per_commit: int = 100,
131+
use_content_defined_chunking: bool = True,
130132
**kwargs,
131133
):
132134
import uuid
@@ -144,6 +146,7 @@ def __init__(
144146
self.row_group_size = row_group_size
145147
self.max_bytes_per_file = max_bytes_per_file
146148
self.max_operations_per_commit = max_operations_per_commit
149+
self.use_content_defined_chunking = use_content_defined_chunking
147150
self.kwargs = kwargs
148151

149152
# Use a unique filename prefix to avoid conflicts with existing files
@@ -210,10 +213,9 @@ def flush(writer: pq.ParquetWriter):
210213
f"{self.prefix}-{self.uuid}-part-{partition_id}-{num_files}.parquet"
211214
)
212215
num_files += 1
213-
parquet.seek(0)
214216

215217
addition = CommitOperationAdd(
216-
path_in_repo=name, path_or_fileobj=parquet
218+
path_in_repo=name, path_or_fileobj=parquet.getvalue()
217219
)
218220
api.preupload_lfs_files(
219221
repo_id=self.repo_id,
@@ -232,7 +234,14 @@ def flush(writer: pq.ParquetWriter):
232234
Limiting the size is necessary because we are writing them in memory.
233235
"""
234236
while True:
235-
with pq.ParquetWriter(parquet, schema, **self.kwargs) as writer:
237+
with pq.ParquetWriter(
238+
parquet,
239+
schema=schema,
240+
**{
241+
"use_content_defined_chunking": self.use_content_defined_chunking,
242+
**self.kwargs
243+
}
244+
) as writer:
236245
num_batches = 0
237246
for batch in iterator: # Start iterating from where we left off
238247
writer.write_batch(batch, row_group_size=self.row_group_size)

0 commit comments

Comments
 (0)