Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ def download_and_prepare(
max_shard_size: Optional[Union[int, str]] = None,
num_proc: Optional[int] = None,
storage_options: Optional[dict] = None,
split: Optional[str] = None,
**download_and_prepare_kwargs,
):
"""Downloads and prepares dataset for reading.
Expand Down Expand Up @@ -879,6 +880,9 @@ def incomplete_dir(dirname):
logger.warning("HF google storage unreachable. Downloading and preparing it from source")
if not downloaded_from_gcs:
prepare_split_kwargs = {"file_format": file_format}
# if "split" in download_and_prepare_kwargs:
if split:
prepare_split_kwargs = {**prepare_split_kwargs, "split": split}
if max_shard_size is not None:
prepare_split_kwargs["max_shard_size"] = max_shard_size
if num_proc is not None:
Expand Down Expand Up @@ -955,8 +959,9 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_k
"""
# Generating data for all splits
split_dict = SplitDict(dataset_name=self.name)
split_generators_kwargs = self._make_split_generators_kwargs(prepare_split_kwargs)
split_generators = self._split_generators(dl_manager, **split_generators_kwargs)
# split_generators_kwargs = self._make_split_generators_kwargs(prepare_split_kwargs)
split_generators = self._split_generators(dl_manager, **prepare_split_kwargs)
prepare_split_kwargs.pop("split", None)

# Checksums verification
if verification_mode == VerificationMode.ALL_CHECKS and dl_manager.record_checksums:
Expand Down Expand Up @@ -1301,7 +1306,7 @@ def _download_post_processing_resources(
return None

@abc.abstractmethod
def _split_generators(self, dl_manager: DownloadManager):
def _split_generators(self, dl_manager: DownloadManager, **split_generators_kwargs):
"""Specify feature dictionary generators and dataset splits.

This function returns a list of `SplitGenerator`s defining how to generate
Expand Down
1 change: 1 addition & 0 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,7 @@ def load_dataset(
verification_mode=verification_mode,
try_from_hf_gcs=try_from_hf_gcs,
num_proc=num_proc,
split=split,
)

# Build dataset for splits
Expand Down