-
Notifications
You must be signed in to change notification settings - Fork 6
Add basic HuggingFace Data Source Implementation #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,167 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 11, | ||
| "id": "125a1871-6cab-4dc4-9fd5-4e5dbd63ada6", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import warnings\n", | ||
| "warnings.filterwarnings('ignore')" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 2, | ||
| "id": "38dc7e9e-35fd-4604-9be3-1a1a8749fbcb", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from pyspark_huggingface import HuggingFaceDatasets" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 3, | ||
| "id": "620d3ecb-b9cb-480c-b300-69198cce7a9c", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from pyspark.sql import SparkSession" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 12, | ||
| "id": "9255ffcb-0b61-43dc-b57a-2b8af01a8432", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "spark = SparkSession.builder.getOrCreate()" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "id": "7c4501a8-26f4-4f52-9dc8-a70393d567b4", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "spark.dataSource.register(HuggingFaceDatasets)" | ||
| ], | ||
| "outputs": [], | ||
| "execution_count": null | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 14, | ||
| "id": "b8580bde-3f64-4c71-a087-8b3f71099aee", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "df = spark.read.format(\"huggingface\").load(\"rotten_tomatoes\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 15, | ||
| "id": "8866bdfb-0782-4430-8b1e-09c65e699f41", | ||
| "metadata": { | ||
| "editable": true, | ||
| "slideshow": { | ||
| "slide_type": "" | ||
| }, | ||
| "tags": [] | ||
| }, | ||
| "outputs": [ | ||
| { | ||
| "name": "stderr", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "[Stage 5:> (0 + 1) / 1]" | ||
| ] | ||
| }, | ||
| { | ||
| "name": "stdout", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "+--------------------+-----+\n", | ||
| "| text|label|\n", | ||
| "+--------------------+-----+\n", | ||
| "|the rock is desti...| 1|\n", | ||
| "|the gorgeously el...| 1|\n", | ||
| "|effective but too...| 1|\n", | ||
| "|if you sometimes ...| 1|\n", | ||
| "|emerges as someth...| 1|\n", | ||
| "|the film provides...| 1|\n", | ||
| "|offers that rare ...| 1|\n", | ||
| "|perhaps no pictur...| 1|\n", | ||
| "|steers turns in a...| 1|\n", | ||
| "|take care of my c...| 1|\n", | ||
| "|this is a film we...| 1|\n", | ||
| "|what really surpr...| 1|\n", | ||
| "|( wendigo is ) wh...| 1|\n", | ||
| "|one of the greate...| 1|\n", | ||
| "|ultimately , it p...| 1|\n", | ||
| "|an utterly compel...| 1|\n", | ||
| "|illuminating if o...| 1|\n", | ||
| "|a masterpiece fou...| 1|\n", | ||
| "|the movie's ripe ...| 1|\n", | ||
| "|offers a breath o...| 1|\n", | ||
| "+--------------------+-----+\n", | ||
| "only showing top 20 rows\n", | ||
| "\n" | ||
| ] | ||
| }, | ||
| { | ||
| "name": "stderr", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| " " | ||
| ] | ||
| } | ||
| ], | ||
| "source": [ | ||
| "df.show()" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "id": "873bb4fc-1424-4816-b835-6c2b839d3de4", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "df.count()" | ||
| ], | ||
| "outputs": [], | ||
| "execution_count": null | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "4a1b895f-fe20-4520-a90d-b17df8e691e4", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "pyspark_huggingface", | ||
| "language": "python", | ||
| "name": "pyspark_huggingface" | ||
| }, | ||
| "language_info": { | ||
| "codemirror_mode": { | ||
| "name": "ipython", | ||
| "version": 3 | ||
| }, | ||
| "file_extension": ".py", | ||
| "mimetype": "text/x-python", | ||
| "name": "python", | ||
| "nbconvert_exporter": "python", | ||
| "pygments_lexer": "ipython3", | ||
| "version": "3.11.10" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 5 | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from pyspark_huggingface.huggingface import HuggingFaceDatasets |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,90 @@ | ||||||||||||||
| from pyspark.sql.datasource import DataSource, DataSourceReader | ||||||||||||||
| from pyspark.sql.types import StructField, StructType, StringType | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| # TODO: Use `DefaultSource` | ||||||||||||||
| class HuggingFaceDatasets(DataSource): | ||||||||||||||
| """ | ||||||||||||||
| A DataSource for reading and writing HuggingFace Datasets in Spark. | ||||||||||||||
|
|
||||||||||||||
| This data source allows reading public datasets from the HuggingFace Hub directly into Spark | ||||||||||||||
| DataFrames. The schema is automatically inferred from the dataset features. The split can be | ||||||||||||||
| specified using the `split` option. The default split is `train`. | ||||||||||||||
|
|
||||||||||||||
| Name: `huggingface` | ||||||||||||||
|
|
||||||||||||||
| Notes: | ||||||||||||||
| ----- | ||||||||||||||
| - The HuggingFace `datasets` library is required to use this data source. Make sure it is installed. | ||||||||||||||
| - If the schema is automatically inferred, it will use string type for all fields. | ||||||||||||||
| - Currently it can only be used with public datasets. Private or gated ones are not supported. | ||||||||||||||
|
|
||||||||||||||
| Examples | ||||||||||||||
| -------- | ||||||||||||||
|
|
||||||||||||||
| Load a public dataset from the HuggingFace Hub. | ||||||||||||||
|
|
||||||||||||||
| >>> spark.read.format("huggingface").load("imdb").show() | ||||||||||||||
| +--------------------+-----+ | ||||||||||||||
| | text|label| | ||||||||||||||
| +--------------------+-----+ | ||||||||||||||
| |I rented I AM CUR...| 0| | ||||||||||||||
| |"I Am Curious: Ye...| 0| | ||||||||||||||
| |... | ...| | ||||||||||||||
| +--------------------+-----+ | ||||||||||||||
|
|
||||||||||||||
| Load a specific split from a public dataset from the HuggingFace Hub. | ||||||||||||||
|
|
||||||||||||||
| >>> spark.read.format("huggingface").option("split", "test").load("imdb").show() | ||||||||||||||
| +--------------------+-----+ | ||||||||||||||
| | text|label| | ||||||||||||||
| +--------------------+-----+ | ||||||||||||||
| |I love sci-fi and...| 0| | ||||||||||||||
| |Worth the enterta...| 0| | ||||||||||||||
| |... | ...| | ||||||||||||||
| +--------------------+-----+ | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| def __init__(self, options): | ||||||||||||||
| super().__init__(options) | ||||||||||||||
| if "path" not in options or not options["path"]: | ||||||||||||||
| raise Exception("You must specify a dataset name in`.load()`.") | ||||||||||||||
|
|
||||||||||||||
| @classmethod | ||||||||||||||
| def name(cls): | ||||||||||||||
| return "huggingface" | ||||||||||||||
|
|
||||||||||||||
| def schema(self): | ||||||||||||||
| from datasets import load_dataset_builder | ||||||||||||||
| dataset_name = self.options["path"] | ||||||||||||||
| ds_builder = load_dataset_builder(dataset_name) | ||||||||||||||
| features = ds_builder.info.features | ||||||||||||||
| if features is None: | ||||||||||||||
| raise Exception( | ||||||||||||||
| "Unable to automatically determine the schema using the dataset features. " | ||||||||||||||
| "Please specify the schema manually using `.schema()`." | ||||||||||||||
| ) | ||||||||||||||
| schema = StructType() | ||||||||||||||
| for key, value in features.items(): | ||||||||||||||
| # For simplicity, use string for all values. | ||||||||||||||
| schema.add(StructField(key, StringType(), True)) | ||||||||||||||
| return schema | ||||||||||||||
|
||||||||||||||
| schema = StructType() | |
| for key, value in features.items(): | |
| # For simplicity, use string for all values. | |
| schema.add(StructField(key, StringType(), True)) | |
| return schema | |
| return from_arrow_schema(features.arrow_schema) |
provided this is imported:
from pyspark.sql.pandas.types import from_arrow_schemafeel free to try in another PR if you prefer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to know! So much easier to convert the schema now :)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| datasets==3.1.0 | ||
| pyspark[connect]==4.0.0.dev2 | ||
| pytest |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| import pytest | ||
| from pyspark.sql import SparkSession | ||
| from pyspark_huggingface import HuggingFaceDatasets | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def spark(): | ||
| spark = SparkSession.builder.getOrCreate() | ||
| yield spark | ||
|
|
||
|
|
||
| def test_basic_load(spark): | ||
| spark.dataSource.register(HuggingFaceDatasets) | ||
| df = spark.read.format("huggingface").load("rotten_tomatoes") | ||
| assert df.count() == 8530 # length of the training dataset |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some datasets have configs/subsets that can ba loaded like
load_dataset_builder(dataset_name, subset_name)Some functions that we can use:
and we can also validate the split name using
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice we can add an additional data source option for config/subset.