From d818b03387aa65165406eff516fa534eb732c82c Mon Sep 17 00:00:00 2001 From: Allison Wang Date: Mon, 25 Nov 2024 19:18:03 +0800 Subject: [PATCH 1/2] init --- .gitignore | 2 +- demo.ipynb | 167 +++++++++++++++++++++++++++++ pyspark_huggingface/__init__.py | 1 + pyspark_huggingface/huggingface.py | 90 ++++++++++++++++ requirements.txt | 3 + tests/__init__.py | 0 tests/test_huggingface.py | 15 +++ 7 files changed, 277 insertions(+), 1 deletion(-) create mode 100644 demo.ipynb create mode 100644 pyspark_huggingface/__init__.py create mode 100644 pyspark_huggingface/huggingface.py create mode 100644 requirements.txt create mode 100644 tests/__init__.py create mode 100644 tests/test_huggingface.py diff --git a/.gitignore b/.gitignore index 82f9275..7b6caf3 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,4 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ diff --git a/demo.ipynb b/demo.ipynb new file mode 100644 index 0000000..b1a5bf7 --- /dev/null +++ b/demo.ipynb @@ -0,0 +1,167 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 11, + "id": "125a1871-6cab-4dc4-9fd5-4e5dbd63ada6", + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "38dc7e9e-35fd-4604-9be3-1a1a8749fbcb", + "metadata": {}, + "outputs": [], + "source": [ + "from pyspark_huggingface import HuggingFaceDatasets" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "620d3ecb-b9cb-480c-b300-69198cce7a9c", + "metadata": {}, + "outputs": [], + "source": [ + "from pyspark.sql import SparkSession" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "9255ffcb-0b61-43dc-b57a-2b8af01a8432", + "metadata": {}, + "outputs": [], + "source": [ + "spark = SparkSession.builder.getOrCreate()" + ] + }, + { + "cell_type": "code", + "id": "7c4501a8-26f4-4f52-9dc8-a70393d567b4", + "metadata": {}, + "source": [ + "spark.dataSource.register(HuggingFaceDatasets)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b8580bde-3f64-4c71-a087-8b3f71099aee", + "metadata": {}, + "outputs": [], + "source": [ + "df = spark.read.format(\"huggingface\").load(\"rotten_tomatoes\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "8866bdfb-0782-4430-8b1e-09c65e699f41", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Stage 5:> (0 + 1) / 1]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+--------------------+-----+\n", + "| text|label|\n", + "+--------------------+-----+\n", + "|the rock is desti...| 1|\n", + "|the gorgeously el...| 1|\n", + "|effective but too...| 1|\n", + "|if you sometimes ...| 1|\n", + "|emerges as someth...| 1|\n", + "|the film provides...| 1|\n", + "|offers that rare ...| 1|\n", + "|perhaps no pictur...| 1|\n", + "|steers turns in a...| 1|\n", + "|take care of my c...| 1|\n", + "|this is a film we...| 1|\n", + "|what really surpr...| 1|\n", + "|( wendigo is ) wh...| 1|\n", + "|one of the greate...| 1|\n", + "|ultimately , it p...| 1|\n", + "|an utterly compel...| 1|\n", + "|illuminating if o...| 1|\n", + "|a masterpiece fou...| 1|\n", + "|the movie's ripe ...| 1|\n", + "|offers a breath o...| 1|\n", + "+--------------------+-----+\n", + "only showing top 20 rows\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " " + ] + } + ], + "source": [ + "df.show()" + ] + }, + { + "cell_type": "code", + "id": "873bb4fc-1424-4816-b835-6c2b839d3de4", + "metadata": {}, + "source": [ + "df.count()" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a1b895f-fe20-4520-a90d-b17df8e691e4", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pyspark_huggingface", + "language": "python", + "name": "pyspark_huggingface" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyspark_huggingface/__init__.py b/pyspark_huggingface/__init__.py new file mode 100644 index 0000000..713350f --- /dev/null +++ b/pyspark_huggingface/__init__.py @@ -0,0 +1 @@ +from pyspark_huggingface.huggingface import HuggingFaceDatasets diff --git a/pyspark_huggingface/huggingface.py b/pyspark_huggingface/huggingface.py new file mode 100644 index 0000000..2746e75 --- /dev/null +++ b/pyspark_huggingface/huggingface.py @@ -0,0 +1,90 @@ +from pyspark.sql.datasource import DataSource, DataSourceReader +from pyspark.sql.types import StructField, StructType, StringType + + +# TODO: Use `DefaultSource` +class HuggingFaceDatasets(DataSource): + """ + A DataSource for reading and writing HuggingFace Datasets in Spark. + + This data source allows reading public datasets from the HuggingFace Hub directly into Spark + DataFrames. The schema is automatically inferred from the dataset features. The split can be + specified using the `split` option. The default split is `train`. + + Name: `huggingface` + + Notes: + ----- + - The HuggingFace `datasets` library is required to use this data source. Make sure it is installed. + - If the schema is automatically inferred, it will use string type for all fields. + - Currently it can only be used with public datasets. Private or gated ones are not supported. + + Examples + -------- + + Load a public dataset from the HuggingFace Hub. + + >>> spark.read.format("huggingface").load("imdb").show() + +--------------------+-----+ + | text|label| + +--------------------+-----+ + |I rented I AM CUR...| 0| + |"I Am Curious: Ye...| 0| + |... | ...| + +--------------------+-----+ + + Load a specific split from a public dataset from the HuggingFace Hub. + + >>> spark.read.format("huggingface").option("split", "test").load("imdb").show() + +--------------------+-----+ + | text|label| + +--------------------+-----+ + |I love sci-fi and...| 0| + |Worth the enterta...| 0| + |... | ...| + +--------------------+-----+ + """ + + def __init__(self, options): + super().__init__(options) + if "path" not in options or not options["path"]: + raise Exception("You must specify a dataset name in`.load()`.") + + @classmethod + def name(cls): + return "huggingface" + + def schema(self): + from datasets import load_dataset_builder + dataset_name = self.options["path"] + ds_builder = load_dataset_builder(dataset_name) + features = ds_builder.info.features + if features is None: + raise Exception( + "Unable to automatically determine the schema using the dataset features. " + "Please specify the schema manually using `.schema()`." + ) + schema = StructType() + for key, value in features.items(): + # For simplicity, use string for all values. + schema.add(StructField(key, StringType(), True)) + return schema + + def reader(self, schema: StructType) -> "DataSourceReader": + return HuggingFaceDatasetsReader(schema, self.options) + + +class HuggingFaceDatasetsReader(DataSourceReader): + def __init__(self, schema: StructType, options: dict): + self.schema = schema + self.dataset_name = options["path"] + # TODO: validate the split value. + self.split = options.get("split", "train") # Default using train split. + + def read(self, partition): + from datasets import load_dataset + columns = [field.name for field in self.schema.fields] + iter_dataset = load_dataset(self.dataset_name, split=self.split, streaming=True) + for example in iter_dataset: + # TODO: next spark 4.0.0 dev release will include the feature to yield as an iterator of pa.RecordBatch + yield tuple([example.get(column) for column in columns]) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..bd26d90 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +datasets==3.1.0 +pyspark[connect]==4.0.0.dev2 +pytest diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_huggingface.py b/tests/test_huggingface.py new file mode 100644 index 0000000..8c6b9b8 --- /dev/null +++ b/tests/test_huggingface.py @@ -0,0 +1,15 @@ +import pytest +from pyspark.sql import SparkSession +from pyspark_huggingface import HuggingFaceDatasets + + +@pytest.fixture +def spark(): + spark = SparkSession.builder.getOrCreate() + yield spark + + +def test_basic_load(spark): + spark.dataSource.register(HuggingFaceDatasets) + df = spark.read.format("huggingface").load("rotten_tomatoes") + assert df.count() == 8530 # length of the training dataset From e3412fc72298acea64db855ba27e2d8e16f60bdf Mon Sep 17 00:00:00 2001 From: Allison Wang Date: Tue, 26 Nov 2024 17:05:21 +0800 Subject: [PATCH 2/2] address comments --- demo.ipynb | 260 +++++++++++++++++++++-------- pyspark_huggingface/__init__.py | 2 +- pyspark_huggingface/huggingface.py | 33 ++-- 3 files changed, 215 insertions(+), 80 deletions(-) diff --git a/demo.ipynb b/demo.ipynb index b1a5bf7..48234e3 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -1,69 +1,113 @@ { "cells": [ { - "cell_type": "code", - "execution_count": 11, - "id": "125a1871-6cab-4dc4-9fd5-4e5dbd63ada6", + "cell_type": "markdown", + "id": "19b1960e-9e0a-401f-be15-d343902eaa21", "metadata": {}, - "outputs": [], "source": [ - "import warnings\n", - "warnings.filterwarnings('ignore')" + "# Spark HuggingFace Connector Demo" ] }, { - "cell_type": "code", - "execution_count": 2, - "id": "38dc7e9e-35fd-4604-9be3-1a1a8749fbcb", + "cell_type": "markdown", + "id": "c9a7bf1d-c208-4873-9e06-5db981f8eeaa", "metadata": {}, - "outputs": [], "source": [ - "from pyspark_huggingface import HuggingFaceDatasets" + "## Create a Spark Session" ] }, { "cell_type": "code", - "execution_count": 3, "id": "620d3ecb-b9cb-480c-b300-69198cce7a9c", "metadata": {}, + "source": [ + "from pyspark.sql import SparkSession\n", + "\n", + "spark = SparkSession.builder.getOrCreate()" + ], "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "6f876028-2af5-4e63-8e9d-59afc0959267", + "metadata": {}, "source": [ - "from pyspark.sql import SparkSession" + "## Load a dataset as a Spark DataFrame" ] }, { "cell_type": "code", - "execution_count": 12, - "id": "9255ffcb-0b61-43dc-b57a-2b8af01a8432", - "metadata": {}, + "execution_count": 2, + "id": "b8580bde-3f64-4c71-a087-8b3f71099aee", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-26T08:54:32.132099Z", + "start_time": "2024-11-26T08:54:28.903653Z" + } + }, "outputs": [], "source": [ - "spark = SparkSession.builder.getOrCreate()" + "df = spark.read.format(\"huggingface\").load(\"rotten_tomatoes\")" ] }, { "cell_type": "code", - "id": "7c4501a8-26f4-4f52-9dc8-a70393d567b4", + "execution_count": 4, + "id": "3bbf61d1-4c2c-40e7-9790-2722637aac9d", "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "root\n", + " |-- text: string (nullable = true)\n", + " |-- label: long (nullable = true)\n", + "\n" + ] + } + ], "source": [ - "spark.dataSource.register(HuggingFaceDatasets)" + "df.printSchema()" + ] + }, + { + "cell_type": "code", + "id": "7f7b9a2b-8733-499a-af56-3c51196d060f", + "metadata": {}, + "source": [ + "# Cache the dataframe to avoid re-downloading data\n", + "df.cache()" ], "outputs": [], "execution_count": null }, { "cell_type": "code", - "execution_count": 14, - "id": "b8580bde-3f64-4c71-a087-8b3f71099aee", + "execution_count": 12, + "id": "df121dba-2e1e-4206-b2bf-db156c298ee1", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "8530" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "df = spark.read.format(\"huggingface\").load(\"rotten_tomatoes\")" + "# Trigger the cache computation\n", + "df.count()" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 13, "id": "8866bdfb-0782-4430-8b1e-09c65e699f41", "metadata": { "editable": true, @@ -72,14 +116,132 @@ }, "tags": [] }, + "outputs": [ + { + "data": { + "text/plain": [ + "Row(text='the rock is destined to be the 21st century\\'s new \" conan \" and that he\\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .', label=1)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "id": "0d9d3112-d19b-4fa8-a6fc-ba40816d1d11", + "metadata": {}, + "source": [ + "df.show(n=5)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "225bbbef-4164-424d-a701-c6c74494ef81", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4265" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Then you can operate on this dataframe\n", + "df.filter(df.label == 0).count()" + ] + }, + { + "cell_type": "markdown", + "id": "3932f1fd-a324-4f15-86e1-bbe1064d707a", + "metadata": {}, + "source": [ + "## Load a different split\n", + "You can specify the `split` data source option:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "a16e9270-eb02-4568-8739-db4dc715c274", + "metadata": {}, + "outputs": [], + "source": [ + "test_df = (\n", + " spark.read.format(\"huggingface\")\n", + " .option(\"split\", \"test\")\n", + " .load(\"rotten_tomatoes\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "3aec5719-c3a1-4d18-92c8-2b0c2f4bb939", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DataFrame[text: string, label: bigint]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_df.cache()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "d605289d-361d-4a6c-9b70-f7ccdff3aa9d", + "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[Stage 5:> (0 + 1) / 1]" + " " ] }, + { + "data": { + "text/plain": [ + "1066" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_df.count()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "df1ad003-1476-4557-811b-31c3888c0030", + "metadata": {}, + "outputs": [ { "name": "stdout", "output_type": "stream", @@ -87,57 +249,25 @@ "+--------------------+-----+\n", "| text|label|\n", "+--------------------+-----+\n", - "|the rock is desti...| 1|\n", - "|the gorgeously el...| 1|\n", - "|effective but too...| 1|\n", - "|if you sometimes ...| 1|\n", - "|emerges as someth...| 1|\n", - "|the film provides...| 1|\n", - "|offers that rare ...| 1|\n", - "|perhaps no pictur...| 1|\n", - "|steers turns in a...| 1|\n", - "|take care of my c...| 1|\n", - "|this is a film we...| 1|\n", - "|what really surpr...| 1|\n", - "|( wendigo is ) wh...| 1|\n", - "|one of the greate...| 1|\n", - "|ultimately , it p...| 1|\n", - "|an utterly compel...| 1|\n", - "|illuminating if o...| 1|\n", - "|a masterpiece fou...| 1|\n", - "|the movie's ripe ...| 1|\n", - "|offers a breath o...| 1|\n", + "|lovingly photogra...| 1|\n", + "|consistently clev...| 1|\n", + "|it's like a \" big...| 1|\n", + "|the story gives a...| 1|\n", + "|red dragon \" neve...| 1|\n", "+--------------------+-----+\n", - "only showing top 20 rows\n", + "only showing top 5 rows\n", "\n" ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " " - ] } ], "source": [ - "df.show()" + "test_df.show(n=5)" ] }, - { - "cell_type": "code", - "id": "873bb4fc-1424-4816-b835-6c2b839d3de4", - "metadata": {}, - "source": [ - "df.count()" - ], - "outputs": [], - "execution_count": null - }, { "cell_type": "code", "execution_count": null, - "id": "4a1b895f-fe20-4520-a90d-b17df8e691e4", + "id": "a7f14b91-059e-4894-83d2-4ed74e0adaf9", "metadata": {}, "outputs": [], "source": [] diff --git a/pyspark_huggingface/__init__.py b/pyspark_huggingface/__init__.py index 713350f..ac09d17 100644 --- a/pyspark_huggingface/__init__.py +++ b/pyspark_huggingface/__init__.py @@ -1 +1 @@ -from pyspark_huggingface.huggingface import HuggingFaceDatasets +from pyspark_huggingface.huggingface import HuggingFaceDatasets as DefaultSource diff --git a/pyspark_huggingface/huggingface.py b/pyspark_huggingface/huggingface.py index 2746e75..42680c8 100644 --- a/pyspark_huggingface/huggingface.py +++ b/pyspark_huggingface/huggingface.py @@ -1,8 +1,7 @@ from pyspark.sql.datasource import DataSource, DataSourceReader -from pyspark.sql.types import StructField, StructType, StringType +from pyspark.sql.pandas.types import from_arrow_schema +from pyspark.sql.types import StructType - -# TODO: Use `DefaultSource` class HuggingFaceDatasets(DataSource): """ A DataSource for reading and writing HuggingFace Datasets in Spark. @@ -24,7 +23,10 @@ class HuggingFaceDatasets(DataSource): Load a public dataset from the HuggingFace Hub. - >>> spark.read.format("huggingface").load("imdb").show() + >>> df = spark.read.format("huggingface").load("imdb") + DataFrame[text: string, label: bigint] + + >>> df.show() +--------------------+-----+ | text|label| +--------------------+-----+ @@ -48,7 +50,7 @@ class HuggingFaceDatasets(DataSource): def __init__(self, options): super().__init__(options) if "path" not in options or not options["path"]: - raise Exception("You must specify a dataset name in`.load()`.") + raise Exception("You must specify a dataset name.") @classmethod def name(cls): @@ -64,27 +66,30 @@ def schema(self): "Unable to automatically determine the schema using the dataset features. " "Please specify the schema manually using `.schema()`." ) - schema = StructType() - for key, value in features.items(): - # For simplicity, use string for all values. - schema.add(StructField(key, StringType(), True)) - return schema + return from_arrow_schema(features.arrow_schema) def reader(self, schema: StructType) -> "DataSourceReader": return HuggingFaceDatasetsReader(schema, self.options) class HuggingFaceDatasetsReader(DataSourceReader): + DEFAULT_SPLIT: str = "train" + def __init__(self, schema: StructType, options: dict): self.schema = schema self.dataset_name = options["path"] - # TODO: validate the split value. - self.split = options.get("split", "train") # Default using train split. + # Get and validate the split name + self.split = options.get("split", self.DEFAULT_SPLIT) + from datasets import get_dataset_split_names + valid_splits = get_dataset_split_names(self.dataset_name) + if self.split not in valid_splits: + raise Exception(f"Split {self.split} is invalid. Valid options are {valid_splits}") def read(self, partition): from datasets import load_dataset columns = [field.name for field in self.schema.fields] + # TODO: add config iter_dataset = load_dataset(self.dataset_name, split=self.split, streaming=True) - for example in iter_dataset: + for data in iter_dataset: # TODO: next spark 4.0.0 dev release will include the feature to yield as an iterator of pa.RecordBatch - yield tuple([example.get(column) for column in columns]) + yield tuple([data.get(column) for column in columns])