diff --git a/README.md b/README.md
index ac65599..8de0f61 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,92 @@
-# pyspark_huggingface
-PySpark custom data source for Hugging Face Datasets
+
+
+
+
+
+
+
+
+
+
+
+# Spark Data Source for Hugging Face Datasets
+
+A Spark Data Source for accessing [🤗 Hugging Face Datasets](https://huggingface.co/datasets):
+
+- Stream datasets from Hugging Face as Spark DataFrames
+- Select subsets and splits, apply projection and predicate filters
+- Save Spark DataFrames as Parquet files to Hugging Face
+- Fully distributed
+- Authentication via `huggingface-cli login` or tokens
+- Compatible with Spark 4 (with auto-import)
+- Backport for Spark 3.5, 3.4 and 3.3
+
+## Installation
+
+```
+pip install pyspark_huggingface
+```
+
+## Usage
+
+Load a dataset (here [stanfordnlp/imdb](https://huggingface.co/datasets/stanfordnlp/imdb)):
+
+```python
+df = spark.read.format("huggingface").load("stanfordnlp/imdb")
+```
+
+Save to Hugging Face:
+
+```python
+# Login with huggingface-cli login
+df.write.format("huggingface").save("username/my_dataset")
+# Or pass a token manually
+df.write.format("huggingface").option("token", "hf_xxx").save("username/my_dataset")
+```
+
+## Advanced
+
+Select a split:
+
+```python
+test_df = (
+ spark.read.format("huggingface")
+ .option("split", "test")
+ .load("stanfordnlp/imdb")
+)
+```
+
+Select a subset/config:
+
+```python
+test_df = (
+ spark.read.format("huggingface")
+ .option("config", "sample-10BT")
+ .load("HuggingFaceFW/fineweb-edu")
+)
+```
+
+Filters columns and rows (especially efficient for Parquet datasets):
+
+```python
+df = (
+ spark.read.format("huggingface")
+ .option("filters", '[("language_score", ">", 0.99)]')
+ .option("columns", '["text", "language_score"]')
+ .load("HuggingFaceFW/fineweb-edu")
+)
+```
+
+## Backport
+
+While the Data Source API was introcuded in Spark 4, this package includes a backport for older versions.
+
+Importing `pyspark_huggingface` patches the PySpark reader and writer to add the "huggingface" data source. It is compatible with PySpark 3.5, 3.4 and 3.3:
+
+```python
+>>> import pyspark_huggingface
+huggingface datasource enabled for pyspark 3.x.x (backport from pyspark 4)
+```
+
+The import is only necessary on Spark 3.x to enable the backport.
+Spark 4 automatically imports `pyspark_huggingface` as soon as it is installed, and registers the "huggingface" data source.
diff --git a/demo.ipynb b/demo.ipynb
index acb1e49..eb031a0 100644
--- a/demo.ipynb
+++ b/demo.ipynb
@@ -1,36 +1,36 @@
{
"cells": [
{
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "8166a1c6bb7797bb",
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-27T07:34:34.171635Z",
"start_time": "2024-11-27T07:34:34.161464Z"
}
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
- ],
- "id": "8166a1c6bb7797bb",
- "outputs": [],
- "execution_count": 1
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "d277e88f7ea91092",
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-27T07:34:34.179778Z",
"start_time": "2024-11-27T07:34:34.174652Z"
}
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings('ignore')"
- ],
- "id": "d277e88f7ea91092",
- "outputs": [],
- "execution_count": 2
+ ]
},
{
"cell_type": "markdown",
@@ -50,8 +50,10 @@
},
{
"cell_type": "code",
+ "execution_count": null,
"id": "620d3ecb-b9cb-480c-b300-69198cce7a9c",
"metadata": {},
+ "outputs": [],
"source": [
"from pyspark.sql import SparkSession\n",
"\n",
@@ -60,9 +62,7 @@
" .config(\"spark.executor.memory\", \"20G\") \n",
" .getOrCreate()\n",
")"
- ],
- "outputs": [],
- "execution_count": null
+ ]
},
{
"cell_type": "markdown",
@@ -76,6 +76,7 @@
},
{
"cell_type": "code",
+ "execution_count": 2,
"id": "b8580bde-3f64-4c71-a087-8b3f71099aee",
"metadata": {
"ExecuteTime": {
@@ -83,14 +84,14 @@
"start_time": "2024-11-27T07:09:59.993537Z"
}
},
- "source": [
- "df = spark.read.format(\"huggingface\").load(\"rotten_tomatoes\")"
- ],
"outputs": [],
- "execution_count": 2
+ "source": [
+ "df = spark.read.format(\"huggingface\").load(\"cornell-movie-review-data/rotten_tomatoes\")"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 3,
"id": "3bbf61d1-4c2c-40e7-9790-2722637aac9d",
"metadata": {
"ExecuteTime": {
@@ -98,9 +99,6 @@
"start_time": "2024-11-27T07:10:11.695157Z"
}
},
- "source": [
- "df.printSchema()"
- ],
"outputs": [
{
"name": "stdout",
@@ -113,10 +111,13 @@
]
}
],
- "execution_count": 3
+ "source": [
+ "df.printSchema()"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 4,
"id": "7f7b9a2b-8733-499a-af56-3c51196d060f",
"metadata": {
"ExecuteTime": {
@@ -124,10 +125,6 @@
"start_time": "2024-11-27T07:10:52.415881Z"
}
},
- "source": [
- "# Cache the dataframe to avoid re-downloading data. Note this should be used for small datasets.\n",
- "df.cache()"
- ],
"outputs": [
{
"data": {
@@ -140,10 +137,14 @@
"output_type": "execute_result"
}
],
- "execution_count": 4
+ "source": [
+ "# Cache the dataframe to avoid re-downloading data. Note this should be used for small datasets.\n",
+ "df.cache()"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 5,
"id": "df121dba-2e1e-4206-b2bf-db156c298ee1",
"metadata": {
"ExecuteTime": {
@@ -151,10 +152,6 @@
"start_time": "2024-11-27T07:10:59.645232Z"
}
},
- "source": [
- "# Trigger the cache computation\n",
- "df.count()"
- ],
"outputs": [
{
"name": "stderr",
@@ -177,25 +174,26 @@
"output_type": "execute_result"
}
],
- "execution_count": 5
+ "source": [
+ "# Trigger the cache computation\n",
+ "df.count()"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 6,
"id": "8866bdfb-0782-4430-8b1e-09c65e699f41",
"metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-11-27T07:11:35.994254Z",
+ "start_time": "2024-11-27T07:11:35.931924Z"
+ },
"editable": true,
"slideshow": {
"slide_type": ""
},
- "tags": [],
- "ExecuteTime": {
- "end_time": "2024-11-27T07:11:35.994254Z",
- "start_time": "2024-11-27T07:11:35.931924Z"
- }
+ "tags": []
},
- "source": [
- "df.head()"
- ],
"outputs": [
{
"data": {
@@ -208,10 +206,13 @@
"output_type": "execute_result"
}
],
- "execution_count": 6
+ "source": [
+ "df.head()"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 7,
"id": "225bbbef-4164-424d-a701-c6c74494ef81",
"metadata": {
"ExecuteTime": {
@@ -219,10 +220,6 @@
"start_time": "2024-11-27T07:11:41.754692Z"
}
},
- "source": [
- "# Then you can operate on this dataframe\n",
- "df.filter(df.label == 0).count()"
- ],
"outputs": [
{
"data": {
@@ -235,16 +232,19 @@
"output_type": "execute_result"
}
],
- "execution_count": 7
+ "source": [
+ "# Then you can operate on this dataframe\n",
+ "df.filter(df.label == 0).count()"
+ ]
},
{
- "metadata": {},
"cell_type": "markdown",
+ "id": "bae9bc7f48526c36",
+ "metadata": {},
"source": [
"## Load a Dataset with a configuration/subset\n",
"Some datasets require explicitly specifying the config name. You can pass this as a data source option."
- ],
- "id": "bae9bc7f48526c36"
+ ]
},
{
"cell_type": "markdown",
@@ -257,6 +257,7 @@
},
{
"cell_type": "code",
+ "execution_count": 8,
"id": "a16e9270-eb02-4568-8739-db4dc715c274",
"metadata": {
"ExecuteTime": {
@@ -264,18 +265,18 @@
"start_time": "2024-11-27T07:11:54.300211Z"
}
},
+ "outputs": [],
"source": [
"test_df = (\n",
" spark.read.format(\"huggingface\")\n",
" .option(\"split\", \"test\")\n",
- " .load(\"rotten_tomatoes\")\n",
+ " .load(\"cornell-movie-review-data/rotten_tomatoes\")\n",
")"
- ],
- "outputs": [],
- "execution_count": 8
+ ]
},
{
"cell_type": "code",
+ "execution_count": 9,
"id": "3aec5719-c3a1-4d18-92c8-2b0c2f4bb939",
"metadata": {
"ExecuteTime": {
@@ -283,9 +284,6 @@
"start_time": "2024-11-27T07:12:02.817828Z"
}
},
- "source": [
- "test_df.cache()"
- ],
"outputs": [
{
"data": {
@@ -298,10 +296,13 @@
"output_type": "execute_result"
}
],
- "execution_count": 9
+ "source": [
+ "test_df.cache()"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 10,
"id": "d605289d-361d-4a6c-9b70-f7ccdff3aa9d",
"metadata": {
"ExecuteTime": {
@@ -309,9 +310,6 @@
"start_time": "2024-11-27T07:12:02.891782Z"
}
},
- "source": [
- "test_df.count()"
- ],
"outputs": [
{
"name": "stderr",
@@ -331,10 +329,13 @@
"output_type": "execute_result"
}
],
- "execution_count": 10
+ "source": [
+ "test_df.count()"
+ ]
},
{
"cell_type": "code",
+ "execution_count": 11,
"id": "df1ad003-1476-4557-811b-31c3888c0030",
"metadata": {
"ExecuteTime": {
@@ -342,9 +343,6 @@
"start_time": "2024-11-27T07:12:16.825661Z"
}
},
- "source": [
- "test_df.show(n=5)"
- ],
"outputs": [
{
"name": "stdout",
@@ -363,44 +361,45 @@
]
}
],
- "execution_count": 11
+ "source": [
+ "test_df.show(n=5)"
+ ]
},
{
- "metadata": {},
"cell_type": "markdown",
+ "id": "d8481e86aeb61aaf",
+ "metadata": {},
"source": [
"## Load a dataset with multiple shards\n",
"\n",
"This example is using the [amazon_popularity dataset](https://huggingface.co/datasets/fancyzhx/amazon_polarity) which has 4 shards (for train split)"
- ],
- "id": "d8481e86aeb61aaf"
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "43759f8c136366b8",
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-27T07:12:25.864047Z",
"start_time": "2024-11-27T07:12:16.919834Z"
}
},
- "cell_type": "code",
- "source": "df = spark.read.format(\"huggingface\").load(\"amazon_polarity\")",
- "id": "43759f8c136366b8",
"outputs": [],
- "execution_count": 12
+ "source": [
+ "df = spark.read.format(\"huggingface\").load(\"fancyzhx/amazon_polarity\")"
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "acccc2c299be9205",
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-27T07:13:04.733705Z",
"start_time": "2024-11-27T07:12:50.016560Z"
}
},
- "cell_type": "code",
- "source": [
- "# You can see there are 4 partitions, each correspond to one shard.\n",
- "df.rdd.getNumPartitions()"
- ],
- "id": "acccc2c299be9205",
"outputs": [
{
"data": {
@@ -413,42 +412,47 @@
"output_type": "execute_result"
}
],
- "execution_count": 13
+ "source": [
+ "# You can see there are 4 partitions, each correspond to one shard.\n",
+ "df.rdd.getNumPartitions()"
+ ]
},
{
- "metadata": {},
"cell_type": "markdown",
- "source": "",
- "id": "ae7ad16dfecf0e4c"
+ "id": "ae7ad16dfecf0e4c",
+ "metadata": {},
+ "source": []
},
{
- "metadata": {},
"cell_type": "markdown",
+ "id": "3587271d9a4f31ac",
+ "metadata": {},
"source": [
"## Load a dataset without streaming\n",
"\n",
"This is equivalent to `load_dataset(..., streaming=False)`"
- ],
- "id": "3587271d9a4f31ac"
+ ]
},
{
- "metadata": {},
"cell_type": "code",
- "source": "df = spark.read.format(\"huggingface\").option(\"streaming\", \"false\").load(\"imdb\")",
+ "execution_count": null,
"id": "5d319d7e93545788",
+ "metadata": {},
"outputs": [],
- "execution_count": null
+ "source": [
+ "df = spark.read.format(\"huggingface\").option(\"streaming\", \"false\").load(\"stanfordnlp/imdb\")"
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "bb4cdc5d8de427ea",
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-27T07:15:03.325628Z",
"start_time": "2024-11-27T07:14:39.711977Z"
}
},
- "cell_type": "code",
- "source": "df.show(n=5)",
- "id": "bb4cdc5d8de427ea",
"outputs": [
{
"name": "stderr",
@@ -481,18 +485,20 @@
]
}
],
- "execution_count": 15
+ "source": [
+ "df.show(n=5)"
+ ]
},
{
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "f444cb40b7ae5044",
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-27T07:17:07.400787Z",
"start_time": "2024-11-27T07:16:43.681788Z"
}
},
- "cell_type": "code",
- "source": "df.filter(df.label == 1).show(n=5)",
- "id": "f444cb40b7ae5044",
"outputs": [
{
"name": "stderr",
@@ -525,15 +531,17 @@
]
}
],
- "execution_count": 16
+ "source": [
+ "df.filter(df.label == 1).show(n=5)"
+ ]
},
{
- "metadata": {},
"cell_type": "code",
- "outputs": [],
"execution_count": null,
- "source": "",
- "id": "a2da54a5cefe1fa3"
+ "id": "a2da54a5cefe1fa3",
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
diff --git a/pyspark_huggingface/compat/__init__.py b/pyspark_huggingface/compat/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/pyspark_huggingface/compat/datasource.py b/pyspark_huggingface/compat/datasource.py
new file mode 100644
index 0000000..3a4abbc
--- /dev/null
+++ b/pyspark_huggingface/compat/datasource.py
@@ -0,0 +1,198 @@
+from typing import TYPE_CHECKING, Iterator, List, Optional, Union
+
+import pyspark
+
+
+try:
+ from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, DataSourceReader, DataSourceWriter, InputPartition, WriterCommitMessage
+except ImportError:
+ class DataSource:
+ def __init__(self, options):
+ self.options = options
+
+ class DataSourceArrowWriter:
+ ...
+
+ class DataSourceReader:
+ ...
+
+ class DataSourceWriter:
+ def __init__(self, options):
+ self.options = options
+
+ class InputPartition:
+ ...
+
+ class WriterCommitMessage:
+ ...
+
+
+ import logging
+ import os
+ import pickle
+ from functools import wraps
+
+ from pyspark.sql.readwriter import DataFrameReader as _DataFrameReader, DataFrameWriter as _DataFrameWriter
+
+ if TYPE_CHECKING:
+ from pyarrow import RecordBatch
+ from pyspark.sql.dataframe import DataFrame
+ from pyspark.sql.readwriter import PathOrPaths
+ from pyspark.sql.types import StructType
+ from pyspark.sql._typing import OptionalPrimitiveType
+
+
+ class _ArrowPickler:
+
+ def __init__(self, key: str):
+ from pyspark.sql.types import StructType, StructField, BinaryType
+
+ self.key = key
+ self.schema = StructType([StructField(self.key, BinaryType(), True)])
+
+ def dumps(self, obj):
+ return {self.key: pickle.dumps(obj)}
+
+ def loads(self, obj):
+ return pickle.loads(obj[self.key])
+
+ # Reader
+
+ def _read_in_arrow(batches: Iterator["RecordBatch"], arrow_pickler, hf_reader) -> Iterator["RecordBatch"]:
+ for batch in batches:
+ for record in batch.to_pylist():
+ partition = arrow_pickler.loads(record)
+ yield from hf_reader.read(partition)
+
+ _orig_reader_format = _DataFrameReader.format
+
+ @wraps(_orig_reader_format)
+ def _new_format(self: _DataFrameReader, source: str) -> _DataFrameReader:
+ self._format = source
+ return _orig_reader_format(self, source)
+
+ _DataFrameReader.format = _new_format
+
+ _orig_reader_option = _DataFrameReader.option
+
+ @wraps(_orig_reader_option)
+ def _new_option(self: _DataFrameReader, key, value) -> _DataFrameReader:
+ if not hasattr(self, "_options"):
+ self._options = {}
+ self._options[key] = value
+ return _orig_reader_option(self, key, value)
+
+ _DataFrameReader.option = _new_option
+
+ _orig_reader_options = _DataFrameReader.options
+
+ @wraps(_orig_reader_options)
+ def _new_options(self: _DataFrameReader, **options) -> _DataFrameReader:
+ if not hasattr(self, "_options"):
+ self._options = {}
+ self._options.update(options)
+ return _orig_reader_options(self, **options)
+
+ _DataFrameReader.options = _new_options
+
+ _orig_reader_load = _DataFrameReader.load
+
+ @wraps(_orig_reader_load)
+ def _new_load(
+ self: _DataFrameReader,
+ path: Optional["PathOrPaths"] = None,
+ format: Optional[str] = None,
+ schema: Optional[Union["StructType", str]] = None,
+ **options: "OptionalPrimitiveType",
+ ) -> "DataFrame":
+ if (format or getattr(self, "_format", None)) == "huggingface":
+ from functools import partial
+ from pyspark.sql import SparkSession
+ from pyspark_huggingface.huggingface import HuggingFaceDatasets
+
+ source = HuggingFaceDatasets(options={**getattr(self, "_options", {}), **options, "path": path}).get_source()
+ schema = schema or source.schema()
+ hf_reader = source.reader(schema)
+ partitions = hf_reader.partitions()
+ arrow_pickler = _ArrowPickler("partition")
+ spark = self._spark if isinstance(self._spark, SparkSession) else self._spark.sparkSession # _spark is SQLContext for older versions
+ rdd = spark.sparkContext.parallelize([arrow_pickler.dumps(partition) for partition in partitions], len(partitions))
+ df = spark.createDataFrame(rdd)
+ return df.mapInArrow(partial(_read_in_arrow, arrow_pickler=arrow_pickler, hf_reader=hf_reader), schema)
+
+ return _orig_reader_load(self, path=path, format=format, schema=schema, **options)
+
+ _DataFrameReader.load = _new_load
+
+ # Writer
+
+ def _write_in_arrow(batches: Iterator["RecordBatch"], arrow_pickler, hf_writer) -> Iterator["RecordBatch"]:
+ from pyarrow import RecordBatch
+
+ commit_message = hf_writer.write(batches)
+ yield RecordBatch.from_pylist([arrow_pickler.dumps(commit_message)])
+
+ _orig_writer_format = _DataFrameWriter.format
+
+ @wraps(_orig_writer_format)
+ def _new_format(self: _DataFrameWriter, source: str) -> _DataFrameWriter:
+ self._format = source
+ return _orig_writer_format(self, source)
+
+ _DataFrameWriter.format = _new_format
+
+ _orig_writer_option = _DataFrameWriter.option
+
+ @wraps(_orig_writer_option)
+ def _new_option(self: _DataFrameWriter, key, value) -> _DataFrameWriter:
+ if not hasattr(self, "_options"):
+ self._options = {}
+ self._options[key] = value
+ return _orig_writer_option(self, key, value)
+
+ _DataFrameWriter.option = _new_option
+
+ _orig_writer_options = _DataFrameWriter.options
+
+ @wraps(_orig_writer_options)
+ def _new_options(self: _DataFrameWriter, **options) -> _DataFrameWriter:
+ if not hasattr(self, "_options"):
+ self._options = {}
+ self._options.update(options)
+ return _orig_writer_options(self, **options)
+
+ _DataFrameWriter.options = _new_options
+
+ _orig_writer_save = _DataFrameWriter.save
+
+ @wraps(_orig_writer_save)
+ def _new_save(
+ self: _DataFrameWriter,
+ path: Optional["PathOrPaths"] = None,
+ format: Optional[str] = None,
+ mode: Optional[Union["StructType", str]] = None,
+ partitionBy: Optional[Union[str, List[str]]] = None,
+ **options: "OptionalPrimitiveType",
+ ) -> "DataFrame":
+ if (format or getattr(self, "_format", None)) == "huggingface":
+ from functools import partial
+ from pyspark_huggingface.huggingface import HuggingFaceDatasets
+
+ sink = HuggingFaceDatasets(options={**getattr(self, "_options", {}), **options, "path": path}).get_sink()
+ schema = self._df.schema
+ mode = options.pop("mode", None)
+ hf_writer = sink.writer(schema, overwrite=(mode == "overwrite"))
+ arrow_pickler = _ArrowPickler("commit_message")
+ commit_messages = self._df.mapInArrow(partial(_write_in_arrow, arrow_pickler=arrow_pickler, hf_writer=hf_writer), arrow_pickler.schema).collect()
+ commit_messages = [arrow_pickler.loads(commit_message) for commit_message in commit_messages]
+ hf_writer.commit(commit_messages)
+ return
+
+ return _orig_writer_save(self, path=path, format=format, schema=schema, **options)
+
+ _DataFrameWriter.save = _new_save
+
+ # Log only in driver
+
+ if not os.environ.get("SPARK_ENV_LOADED"):
+ logging.getLogger(__name__).warning(f"huggingface datasource enabled for pyspark {pyspark.__version__} (backport from pyspark 4)")
diff --git a/pyspark_huggingface/huggingface.py b/pyspark_huggingface/huggingface.py
index 784bb19..9b8084c 100644
--- a/pyspark_huggingface/huggingface.py
+++ b/pyspark_huggingface/huggingface.py
@@ -1,9 +1,9 @@
from typing import TYPE_CHECKING, Optional
-from pyspark.sql.datasource import DataSource
+from pyspark_huggingface.compat.datasource import DataSource
if TYPE_CHECKING:
- from pyspark.sql.datasource import DataSourceWriter, DataSourceReader
+ from pyspark_huggingface.compat.datasource import DataSourceWriter, DataSourceReader
from pyspark.sql.types import StructType
from pyspark_huggingface.huggingface_sink import HuggingFaceSink
diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py
index c572be3..edef431 100644
--- a/pyspark_huggingface/huggingface_sink.py
+++ b/pyspark_huggingface/huggingface_sink.py
@@ -3,12 +3,12 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
-from pyspark.sql.datasource import (
+from pyspark.sql.types import StructType
+from pyspark_huggingface.compat.datasource import (
DataSource,
DataSourceArrowWriter,
WriterCommitMessage,
)
-from pyspark.sql.types import StructType
if TYPE_CHECKING:
from huggingface_hub import (
@@ -66,12 +66,14 @@ def __init__(self, options):
if "path" not in options or not options["path"]:
raise Exception("You must specify a dataset name.")
+ from huggingface_hub import get_token
+
kwargs = dict(self.options)
- self.token = kwargs.pop("token")
self.repo_id = kwargs.pop("path")
self.path_in_repo = kwargs.pop("path_in_repo", None)
self.split = kwargs.pop("split", None)
self.revision = kwargs.pop("revision", None)
+ self.token = kwargs.pop("token", None) or get_token()
self.endpoint = kwargs.pop("endpoint", None)
for arg in kwargs:
if kwargs[arg].lower() == "true":
@@ -89,7 +91,7 @@ def __init__(self, options):
def name(cls):
return "huggingfacesink"
- def writer(self, schema: StructType, overwrite: bool) -> DataSourceArrowWriter:
+ def writer(self, schema: StructType, overwrite: bool) -> "HuggingFaceDatasetsWriter":
return HuggingFaceDatasetsWriter(
repo_id=self.repo_id,
path_in_repo=self.path_in_repo,
diff --git a/pyspark_huggingface/huggingface_source.py b/pyspark_huggingface/huggingface_source.py
index 9ec9907..9cd0f5a 100644
--- a/pyspark_huggingface/huggingface_source.py
+++ b/pyspark_huggingface/huggingface_source.py
@@ -2,9 +2,10 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Sequence
-from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql.pandas.types import from_arrow_schema
from pyspark.sql.types import StructType
+from pyspark_huggingface.compat.datasource import DataSource, DataSourceReader, InputPartition
+
if TYPE_CHECKING:
from datasets import DatasetBuilder, IterableDataset
@@ -33,7 +34,7 @@ class HuggingFaceSource(DataSource):
Load a public dataset from the HuggingFace Hub.
- >>> df = spark.read.format("huggingface").load("imdb")
+ >>> df = spark.read.format("huggingface").load("stanfordnlp/imdb")
DataFrame[text: string, label: bigint]
>>> df.show()
@@ -47,7 +48,7 @@ class HuggingFaceSource(DataSource):
Load a specific split from a public dataset from the HuggingFace Hub.
- >>> spark.read.format("huggingface").option("split", "test").load("imdb").show()
+ >>> spark.read.format("huggingface").option("split", "test").load("stanfordnlp/imdb").show()
+--------------------+-----+
| text|label|
+--------------------+-----+
@@ -81,13 +82,16 @@ def __init__(self, options):
if "path" not in options or not options["path"]:
raise Exception("You must specify a dataset name.")
+ from huggingface_hub import get_token
+
kwargs = dict(self.options)
self.dataset_name = kwargs.pop("path")
self.config_name = kwargs.pop("config", None)
self.split = kwargs.pop("split", self.DEFAULT_SPLIT)
self.revision = kwargs.pop("revision", None)
self.streaming = kwargs.pop("streaming", "true").lower() == "true"
- self.token = kwargs.pop("token", None)
+ self.token = kwargs.pop("token", None) or get_token()
+ self.endpoint = kwargs.pop("endpoint", None)
for arg in kwargs:
if kwargs[arg].lower() == "true":
kwargs[arg] = True
@@ -115,7 +119,7 @@ def __init__(self, options):
def _get_api(self):
from huggingface_hub import HfApi
- return HfApi(token=self.token, library_name="pyspark_huggingface")
+ return HfApi(token=self.token, endpoint=self.endpoint, library_name="pyspark_huggingface")
@classmethod
def name(cls):
@@ -124,7 +128,7 @@ def name(cls):
def schema(self):
return from_arrow_schema(self.streaming_dataset.features.arrow_schema)
- def reader(self, schema: StructType) -> "DataSourceReader":
+ def reader(self, schema: StructType) -> "HuggingFaceDatasetsReader":
return HuggingFaceDatasetsReader(
schema,
builder=self.builder,
@@ -148,7 +152,7 @@ def __init__(self, schema: StructType, builder: "DatasetBuilder", split: str, st
self.streaming_dataset = streaming_dataset
# Get and validate the split name
- def partitions(self) -> Sequence[InputPartition]:
+ def partitions(self) -> Sequence[Shard]:
if self.streaming_dataset:
return [Shard(index=i) for i in range(self.streaming_dataset.num_shards)]
else:
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index bd26d90..0000000
--- a/requirements.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-datasets==3.1.0
-pyspark[connect]==4.0.0.dev2
-pytest