Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 86 additions & 42 deletions pyspark_huggingface/huggingface_sink.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterator, List, Optional
from typing import TYPE_CHECKING, Iterator, List, Optional, Union

from pyspark.sql.datasource import (
DataSource,
Expand All @@ -11,7 +11,12 @@
from pyspark.sql.types import StructType

if TYPE_CHECKING:
from huggingface_hub import CommitOperationAdd, CommitOperationDelete
from huggingface_hub import (
CommitOperation,
CommitOperationAdd,
HfApi,
)
from huggingface_hub.hf_api import RepoFile, RepoFolder
from pyarrow import RecordBatch

logger = logging.getLogger(__name__)
Expand All @@ -27,8 +32,8 @@ class HuggingFaceSink(DataSource):
Data Source Options:
- token (str, required): HuggingFace API token for authentication.
- path (str, required): HuggingFace repository ID, e.g. `{username}/{dataset}`.
- path_in_repo (str): Path within the repository to write the data. Defaults to the root.
- split (str): Split name to write the data to. Defaults to `train`. Only `train`, `test`, and `validation` are supported.
- path_in_repo (str): Path within the repository to write the data. Defaults to "data".
- split (str): Split name to write the data to. Defaults to `train`.
- revision (str): Branch, tag, or commit to write to. Defaults to the main branch.
- endpoint (str): Custom HuggingFace API endpoint URL.
- max_bytes_per_file (int): Maximum size of each Parquet file.
Expand Down Expand Up @@ -125,7 +130,9 @@ def __init__(
import uuid

self.repo_id = repo_id
self.path_in_repo = (path_in_repo or "").strip("/")
self.path_in_repo = (
path_in_repo.strip("/") if path_in_repo is not None else "data"
)
self.split = split or "train"
self.revision = revision
self.schema = schema
Expand All @@ -140,26 +147,7 @@ def __init__(
# Use a unique filename prefix to avoid conflicts with existing files
self.uuid = uuid.uuid4()

self.validate()

def validate(self):
if self.split not in ["train", "test", "validation"]:
"""
TODO: Add support for custom splits.

For custom split names to be recognized, the files must have path with format:
`data/{split}-{iiiii}-of-{nnnnn}.parquet`
where `iiiii` is the part number and `nnnnn` is the total number of parts, both padded to 5 digits.
Example: `data/custom-00000-of-00002.parquet`

Therefore the current usage of UUID to avoid naming conflicts won't work for custom split names.
To fix this we can rename the files in the commit phase to satisfy the naming convention.
"""
raise NotImplementedError(
f"Only 'train', 'test', and 'validation' splits are supported. Got '{self.split}'."
)

def get_api(self):
def _get_api(self):
from huggingface_hub import HfApi

return HfApi(token=self.token, endpoint=self.endpoint)
Expand All @@ -168,16 +156,11 @@ def get_api(self):
def prefix(self) -> str:
return f"{self.path_in_repo}/{self.split}".strip("/")

def get_delete_operations(self) -> Iterator["CommitOperationDelete"]:
def _list_split(self, api: "HfApi") -> Iterator[Union["RepoFile", "RepoFolder"]]:
"""
Get the commit operations to delete all existing Parquet files.
This is used when `overwrite=True` to clear the target directory.
Get all existing files of the current split.
"""
from huggingface_hub import CommitOperationDelete
from huggingface_hub.errors import EntryNotFoundError
from huggingface_hub.hf_api import RepoFolder

api = self.get_api()

try:
objects = api.list_repo_tree(
Expand All @@ -190,11 +173,54 @@ def get_delete_operations(self) -> Iterator["CommitOperationDelete"]:
)
for obj in objects:
if obj.path.startswith(self.prefix):
yield CommitOperationDelete(
path_in_repo=obj.path, is_folder=isinstance(obj, RepoFolder)
yield obj
except EntryNotFoundError:
pass

def _prepare_operations(
self, api: "HfApi", additions: List["CommitOperationAdd"]
) -> Iterator["CommitOperation"]:
"""
Prepare operations for upload.
- Rename files to be recognizable by HuggingFace: `{split}-{current:05d}-of-{total:05d}.parquet`
- Delete existing files if `overwrite=True`

See: https://huggingface.co/docs/hub/en/datasets-file-names-and-splits

Note: additions are mutated to update the path_in_repo to the new filename.
"""
from huggingface_hub import CommitOperationCopy, CommitOperationDelete
from huggingface_hub.hf_api import RepoFile, RepoFolder

count_new = len(additions)
count_existing = 0

def format_path(i):
return f"{self.prefix}-{i:05d}-of-{count_new + count_existing:05d}.parquet"

# Rename existing files to have the correct total number of parts
if not self.overwrite:
existing = list(
obj for obj in self._list_split(api) if isinstance(obj, RepoFile)
)
count_existing = len(existing)
for i, obj in enumerate(existing):
new_path = format_path(i)
if obj.path != new_path:
yield CommitOperationCopy(
src_path_in_repo=obj.path, path_in_repo=new_path
)
except EntryNotFoundError as e:
logger.info(f"Writing to a new path: {e}")
yield CommitOperationDelete(path_in_repo=obj.path)
# Otherwise, delete existing files
else:
for obj in self._list_split(api):
yield CommitOperationDelete(
path_in_repo=obj.path, is_folder=isinstance(obj, RepoFolder)
)

# Rename additions
for i, addition in enumerate(additions):
addition.path_in_repo = format_path(i + count_existing)

def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage:
import io
Expand All @@ -208,7 +234,7 @@ def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage:
context = TaskContext.get()
partition_id = context.partitionId() if context else 0

api = self.get_api()
api = self._get_api()

schema = to_arrow_schema(self.schema)
num_files = 0
Expand Down Expand Up @@ -265,25 +291,43 @@ def flush(writer: pq.ParquetWriter):
return HuggingFaceCommitMessage(additions=additions)

def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ignore[override]
import math
api = self._get_api()

api = self.get_api()
operations = [
addition for message in messages for addition in message.additions
]
if self.overwrite: # Delete existing files if overwrite is enabled
operations.extend(self.get_delete_operations())
prepare_operations = list(self._prepare_operations(api, operations))

# First rename existing files or delete files to be overwritten
self._create_commits(
api,
operations=prepare_operations,
message="Prepare for upload using PySpark",
)

# Then upload the new files
# This is a separate commit to avoid conflicts when e.g. a renamed file's old name is the same as a new file
self._create_commits(
api,
operations=operations,
message="Upload using PySpark",
)

def _create_commits(
self, api: "HfApi", operations: List["CommitOperation"], message: str
) -> None:
"""
Split the commit into multiple parts if necessary.
The HuggingFace API may time out if there are too many operations in a single commit.
"""
import math

num_commits = math.ceil(len(operations) / self.max_operations_per_commit)
for i in range(num_commits):
begin = i * self.max_operations_per_commit
end = (i + 1) * self.max_operations_per_commit
part = operations[begin:end]
commit_message = "Upload using PySpark" + (
commit_message = message + (
f" (part {i:05d}-of-{num_commits:05d})" if num_commits > 1 else ""
)
api.create_commit(
Expand Down
50 changes: 27 additions & 23 deletions tests/test_huggingface_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pyspark.testing import assertDataFrameEqual
from pytest_mock import MockerFixture


# ============== Fixtures & Helpers ==============

@pytest.fixture(scope="session")
Expand All @@ -22,8 +21,10 @@ def token():
return os.environ["HF_TOKEN"]


def reader(spark):
return spark.read.format("huggingface").option("token", token())
def load(repo, split):
from datasets import load_dataset

return load_dataset(repo, token=token(), split=split).to_pandas()


def writer(df: DataFrame):
Expand All @@ -34,7 +35,7 @@ def writer(df: DataFrame):
def random_df(spark: SparkSession):
from pyspark.sql.functions import rand

return lambda n: spark.range(n).select((rand()).alias("value"))
return lambda n: spark.range(n, numPartitions=2).select((rand()).alias("value"))


@pytest.fixture(scope="session")
Expand All @@ -59,41 +60,44 @@ def repo(api, username):

# ============== Tests ==============

def test_basic(spark, repo, random_df):

def test_basic(repo, random_df):
df = random_df(10)
writer(df).mode("append").save(repo)
actual = reader(spark).load(repo)
assertDataFrameEqual(df, actual)
actual = load(repo, "train")
assertDataFrameEqual(actual, df.toPandas())


def test_append(spark, repo, random_df):
@pytest.mark.parametrize("split", ["train", "custom"])
def test_append(repo, random_df, split):
df1 = random_df(10)
df2 = random_df(10)
writer(df1).mode("append").save(repo)
writer(df2).mode("append").save(repo)
actual = reader(spark).load(repo)
writer(df1).options(split=split).mode("append").save(repo)
writer(df2).options(split=split).mode("append").save(repo)
actual = load(repo, split)
expected = df1.union(df2)
assertDataFrameEqual(actual, expected)
assertDataFrameEqual(actual, expected.toPandas())


def test_overwrite(spark, repo, random_df):
@pytest.mark.parametrize("split", ["train", "custom"])
def test_overwrite(repo, random_df, split):
df1 = random_df(10)
df2 = random_df(10)
writer(df1).mode("append").save(repo)
writer(df2).mode("overwrite").save(repo)
actual = reader(spark).load(repo)
assertDataFrameEqual(actual, df2)
writer(df1).options(split=split).mode("append").save(repo)
writer(df2).options(split=split).mode("overwrite").save(repo)
actual = load(repo, split)
assertDataFrameEqual(actual, df2.toPandas())


def test_split(spark, repo, random_df):
def test_split(repo, random_df):
df1 = random_df(10)
df2 = random_df(10)
writer(df1).mode("append").save(repo)
writer(df2).mode("append").options(split="test").save(repo)
actual1 = reader(spark).options(split="train").load(repo)
actual2 = reader(spark).options(split="test").load(repo)
assertDataFrameEqual(actual1, df1)
assertDataFrameEqual(actual2, df2)
writer(df2).mode("append").options(split="custom").save(repo)
actual1 = load(repo, "train")
actual2 = load(repo, "custom")
assertDataFrameEqual(actual1, df1.toPandas())
assertDataFrameEqual(actual2, df2.toPandas())


def test_revision(repo, random_df, api):
Expand Down