Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
167 changes: 167 additions & 0 deletions demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"id": "125a1871-6cab-4dc4-9fd5-4e5dbd63ada6",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings('ignore')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "38dc7e9e-35fd-4604-9be3-1a1a8749fbcb",
"metadata": {},
"outputs": [],
"source": [
"from pyspark_huggingface import HuggingFaceDatasets"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "620d3ecb-b9cb-480c-b300-69198cce7a9c",
"metadata": {},
"outputs": [],
"source": [
"from pyspark.sql import SparkSession"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "9255ffcb-0b61-43dc-b57a-2b8af01a8432",
"metadata": {},
"outputs": [],
"source": [
"spark = SparkSession.builder.getOrCreate()"
]
},
{
"cell_type": "code",
"id": "7c4501a8-26f4-4f52-9dc8-a70393d567b4",
"metadata": {},
"source": [
"spark.dataSource.register(HuggingFaceDatasets)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"execution_count": 14,
"id": "b8580bde-3f64-4c71-a087-8b3f71099aee",
"metadata": {},
"outputs": [],
"source": [
"df = spark.read.format(\"huggingface\").load(\"rotten_tomatoes\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8866bdfb-0782-4430-8b1e-09c65e699f41",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Stage 5:> (0 + 1) / 1]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"+--------------------+-----+\n",
"| text|label|\n",
"+--------------------+-----+\n",
"|the rock is desti...| 1|\n",
"|the gorgeously el...| 1|\n",
"|effective but too...| 1|\n",
"|if you sometimes ...| 1|\n",
"|emerges as someth...| 1|\n",
"|the film provides...| 1|\n",
"|offers that rare ...| 1|\n",
"|perhaps no pictur...| 1|\n",
"|steers turns in a...| 1|\n",
"|take care of my c...| 1|\n",
"|this is a film we...| 1|\n",
"|what really surpr...| 1|\n",
"|( wendigo is ) wh...| 1|\n",
"|one of the greate...| 1|\n",
"|ultimately , it p...| 1|\n",
"|an utterly compel...| 1|\n",
"|illuminating if o...| 1|\n",
"|a masterpiece fou...| 1|\n",
"|the movie's ripe ...| 1|\n",
"|offers a breath o...| 1|\n",
"+--------------------+-----+\n",
"only showing top 20 rows\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" "
]
}
],
"source": [
"df.show()"
]
},
{
"cell_type": "code",
"id": "873bb4fc-1424-4816-b835-6c2b839d3de4",
"metadata": {},
"source": [
"df.count()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a1b895f-fe20-4520-a90d-b17df8e691e4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pyspark_huggingface",
"language": "python",
"name": "pyspark_huggingface"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions pyspark_huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pyspark_huggingface.huggingface import HuggingFaceDatasets
90 changes: 90 additions & 0 deletions pyspark_huggingface/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import StructField, StructType, StringType


# TODO: Use `DefaultSource`
class HuggingFaceDatasets(DataSource):
"""
A DataSource for reading and writing HuggingFace Datasets in Spark.

This data source allows reading public datasets from the HuggingFace Hub directly into Spark
DataFrames. The schema is automatically inferred from the dataset features. The split can be
specified using the `split` option. The default split is `train`.

Name: `huggingface`

Notes:
-----
- The HuggingFace `datasets` library is required to use this data source. Make sure it is installed.
- If the schema is automatically inferred, it will use string type for all fields.
- Currently it can only be used with public datasets. Private or gated ones are not supported.

Examples
--------

Load a public dataset from the HuggingFace Hub.

>>> spark.read.format("huggingface").load("imdb").show()
+--------------------+-----+
| text|label|
+--------------------+-----+
|I rented I AM CUR...| 0|
|"I Am Curious: Ye...| 0|
|... | ...|
+--------------------+-----+

Load a specific split from a public dataset from the HuggingFace Hub.

>>> spark.read.format("huggingface").option("split", "test").load("imdb").show()
+--------------------+-----+
| text|label|
+--------------------+-----+
|I love sci-fi and...| 0|
|Worth the enterta...| 0|
|... | ...|
+--------------------+-----+
"""

def __init__(self, options):
super().__init__(options)
if "path" not in options or not options["path"]:
raise Exception("You must specify a dataset name in`.load()`.")

@classmethod
def name(cls):
return "huggingface"

def schema(self):
from datasets import load_dataset_builder
dataset_name = self.options["path"]
ds_builder = load_dataset_builder(dataset_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some datasets have configs/subsets that can ba loaded like load_dataset_builder(dataset_name, subset_name)

Some functions that we can use:

  • get_dataset_config_names
  • get_dataset_default_config_name

and we can also validate the split name using

  • get_dataset_split_names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice we can add an additional data source option for config/subset.

features = ds_builder.info.features
if features is None:
raise Exception(
"Unable to automatically determine the schema using the dataset features. "
"Please specify the schema manually using `.schema()`."
)
schema = StructType()
for key, value in features.items():
# For simplicity, use string for all values.
schema.add(StructField(key, StringType(), True))
return schema
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or simply this ? :)

Suggested change
schema = StructType()
for key, value in features.items():
# For simplicity, use string for all values.
schema.add(StructField(key, StringType(), True))
return schema
return from_arrow_schema(features.arrow_schema)

provided this is imported:

from pyspark.sql.pandas.types import from_arrow_schema

feel free to try in another PR if you prefer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know! So much easier to convert the schema now :)


def reader(self, schema: StructType) -> "DataSourceReader":
return HuggingFaceDatasetsReader(schema, self.options)


class HuggingFaceDatasetsReader(DataSourceReader):
def __init__(self, schema: StructType, options: dict):
self.schema = schema
self.dataset_name = options["path"]
# TODO: validate the split value.
self.split = options.get("split", "train") # Default using train split.

def read(self, partition):
from datasets import load_dataset
columns = [field.name for field in self.schema.fields]
iter_dataset = load_dataset(self.dataset_name, split=self.split, streaming=True)
for example in iter_dataset:
# TODO: next spark 4.0.0 dev release will include the feature to yield as an iterator of pa.RecordBatch
yield tuple([example.get(column) for column in columns])
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
datasets==3.1.0
pyspark[connect]==4.0.0.dev2
pytest
Empty file added tests/__init__.py
Empty file.
15 changes: 15 additions & 0 deletions tests/test_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest
from pyspark.sql import SparkSession
from pyspark_huggingface import HuggingFaceDatasets


@pytest.fixture
def spark():
spark = SparkSession.builder.getOrCreate()
yield spark


def test_basic_load(spark):
spark.dataSource.register(HuggingFaceDatasets)
df = spark.read.format("huggingface").load("rotten_tomatoes")
assert df.count() == 8530 # length of the training dataset