Skip to content

Commit 8647b7d

Browse files
authored
Use same data source name for reader and writer (#7)
* rename huggingface.py to huggingface_source.py * add wrapper class to have the same name for both source and sink * add docstring * undo formatting changes * undo formatting changes
1 parent df37cbd commit 8647b7d

File tree

3 files changed

+192
-143
lines changed

3 files changed

+192
-143
lines changed

pyspark_huggingface/huggingface.py

Lines changed: 29 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -1,161 +1,51 @@
1-
import ast
2-
from dataclasses import dataclass
3-
from typing import TYPE_CHECKING, Optional, Sequence
1+
from typing import Optional
42

5-
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
6-
from pyspark.sql.pandas.types import from_arrow_schema
3+
from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, DataSourceReader
74
from pyspark.sql.types import StructType
85

9-
if TYPE_CHECKING:
10-
from datasets import DatasetBuilder, IterableDataset
6+
from pyspark_huggingface.huggingface_sink import HuggingFaceSink
7+
from pyspark_huggingface.huggingface_source import HuggingFaceSource
8+
119

1210
class HuggingFaceDatasets(DataSource):
1311
"""
14-
A DataSource for reading and writing HuggingFace Datasets in Spark.
15-
16-
This data source allows reading public datasets from the HuggingFace Hub directly into Spark
17-
DataFrames. The schema is automatically inferred from the dataset features. The split can be
18-
specified using the `split` option. The default split is `train`.
19-
20-
Name: `huggingface`
21-
22-
Data Source Options:
23-
- split (str): Specify which split to retrieve. Default: train
24-
- config (str): Specify which subset or configuration to retrieve.
25-
- streaming (bool): Specify whether to read a dataset without downloading it.
26-
27-
Notes:
28-
-----
29-
- Currently it can only be used with public datasets. Private or gated ones are not supported.
30-
31-
Examples
32-
--------
33-
34-
Load a public dataset from the HuggingFace Hub.
35-
36-
>>> df = spark.read.format("huggingface").load("imdb")
37-
DataFrame[text: string, label: bigint]
38-
39-
>>> df.show()
40-
+--------------------+-----+
41-
| text|label|
42-
+--------------------+-----+
43-
|I rented I AM CUR...| 0|
44-
|"I Am Curious: Ye...| 0|
45-
|... | ...|
46-
+--------------------+-----+
47-
48-
Load a specific split from a public dataset from the HuggingFace Hub.
12+
DataSource for reading and writing HuggingFace Datasets in Spark.
4913
50-
>>> spark.read.format("huggingface").option("split", "test").load("imdb").show()
51-
+--------------------+-----+
52-
| text|label|
53-
+--------------------+-----+
54-
|I love sci-fi and...| 0|
55-
|Worth the enterta...| 0|
56-
|... | ...|
57-
+--------------------+-----+
14+
Read
15+
------
16+
See :py:class:`HuggingFaceSource` for more details.
5817
59-
Enable predicate pushdown for Parquet datasets.
60-
61-
>>> spark.read.format("huggingface") \
62-
... .option("filters", '[("language_score", ">", 0.99)]') \
63-
... .option("columns", '["text", "language_score"]') \
64-
... .load("HuggingFaceFW/fineweb-edu") \
65-
... .show()
66-
+--------------------+------------------+
67-
| text| language_score|
68-
+--------------------+------------------+
69-
|died Aug. 28, 181...|0.9901925325393677|
70-
|Coyotes spend a g...|0.9902171492576599|
71-
|... | ...|
72-
+--------------------+------------------+
18+
Write
19+
------
20+
See :py:class:`HuggingFaceSink` for more details.
7321
"""
7422

75-
DEFAULT_SPLIT: str = "train"
76-
77-
def __init__(self, options):
23+
# Delegate the source and sink methods to the respective classes.
24+
def __init__(self, options: dict):
7825
super().__init__(options)
79-
from datasets import load_dataset_builder
80-
81-
if "path" not in options or not options["path"]:
82-
raise Exception("You must specify a dataset name.")
26+
self.options = options
27+
self.source: Optional[HuggingFaceSource] = None
28+
self.sink: Optional[HuggingFaceSink] = None
8329

84-
kwargs = dict(self.options)
85-
self.dataset_name = kwargs.pop("path")
86-
self.config_name = kwargs.pop("config", None)
87-
self.split = kwargs.pop("split", self.DEFAULT_SPLIT)
88-
self.streaming = kwargs.pop("streaming", "true").lower() == "true"
89-
for arg in kwargs:
90-
if kwargs[arg].lower() == "true":
91-
kwargs[arg] = True
92-
elif kwargs[arg].lower() == "false":
93-
kwargs[arg] = False
94-
else:
95-
try:
96-
kwargs[arg] = ast.literal_eval(kwargs[arg])
97-
except ValueError:
98-
pass
30+
def get_source(self) -> HuggingFaceSource:
31+
if self.source is None:
32+
self.source = HuggingFaceSource(self.options.copy())
33+
return self.source
9934

100-
self.builder = load_dataset_builder(self.dataset_name, self.config_name, **kwargs)
101-
streaming_dataset = self.builder.as_streaming_dataset()
102-
if self.split not in streaming_dataset:
103-
raise Exception(f"Split {self.split} is invalid. Valid options are {list(streaming_dataset)}")
104-
105-
self.streaming_dataset = streaming_dataset[self.split]
106-
if not self.streaming_dataset.features:
107-
self.streaming_dataset = self.streaming_dataset._resolve_features()
35+
def get_sink(self):
36+
if self.sink is None:
37+
self.sink = HuggingFaceSink(self.options.copy())
38+
return self.sink
10839

10940
@classmethod
11041
def name(cls):
11142
return "huggingface"
11243

11344
def schema(self):
114-
return from_arrow_schema(self.streaming_dataset.features.arrow_schema)
45+
return self.get_source().schema()
11546

11647
def reader(self, schema: StructType) -> "DataSourceReader":
117-
return HuggingFaceDatasetsReader(
118-
schema,
119-
builder=self.builder,
120-
split=self.split,
121-
streaming_dataset=self.streaming_dataset if self.streaming else None
122-
)
123-
124-
125-
@dataclass
126-
class Shard(InputPartition):
127-
""" Represents a dataset shard. """
128-
index: int
129-
130-
131-
class HuggingFaceDatasetsReader(DataSourceReader):
132-
133-
def __init__(self, schema: StructType, builder: "DatasetBuilder", split: str, streaming_dataset: Optional["IterableDataset"]):
134-
self.schema = schema
135-
self.builder = builder
136-
self.split = split
137-
self.streaming_dataset = streaming_dataset
138-
# Get and validate the split name
139-
140-
def partitions(self) -> Sequence[InputPartition]:
141-
if self.streaming_dataset:
142-
return [Shard(index=i) for i in range(self.streaming_dataset.num_shards)]
143-
else:
144-
return [Shard(index=0)]
48+
return self.get_source().reader(schema)
14549

146-
def read(self, partition: Shard):
147-
columns = [field.name for field in self.schema.fields]
148-
if self.streaming_dataset:
149-
shard = self.streaming_dataset.shard(num_shards=self.streaming_dataset.num_shards, index=partition.index)
150-
if shard._ex_iterable.iter_arrow:
151-
for _, pa_table in shard._ex_iterable.iter_arrow():
152-
yield from pa_table.select(columns).to_batches()
153-
else:
154-
for _, example in shard:
155-
yield example
156-
else:
157-
self.builder.download_and_prepare()
158-
dataset = self.builder.as_dataset(self.split)
159-
# Get the underlying arrow table of the dataset
160-
table = dataset._data
161-
yield from table.select(columns).to_batches()
50+
def writer(self, schema: StructType, overwrite: bool) -> "DataSourceArrowWriter":
51+
return self.get_sink().writer(schema, overwrite)
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import ast
2+
from dataclasses import dataclass
3+
from typing import TYPE_CHECKING, Optional, Sequence
4+
5+
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
6+
from pyspark.sql.pandas.types import from_arrow_schema
7+
from pyspark.sql.types import StructType
8+
9+
if TYPE_CHECKING:
10+
from datasets import DatasetBuilder, IterableDataset
11+
12+
class HuggingFaceSource(DataSource):
13+
"""
14+
A DataSource for reading and writing HuggingFace Datasets in Spark.
15+
16+
This data source allows reading public datasets from the HuggingFace Hub directly into Spark
17+
DataFrames. The schema is automatically inferred from the dataset features. The split can be
18+
specified using the `split` option. The default split is `train`.
19+
20+
Name: `huggingface`
21+
22+
Data Source Options:
23+
- split (str): Specify which split to retrieve. Default: train
24+
- config (str): Specify which subset or configuration to retrieve.
25+
- streaming (bool): Specify whether to read a dataset without downloading it.
26+
27+
Notes:
28+
-----
29+
- Currently it can only be used with public datasets. Private or gated ones are not supported.
30+
31+
Examples
32+
--------
33+
34+
Load a public dataset from the HuggingFace Hub.
35+
36+
>>> df = spark.read.format("huggingface").load("imdb")
37+
DataFrame[text: string, label: bigint]
38+
39+
>>> df.show()
40+
+--------------------+-----+
41+
| text|label|
42+
+--------------------+-----+
43+
|I rented I AM CUR...| 0|
44+
|"I Am Curious: Ye...| 0|
45+
|... | ...|
46+
+--------------------+-----+
47+
48+
Load a specific split from a public dataset from the HuggingFace Hub.
49+
50+
>>> spark.read.format("huggingface").option("split", "test").load("imdb").show()
51+
+--------------------+-----+
52+
| text|label|
53+
+--------------------+-----+
54+
|I love sci-fi and...| 0|
55+
|Worth the enterta...| 0|
56+
|... | ...|
57+
+--------------------+-----+
58+
59+
Enable predicate pushdown for Parquet datasets.
60+
61+
>>> spark.read.format("huggingface") \
62+
... .option("filters", '[("language_score", ">", 0.99)]') \
63+
... .option("columns", '["text", "language_score"]') \
64+
... .load("HuggingFaceFW/fineweb-edu") \
65+
... .show()
66+
+--------------------+------------------+
67+
| text| language_score|
68+
+--------------------+------------------+
69+
|died Aug. 28, 181...|0.9901925325393677|
70+
|Coyotes spend a g...|0.9902171492576599|
71+
|... | ...|
72+
+--------------------+------------------+
73+
"""
74+
75+
DEFAULT_SPLIT: str = "train"
76+
77+
def __init__(self, options):
78+
super().__init__(options)
79+
from datasets import load_dataset_builder
80+
81+
if "path" not in options or not options["path"]:
82+
raise Exception("You must specify a dataset name.")
83+
84+
kwargs = dict(self.options)
85+
self.dataset_name = kwargs.pop("path")
86+
self.config_name = kwargs.pop("config", None)
87+
self.split = kwargs.pop("split", self.DEFAULT_SPLIT)
88+
self.streaming = kwargs.pop("streaming", "true").lower() == "true"
89+
for arg in kwargs:
90+
if kwargs[arg].lower() == "true":
91+
kwargs[arg] = True
92+
elif kwargs[arg].lower() == "false":
93+
kwargs[arg] = False
94+
else:
95+
try:
96+
kwargs[arg] = ast.literal_eval(kwargs[arg])
97+
except ValueError:
98+
pass
99+
100+
self.builder = load_dataset_builder(self.dataset_name, self.config_name, **kwargs)
101+
streaming_dataset = self.builder.as_streaming_dataset()
102+
if self.split not in streaming_dataset:
103+
raise Exception(f"Split {self.split} is invalid. Valid options are {list(streaming_dataset)}")
104+
105+
self.streaming_dataset = streaming_dataset[self.split]
106+
if not self.streaming_dataset.features:
107+
self.streaming_dataset = self.streaming_dataset._resolve_features()
108+
109+
@classmethod
110+
def name(cls):
111+
return "huggingfacesource"
112+
113+
def schema(self):
114+
return from_arrow_schema(self.streaming_dataset.features.arrow_schema)
115+
116+
def reader(self, schema: StructType) -> "DataSourceReader":
117+
return HuggingFaceDatasetsReader(
118+
schema,
119+
builder=self.builder,
120+
split=self.split,
121+
streaming_dataset=self.streaming_dataset if self.streaming else None
122+
)
123+
124+
125+
@dataclass
126+
class Shard(InputPartition):
127+
""" Represents a dataset shard. """
128+
index: int
129+
130+
131+
class HuggingFaceDatasetsReader(DataSourceReader):
132+
133+
def __init__(self, schema: StructType, builder: "DatasetBuilder", split: str, streaming_dataset: Optional["IterableDataset"]):
134+
self.schema = schema
135+
self.builder = builder
136+
self.split = split
137+
self.streaming_dataset = streaming_dataset
138+
# Get and validate the split name
139+
140+
def partitions(self) -> Sequence[InputPartition]:
141+
if self.streaming_dataset:
142+
return [Shard(index=i) for i in range(self.streaming_dataset.num_shards)]
143+
else:
144+
return [Shard(index=0)]
145+
146+
def read(self, partition: Shard):
147+
columns = [field.name for field in self.schema.fields]
148+
if self.streaming_dataset:
149+
shard = self.streaming_dataset.shard(num_shards=self.streaming_dataset.num_shards, index=partition.index)
150+
if shard._ex_iterable.iter_arrow:
151+
for _, pa_table in shard._ex_iterable.iter_arrow():
152+
yield from pa_table.select(columns).to_batches()
153+
else:
154+
for _, example in shard:
155+
yield example
156+
else:
157+
self.builder.download_and_prepare()
158+
dataset = self.builder.as_dataset(self.split)
159+
# Get the underlying arrow table of the dataset
160+
table = dataset._data
161+
yield from table.select(columns).to_batches()

tests/test_huggingface_writer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88

99
# ============== Fixtures & Helpers ==============
1010

11+
1112
@pytest.fixture(scope="session")
1213
def spark():
13-
from pyspark_huggingface.huggingface_sink import HuggingFaceSink
14-
1514
spark = SparkSession.builder.getOrCreate()
16-
spark.dataSource.register(HuggingFaceSink)
1715
yield spark
1816

1917

@@ -28,7 +26,7 @@ def load(repo, split):
2826

2927

3028
def writer(df: DataFrame):
31-
return df.write.format("huggingfacesink").option("token", token())
29+
return df.write.format("huggingface").option("token", token())
3230

3331

3432
@pytest.fixture(scope="session")

0 commit comments

Comments
 (0)