From 5d6706694f9f5f7412a6a99bf249f8516b4a30ab Mon Sep 17 00:00:00 2001 From: polinaeterna Date: Mon, 13 Mar 2023 14:39:17 +0100 Subject: [PATCH] pass split to download_and_prepare --- src/datasets/builder.py | 11 ++++++++--- src/datasets/load.py | 1 + 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 9f0d947c78a..aaea13c474e 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -638,6 +638,7 @@ def download_and_prepare( max_shard_size: Optional[Union[int, str]] = None, num_proc: Optional[int] = None, storage_options: Optional[dict] = None, + split: Optional[str] = None, **download_and_prepare_kwargs, ): """Downloads and prepares dataset for reading. @@ -879,6 +880,9 @@ def incomplete_dir(dirname): logger.warning("HF google storage unreachable. Downloading and preparing it from source") if not downloaded_from_gcs: prepare_split_kwargs = {"file_format": file_format} + # if "split" in download_and_prepare_kwargs: + if split: + prepare_split_kwargs = {**prepare_split_kwargs, "split": split} if max_shard_size is not None: prepare_split_kwargs["max_shard_size"] = max_shard_size if num_proc is not None: @@ -955,8 +959,9 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_k """ # Generating data for all splits split_dict = SplitDict(dataset_name=self.name) - split_generators_kwargs = self._make_split_generators_kwargs(prepare_split_kwargs) - split_generators = self._split_generators(dl_manager, **split_generators_kwargs) + # split_generators_kwargs = self._make_split_generators_kwargs(prepare_split_kwargs) + split_generators = self._split_generators(dl_manager, **prepare_split_kwargs) + prepare_split_kwargs.pop("split", None) # Checksums verification if verification_mode == VerificationMode.ALL_CHECKS and dl_manager.record_checksums: @@ -1301,7 +1306,7 @@ def _download_post_processing_resources( return None @abc.abstractmethod - def _split_generators(self, dl_manager: DownloadManager): + def _split_generators(self, dl_manager: DownloadManager, **split_generators_kwargs): """Specify feature dictionary generators and dataset splits. This function returns a list of `SplitGenerator`s defining how to generate diff --git a/src/datasets/load.py b/src/datasets/load.py index 7ad84e78f3f..433dc12dbd9 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1785,6 +1785,7 @@ def load_dataset( verification_mode=verification_mode, try_from_hf_gcs=try_from_hf_gcs, num_proc=num_proc, + split=split, ) # Build dataset for splits