diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 9f0d947c78a..aaea13c474e 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -638,6 +638,7 @@ def download_and_prepare( max_shard_size: Optional[Union[int, str]] = None, num_proc: Optional[int] = None, storage_options: Optional[dict] = None, + split: Optional[str] = None, **download_and_prepare_kwargs, ): """Downloads and prepares dataset for reading. @@ -879,6 +880,9 @@ def incomplete_dir(dirname): logger.warning("HF google storage unreachable. Downloading and preparing it from source") if not downloaded_from_gcs: prepare_split_kwargs = {"file_format": file_format} + # if "split" in download_and_prepare_kwargs: + if split: + prepare_split_kwargs = {**prepare_split_kwargs, "split": split} if max_shard_size is not None: prepare_split_kwargs["max_shard_size"] = max_shard_size if num_proc is not None: @@ -955,8 +959,9 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_k """ # Generating data for all splits split_dict = SplitDict(dataset_name=self.name) - split_generators_kwargs = self._make_split_generators_kwargs(prepare_split_kwargs) - split_generators = self._split_generators(dl_manager, **split_generators_kwargs) + # split_generators_kwargs = self._make_split_generators_kwargs(prepare_split_kwargs) + split_generators = self._split_generators(dl_manager, **prepare_split_kwargs) + prepare_split_kwargs.pop("split", None) # Checksums verification if verification_mode == VerificationMode.ALL_CHECKS and dl_manager.record_checksums: @@ -1301,7 +1306,7 @@ def _download_post_processing_resources( return None @abc.abstractmethod - def _split_generators(self, dl_manager: DownloadManager): + def _split_generators(self, dl_manager: DownloadManager, **split_generators_kwargs): """Specify feature dictionary generators and dataset splits. This function returns a list of `SplitGenerator`s defining how to generate diff --git a/src/datasets/load.py b/src/datasets/load.py index c076a3cc465..6b9b4c42b46 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1782,6 +1782,7 @@ def load_dataset( verification_mode=verification_mode, try_from_hf_gcs=try_from_hf_gcs, num_proc=num_proc, + split=split, ) # Build dataset for splits