diff --git a/README.md b/README.md index ac65599..8de0f61 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,92 @@ -# pyspark_huggingface -PySpark custom data source for Hugging Face Datasets +

+ Hugging Face x Spark +
+
+

+ +

+ GitHub release + Number of datasets +

+ +# Spark Data Source for Hugging Face Datasets + +A Spark Data Source for accessing [🤗 Hugging Face Datasets](https://huggingface.co/datasets): + +- Stream datasets from Hugging Face as Spark DataFrames +- Select subsets and splits, apply projection and predicate filters +- Save Spark DataFrames as Parquet files to Hugging Face +- Fully distributed +- Authentication via `huggingface-cli login` or tokens +- Compatible with Spark 4 (with auto-import) +- Backport for Spark 3.5, 3.4 and 3.3 + +## Installation + +``` +pip install pyspark_huggingface +``` + +## Usage + +Load a dataset (here [stanfordnlp/imdb](https://huggingface.co/datasets/stanfordnlp/imdb)): + +```python +df = spark.read.format("huggingface").load("stanfordnlp/imdb") +``` + +Save to Hugging Face: + +```python +# Login with huggingface-cli login +df.write.format("huggingface").save("username/my_dataset") +# Or pass a token manually +df.write.format("huggingface").option("token", "hf_xxx").save("username/my_dataset") +``` + +## Advanced + +Select a split: + +```python +test_df = ( + spark.read.format("huggingface") + .option("split", "test") + .load("stanfordnlp/imdb") +) +``` + +Select a subset/config: + +```python +test_df = ( + spark.read.format("huggingface") + .option("config", "sample-10BT") + .load("HuggingFaceFW/fineweb-edu") +) +``` + +Filters columns and rows (especially efficient for Parquet datasets): + +```python +df = ( + spark.read.format("huggingface") + .option("filters", '[("language_score", ">", 0.99)]') + .option("columns", '["text", "language_score"]') + .load("HuggingFaceFW/fineweb-edu") +) +``` + +## Backport + +While the Data Source API was introcuded in Spark 4, this package includes a backport for older versions. + +Importing `pyspark_huggingface` patches the PySpark reader and writer to add the "huggingface" data source. It is compatible with PySpark 3.5, 3.4 and 3.3: + +```python +>>> import pyspark_huggingface +huggingface datasource enabled for pyspark 3.x.x (backport from pyspark 4) +``` + +The import is only necessary on Spark 3.x to enable the backport. +Spark 4 automatically imports `pyspark_huggingface` as soon as it is installed, and registers the "huggingface" data source. diff --git a/demo.ipynb b/demo.ipynb index acb1e49..eb031a0 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -1,36 +1,36 @@ { "cells": [ { + "cell_type": "code", + "execution_count": 1, + "id": "8166a1c6bb7797bb", "metadata": { "ExecuteTime": { "end_time": "2024-11-27T07:34:34.171635Z", "start_time": "2024-11-27T07:34:34.161464Z" } }, - "cell_type": "code", + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" - ], - "id": "8166a1c6bb7797bb", - "outputs": [], - "execution_count": 1 + ] }, { + "cell_type": "code", + "execution_count": 2, + "id": "d277e88f7ea91092", "metadata": { "ExecuteTime": { "end_time": "2024-11-27T07:34:34.179778Z", "start_time": "2024-11-27T07:34:34.174652Z" } }, - "cell_type": "code", + "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings('ignore')" - ], - "id": "d277e88f7ea91092", - "outputs": [], - "execution_count": 2 + ] }, { "cell_type": "markdown", @@ -50,8 +50,10 @@ }, { "cell_type": "code", + "execution_count": null, "id": "620d3ecb-b9cb-480c-b300-69198cce7a9c", "metadata": {}, + "outputs": [], "source": [ "from pyspark.sql import SparkSession\n", "\n", @@ -60,9 +62,7 @@ " .config(\"spark.executor.memory\", \"20G\") \n", " .getOrCreate()\n", ")" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", @@ -76,6 +76,7 @@ }, { "cell_type": "code", + "execution_count": 2, "id": "b8580bde-3f64-4c71-a087-8b3f71099aee", "metadata": { "ExecuteTime": { @@ -83,14 +84,14 @@ "start_time": "2024-11-27T07:09:59.993537Z" } }, - "source": [ - "df = spark.read.format(\"huggingface\").load(\"rotten_tomatoes\")" - ], "outputs": [], - "execution_count": 2 + "source": [ + "df = spark.read.format(\"huggingface\").load(\"cornell-movie-review-data/rotten_tomatoes\")" + ] }, { "cell_type": "code", + "execution_count": 3, "id": "3bbf61d1-4c2c-40e7-9790-2722637aac9d", "metadata": { "ExecuteTime": { @@ -98,9 +99,6 @@ "start_time": "2024-11-27T07:10:11.695157Z" } }, - "source": [ - "df.printSchema()" - ], "outputs": [ { "name": "stdout", @@ -113,10 +111,13 @@ ] } ], - "execution_count": 3 + "source": [ + "df.printSchema()" + ] }, { "cell_type": "code", + "execution_count": 4, "id": "7f7b9a2b-8733-499a-af56-3c51196d060f", "metadata": { "ExecuteTime": { @@ -124,10 +125,6 @@ "start_time": "2024-11-27T07:10:52.415881Z" } }, - "source": [ - "# Cache the dataframe to avoid re-downloading data. Note this should be used for small datasets.\n", - "df.cache()" - ], "outputs": [ { "data": { @@ -140,10 +137,14 @@ "output_type": "execute_result" } ], - "execution_count": 4 + "source": [ + "# Cache the dataframe to avoid re-downloading data. Note this should be used for small datasets.\n", + "df.cache()" + ] }, { "cell_type": "code", + "execution_count": 5, "id": "df121dba-2e1e-4206-b2bf-db156c298ee1", "metadata": { "ExecuteTime": { @@ -151,10 +152,6 @@ "start_time": "2024-11-27T07:10:59.645232Z" } }, - "source": [ - "# Trigger the cache computation\n", - "df.count()" - ], "outputs": [ { "name": "stderr", @@ -177,25 +174,26 @@ "output_type": "execute_result" } ], - "execution_count": 5 + "source": [ + "# Trigger the cache computation\n", + "df.count()" + ] }, { "cell_type": "code", + "execution_count": 6, "id": "8866bdfb-0782-4430-8b1e-09c65e699f41", "metadata": { + "ExecuteTime": { + "end_time": "2024-11-27T07:11:35.994254Z", + "start_time": "2024-11-27T07:11:35.931924Z" + }, "editable": true, "slideshow": { "slide_type": "" }, - "tags": [], - "ExecuteTime": { - "end_time": "2024-11-27T07:11:35.994254Z", - "start_time": "2024-11-27T07:11:35.931924Z" - } + "tags": [] }, - "source": [ - "df.head()" - ], "outputs": [ { "data": { @@ -208,10 +206,13 @@ "output_type": "execute_result" } ], - "execution_count": 6 + "source": [ + "df.head()" + ] }, { "cell_type": "code", + "execution_count": 7, "id": "225bbbef-4164-424d-a701-c6c74494ef81", "metadata": { "ExecuteTime": { @@ -219,10 +220,6 @@ "start_time": "2024-11-27T07:11:41.754692Z" } }, - "source": [ - "# Then you can operate on this dataframe\n", - "df.filter(df.label == 0).count()" - ], "outputs": [ { "data": { @@ -235,16 +232,19 @@ "output_type": "execute_result" } ], - "execution_count": 7 + "source": [ + "# Then you can operate on this dataframe\n", + "df.filter(df.label == 0).count()" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "bae9bc7f48526c36", + "metadata": {}, "source": [ "## Load a Dataset with a configuration/subset\n", "Some datasets require explicitly specifying the config name. You can pass this as a data source option." - ], - "id": "bae9bc7f48526c36" + ] }, { "cell_type": "markdown", @@ -257,6 +257,7 @@ }, { "cell_type": "code", + "execution_count": 8, "id": "a16e9270-eb02-4568-8739-db4dc715c274", "metadata": { "ExecuteTime": { @@ -264,18 +265,18 @@ "start_time": "2024-11-27T07:11:54.300211Z" } }, + "outputs": [], "source": [ "test_df = (\n", " spark.read.format(\"huggingface\")\n", " .option(\"split\", \"test\")\n", - " .load(\"rotten_tomatoes\")\n", + " .load(\"cornell-movie-review-data/rotten_tomatoes\")\n", ")" - ], - "outputs": [], - "execution_count": 8 + ] }, { "cell_type": "code", + "execution_count": 9, "id": "3aec5719-c3a1-4d18-92c8-2b0c2f4bb939", "metadata": { "ExecuteTime": { @@ -283,9 +284,6 @@ "start_time": "2024-11-27T07:12:02.817828Z" } }, - "source": [ - "test_df.cache()" - ], "outputs": [ { "data": { @@ -298,10 +296,13 @@ "output_type": "execute_result" } ], - "execution_count": 9 + "source": [ + "test_df.cache()" + ] }, { "cell_type": "code", + "execution_count": 10, "id": "d605289d-361d-4a6c-9b70-f7ccdff3aa9d", "metadata": { "ExecuteTime": { @@ -309,9 +310,6 @@ "start_time": "2024-11-27T07:12:02.891782Z" } }, - "source": [ - "test_df.count()" - ], "outputs": [ { "name": "stderr", @@ -331,10 +329,13 @@ "output_type": "execute_result" } ], - "execution_count": 10 + "source": [ + "test_df.count()" + ] }, { "cell_type": "code", + "execution_count": 11, "id": "df1ad003-1476-4557-811b-31c3888c0030", "metadata": { "ExecuteTime": { @@ -342,9 +343,6 @@ "start_time": "2024-11-27T07:12:16.825661Z" } }, - "source": [ - "test_df.show(n=5)" - ], "outputs": [ { "name": "stdout", @@ -363,44 +361,45 @@ ] } ], - "execution_count": 11 + "source": [ + "test_df.show(n=5)" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "d8481e86aeb61aaf", + "metadata": {}, "source": [ "## Load a dataset with multiple shards\n", "\n", "This example is using the [amazon_popularity dataset](https://huggingface.co/datasets/fancyzhx/amazon_polarity) which has 4 shards (for train split)" - ], - "id": "d8481e86aeb61aaf" + ] }, { + "cell_type": "code", + "execution_count": 12, + "id": "43759f8c136366b8", "metadata": { "ExecuteTime": { "end_time": "2024-11-27T07:12:25.864047Z", "start_time": "2024-11-27T07:12:16.919834Z" } }, - "cell_type": "code", - "source": "df = spark.read.format(\"huggingface\").load(\"amazon_polarity\")", - "id": "43759f8c136366b8", "outputs": [], - "execution_count": 12 + "source": [ + "df = spark.read.format(\"huggingface\").load(\"fancyzhx/amazon_polarity\")" + ] }, { + "cell_type": "code", + "execution_count": 13, + "id": "acccc2c299be9205", "metadata": { "ExecuteTime": { "end_time": "2024-11-27T07:13:04.733705Z", "start_time": "2024-11-27T07:12:50.016560Z" } }, - "cell_type": "code", - "source": [ - "# You can see there are 4 partitions, each correspond to one shard.\n", - "df.rdd.getNumPartitions()" - ], - "id": "acccc2c299be9205", "outputs": [ { "data": { @@ -413,42 +412,47 @@ "output_type": "execute_result" } ], - "execution_count": 13 + "source": [ + "# You can see there are 4 partitions, each correspond to one shard.\n", + "df.rdd.getNumPartitions()" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "", - "id": "ae7ad16dfecf0e4c" + "id": "ae7ad16dfecf0e4c", + "metadata": {}, + "source": [] }, { - "metadata": {}, "cell_type": "markdown", + "id": "3587271d9a4f31ac", + "metadata": {}, "source": [ "## Load a dataset without streaming\n", "\n", "This is equivalent to `load_dataset(..., streaming=False)`" - ], - "id": "3587271d9a4f31ac" + ] }, { - "metadata": {}, "cell_type": "code", - "source": "df = spark.read.format(\"huggingface\").option(\"streaming\", \"false\").load(\"imdb\")", + "execution_count": null, "id": "5d319d7e93545788", + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "df = spark.read.format(\"huggingface\").option(\"streaming\", \"false\").load(\"stanfordnlp/imdb\")" + ] }, { + "cell_type": "code", + "execution_count": 15, + "id": "bb4cdc5d8de427ea", "metadata": { "ExecuteTime": { "end_time": "2024-11-27T07:15:03.325628Z", "start_time": "2024-11-27T07:14:39.711977Z" } }, - "cell_type": "code", - "source": "df.show(n=5)", - "id": "bb4cdc5d8de427ea", "outputs": [ { "name": "stderr", @@ -481,18 +485,20 @@ ] } ], - "execution_count": 15 + "source": [ + "df.show(n=5)" + ] }, { + "cell_type": "code", + "execution_count": 16, + "id": "f444cb40b7ae5044", "metadata": { "ExecuteTime": { "end_time": "2024-11-27T07:17:07.400787Z", "start_time": "2024-11-27T07:16:43.681788Z" } }, - "cell_type": "code", - "source": "df.filter(df.label == 1).show(n=5)", - "id": "f444cb40b7ae5044", "outputs": [ { "name": "stderr", @@ -525,15 +531,17 @@ ] } ], - "execution_count": 16 + "source": [ + "df.filter(df.label == 1).show(n=5)" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, - "source": "", - "id": "a2da54a5cefe1fa3" + "id": "a2da54a5cefe1fa3", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/pyspark_huggingface/compat/__init__.py b/pyspark_huggingface/compat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyspark_huggingface/compat/datasource.py b/pyspark_huggingface/compat/datasource.py new file mode 100644 index 0000000..3a4abbc --- /dev/null +++ b/pyspark_huggingface/compat/datasource.py @@ -0,0 +1,198 @@ +from typing import TYPE_CHECKING, Iterator, List, Optional, Union + +import pyspark + + +try: + from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, DataSourceReader, DataSourceWriter, InputPartition, WriterCommitMessage +except ImportError: + class DataSource: + def __init__(self, options): + self.options = options + + class DataSourceArrowWriter: + ... + + class DataSourceReader: + ... + + class DataSourceWriter: + def __init__(self, options): + self.options = options + + class InputPartition: + ... + + class WriterCommitMessage: + ... + + + import logging + import os + import pickle + from functools import wraps + + from pyspark.sql.readwriter import DataFrameReader as _DataFrameReader, DataFrameWriter as _DataFrameWriter + + if TYPE_CHECKING: + from pyarrow import RecordBatch + from pyspark.sql.dataframe import DataFrame + from pyspark.sql.readwriter import PathOrPaths + from pyspark.sql.types import StructType + from pyspark.sql._typing import OptionalPrimitiveType + + + class _ArrowPickler: + + def __init__(self, key: str): + from pyspark.sql.types import StructType, StructField, BinaryType + + self.key = key + self.schema = StructType([StructField(self.key, BinaryType(), True)]) + + def dumps(self, obj): + return {self.key: pickle.dumps(obj)} + + def loads(self, obj): + return pickle.loads(obj[self.key]) + + # Reader + + def _read_in_arrow(batches: Iterator["RecordBatch"], arrow_pickler, hf_reader) -> Iterator["RecordBatch"]: + for batch in batches: + for record in batch.to_pylist(): + partition = arrow_pickler.loads(record) + yield from hf_reader.read(partition) + + _orig_reader_format = _DataFrameReader.format + + @wraps(_orig_reader_format) + def _new_format(self: _DataFrameReader, source: str) -> _DataFrameReader: + self._format = source + return _orig_reader_format(self, source) + + _DataFrameReader.format = _new_format + + _orig_reader_option = _DataFrameReader.option + + @wraps(_orig_reader_option) + def _new_option(self: _DataFrameReader, key, value) -> _DataFrameReader: + if not hasattr(self, "_options"): + self._options = {} + self._options[key] = value + return _orig_reader_option(self, key, value) + + _DataFrameReader.option = _new_option + + _orig_reader_options = _DataFrameReader.options + + @wraps(_orig_reader_options) + def _new_options(self: _DataFrameReader, **options) -> _DataFrameReader: + if not hasattr(self, "_options"): + self._options = {} + self._options.update(options) + return _orig_reader_options(self, **options) + + _DataFrameReader.options = _new_options + + _orig_reader_load = _DataFrameReader.load + + @wraps(_orig_reader_load) + def _new_load( + self: _DataFrameReader, + path: Optional["PathOrPaths"] = None, + format: Optional[str] = None, + schema: Optional[Union["StructType", str]] = None, + **options: "OptionalPrimitiveType", + ) -> "DataFrame": + if (format or getattr(self, "_format", None)) == "huggingface": + from functools import partial + from pyspark.sql import SparkSession + from pyspark_huggingface.huggingface import HuggingFaceDatasets + + source = HuggingFaceDatasets(options={**getattr(self, "_options", {}), **options, "path": path}).get_source() + schema = schema or source.schema() + hf_reader = source.reader(schema) + partitions = hf_reader.partitions() + arrow_pickler = _ArrowPickler("partition") + spark = self._spark if isinstance(self._spark, SparkSession) else self._spark.sparkSession # _spark is SQLContext for older versions + rdd = spark.sparkContext.parallelize([arrow_pickler.dumps(partition) for partition in partitions], len(partitions)) + df = spark.createDataFrame(rdd) + return df.mapInArrow(partial(_read_in_arrow, arrow_pickler=arrow_pickler, hf_reader=hf_reader), schema) + + return _orig_reader_load(self, path=path, format=format, schema=schema, **options) + + _DataFrameReader.load = _new_load + + # Writer + + def _write_in_arrow(batches: Iterator["RecordBatch"], arrow_pickler, hf_writer) -> Iterator["RecordBatch"]: + from pyarrow import RecordBatch + + commit_message = hf_writer.write(batches) + yield RecordBatch.from_pylist([arrow_pickler.dumps(commit_message)]) + + _orig_writer_format = _DataFrameWriter.format + + @wraps(_orig_writer_format) + def _new_format(self: _DataFrameWriter, source: str) -> _DataFrameWriter: + self._format = source + return _orig_writer_format(self, source) + + _DataFrameWriter.format = _new_format + + _orig_writer_option = _DataFrameWriter.option + + @wraps(_orig_writer_option) + def _new_option(self: _DataFrameWriter, key, value) -> _DataFrameWriter: + if not hasattr(self, "_options"): + self._options = {} + self._options[key] = value + return _orig_writer_option(self, key, value) + + _DataFrameWriter.option = _new_option + + _orig_writer_options = _DataFrameWriter.options + + @wraps(_orig_writer_options) + def _new_options(self: _DataFrameWriter, **options) -> _DataFrameWriter: + if not hasattr(self, "_options"): + self._options = {} + self._options.update(options) + return _orig_writer_options(self, **options) + + _DataFrameWriter.options = _new_options + + _orig_writer_save = _DataFrameWriter.save + + @wraps(_orig_writer_save) + def _new_save( + self: _DataFrameWriter, + path: Optional["PathOrPaths"] = None, + format: Optional[str] = None, + mode: Optional[Union["StructType", str]] = None, + partitionBy: Optional[Union[str, List[str]]] = None, + **options: "OptionalPrimitiveType", + ) -> "DataFrame": + if (format or getattr(self, "_format", None)) == "huggingface": + from functools import partial + from pyspark_huggingface.huggingface import HuggingFaceDatasets + + sink = HuggingFaceDatasets(options={**getattr(self, "_options", {}), **options, "path": path}).get_sink() + schema = self._df.schema + mode = options.pop("mode", None) + hf_writer = sink.writer(schema, overwrite=(mode == "overwrite")) + arrow_pickler = _ArrowPickler("commit_message") + commit_messages = self._df.mapInArrow(partial(_write_in_arrow, arrow_pickler=arrow_pickler, hf_writer=hf_writer), arrow_pickler.schema).collect() + commit_messages = [arrow_pickler.loads(commit_message) for commit_message in commit_messages] + hf_writer.commit(commit_messages) + return + + return _orig_writer_save(self, path=path, format=format, schema=schema, **options) + + _DataFrameWriter.save = _new_save + + # Log only in driver + + if not os.environ.get("SPARK_ENV_LOADED"): + logging.getLogger(__name__).warning(f"huggingface datasource enabled for pyspark {pyspark.__version__} (backport from pyspark 4)") diff --git a/pyspark_huggingface/huggingface.py b/pyspark_huggingface/huggingface.py index 784bb19..9b8084c 100644 --- a/pyspark_huggingface/huggingface.py +++ b/pyspark_huggingface/huggingface.py @@ -1,9 +1,9 @@ from typing import TYPE_CHECKING, Optional -from pyspark.sql.datasource import DataSource +from pyspark_huggingface.compat.datasource import DataSource if TYPE_CHECKING: - from pyspark.sql.datasource import DataSourceWriter, DataSourceReader + from pyspark_huggingface.compat.datasource import DataSourceWriter, DataSourceReader from pyspark.sql.types import StructType from pyspark_huggingface.huggingface_sink import HuggingFaceSink diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index c572be3..edef431 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -3,12 +3,12 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Iterator, List, Optional, Union -from pyspark.sql.datasource import ( +from pyspark.sql.types import StructType +from pyspark_huggingface.compat.datasource import ( DataSource, DataSourceArrowWriter, WriterCommitMessage, ) -from pyspark.sql.types import StructType if TYPE_CHECKING: from huggingface_hub import ( @@ -66,12 +66,14 @@ def __init__(self, options): if "path" not in options or not options["path"]: raise Exception("You must specify a dataset name.") + from huggingface_hub import get_token + kwargs = dict(self.options) - self.token = kwargs.pop("token") self.repo_id = kwargs.pop("path") self.path_in_repo = kwargs.pop("path_in_repo", None) self.split = kwargs.pop("split", None) self.revision = kwargs.pop("revision", None) + self.token = kwargs.pop("token", None) or get_token() self.endpoint = kwargs.pop("endpoint", None) for arg in kwargs: if kwargs[arg].lower() == "true": @@ -89,7 +91,7 @@ def __init__(self, options): def name(cls): return "huggingfacesink" - def writer(self, schema: StructType, overwrite: bool) -> DataSourceArrowWriter: + def writer(self, schema: StructType, overwrite: bool) -> "HuggingFaceDatasetsWriter": return HuggingFaceDatasetsWriter( repo_id=self.repo_id, path_in_repo=self.path_in_repo, diff --git a/pyspark_huggingface/huggingface_source.py b/pyspark_huggingface/huggingface_source.py index 9ec9907..9cd0f5a 100644 --- a/pyspark_huggingface/huggingface_source.py +++ b/pyspark_huggingface/huggingface_source.py @@ -2,9 +2,10 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Sequence -from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition from pyspark.sql.pandas.types import from_arrow_schema from pyspark.sql.types import StructType +from pyspark_huggingface.compat.datasource import DataSource, DataSourceReader, InputPartition + if TYPE_CHECKING: from datasets import DatasetBuilder, IterableDataset @@ -33,7 +34,7 @@ class HuggingFaceSource(DataSource): Load a public dataset from the HuggingFace Hub. - >>> df = spark.read.format("huggingface").load("imdb") + >>> df = spark.read.format("huggingface").load("stanfordnlp/imdb") DataFrame[text: string, label: bigint] >>> df.show() @@ -47,7 +48,7 @@ class HuggingFaceSource(DataSource): Load a specific split from a public dataset from the HuggingFace Hub. - >>> spark.read.format("huggingface").option("split", "test").load("imdb").show() + >>> spark.read.format("huggingface").option("split", "test").load("stanfordnlp/imdb").show() +--------------------+-----+ | text|label| +--------------------+-----+ @@ -81,13 +82,16 @@ def __init__(self, options): if "path" not in options or not options["path"]: raise Exception("You must specify a dataset name.") + from huggingface_hub import get_token + kwargs = dict(self.options) self.dataset_name = kwargs.pop("path") self.config_name = kwargs.pop("config", None) self.split = kwargs.pop("split", self.DEFAULT_SPLIT) self.revision = kwargs.pop("revision", None) self.streaming = kwargs.pop("streaming", "true").lower() == "true" - self.token = kwargs.pop("token", None) + self.token = kwargs.pop("token", None) or get_token() + self.endpoint = kwargs.pop("endpoint", None) for arg in kwargs: if kwargs[arg].lower() == "true": kwargs[arg] = True @@ -115,7 +119,7 @@ def __init__(self, options): def _get_api(self): from huggingface_hub import HfApi - return HfApi(token=self.token, library_name="pyspark_huggingface") + return HfApi(token=self.token, endpoint=self.endpoint, library_name="pyspark_huggingface") @classmethod def name(cls): @@ -124,7 +128,7 @@ def name(cls): def schema(self): return from_arrow_schema(self.streaming_dataset.features.arrow_schema) - def reader(self, schema: StructType) -> "DataSourceReader": + def reader(self, schema: StructType) -> "HuggingFaceDatasetsReader": return HuggingFaceDatasetsReader( schema, builder=self.builder, @@ -148,7 +152,7 @@ def __init__(self, schema: StructType, builder: "DatasetBuilder", split: str, st self.streaming_dataset = streaming_dataset # Get and validate the split name - def partitions(self) -> Sequence[InputPartition]: + def partitions(self) -> Sequence[Shard]: if self.streaming_dataset: return [Shard(index=i) for i in range(self.streaming_dataset.num_shards)] else: diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index bd26d90..0000000 --- a/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -datasets==3.1.0 -pyspark[connect]==4.0.0.dev2 -pytest