3232import itertools
3333import multiprocessing
3434import os
35+ import time
3536from typing import Any , Dict , Optional , Union
3637
3738from absl import logging
@@ -108,9 +109,24 @@ class _ShardInfo:
108109 num_exceptions : int
109110
110111
112+ def _load_dataset (
113+ hf_builder : hf_datasets .DatasetBuilder ,
114+ split : str ,
115+ ) -> hf_datasets .Dataset :
116+ """Efficiently loads a HuggingFace iterable dataset from its builder."""
117+ if hf_builder .repo_id is None :
118+ return hf_builder .as_dataset (split = split )
119+ return hf_datasets .load_dataset (
120+ hf_builder .repo_id or hf_builder .cache_dir ,
121+ hf_builder .config_id ,
122+ split = split ,
123+ streaming = True ,
124+ )
125+
126+
111127def _write_shard (
112128 shard_spec : _ShardSpec ,
113- hf_builder ,
129+ hf_builder : hf_datasets . DatasetBuilder ,
114130 example_writer ,
115131 features : feature_lib .FeaturesDict ,
116132 ignore_hf_errors : bool ,
@@ -136,12 +152,19 @@ def _write_shard(
136152 def get_serialized_examples_iter ():
137153 nonlocal num_bytes
138154 nonlocal num_exceptions
139- dataset = hf_builder .as_dataset (
140- split = shard_spec .shard_split , run_post_process = False
155+ dataset = _load_dataset (
156+ hf_builder ,
157+ shard_spec .hf_split ,
141158 )
142- for i in range (shard_spec .num_examples ):
159+ dataset = iter (dataset )
160+ # Skipping the first `start_index` examples. `streaming=True` returns an
161+ # iterable dataset, so we cannot jump to a specific index. This is not too
162+ # costly because it takes <0.5 ms/element in the wikipedia dataset.
163+ for _ in range (shard_spec .start_index ):
164+ next (dataset )
165+ for _ in range (shard_spec .num_examples ):
143166 try :
144- hf_value = dataset [ i ]
167+ hf_value = next ( dataset )
145168 except Exception : # pylint: disable=broad-exception-caught
146169 num_exceptions += 1
147170 if ignore_hf_errors :
@@ -155,6 +178,7 @@ def get_serialized_examples_iter():
155178 num_bytes += len (serialized_example )
156179 yield serialized_example
157180
181+ start = time .time ()
158182 example_writer .write (
159183 os .fspath (shard_spec .path ),
160184 tqdm_utils .tqdm (
@@ -166,6 +190,11 @@ def get_serialized_examples_iter():
166190 mininterval = 1.0 ,
167191 ),
168192 )
193+ logging .info (
194+ 'Generated %s examples in %s seconds' ,
195+ shard_spec .num_examples ,
196+ time .time () - start ,
197+ )
169198
170199 return _ShardInfo (
171200 num_bytes = num_bytes ,
@@ -247,6 +276,7 @@ def __init__(
247276 self ._builder_config = self ._converted_builder_config
248277 self .generation_errors = []
249278 self ._ignore_hf_errors = ignore_hf_errors
279+ login_to_hf (self ._hf_hub_token )
250280
251281 @property
252282 def builder_config (self ) -> Optional [Any ]:
@@ -257,14 +287,6 @@ def _create_builder_config(
257287 ) -> Optional [dataset_builder .BuilderConfig ]:
258288 return self ._converted_builder_config
259289
260- @functools .lru_cache (maxsize = 1 )
261- def _hf_download_and_prepare (self ):
262- login_to_hf (self ._hf_hub_token )
263- self ._hf_builder .download_and_prepare (
264- num_proc = self ._hf_num_proc ,
265- verification_mode = self ._verification_mode ,
266- )
267-
268290 @property
269291 def _hf_info (self ) -> hf_datasets .DatasetInfo :
270292 """Retrieves the dataset info from the HuggingFace Datasets."""
@@ -278,11 +300,18 @@ def _hf_hub_info(self) -> huggingface_hub.hf_api.DatasetInfo:
278300 )
279301
280302 def _hf_features (self ) -> hf_datasets .Features :
281- if not self ._hf_info .features :
282- # We need to download and prepare the data to know its features.
283- self ._hf_download_and_prepare ()
284-
285- return self ._hf_info .features
303+ # Return the features from the builder info.
304+ if self ._hf_info .features :
305+ return self ._hf_info .features
306+ # Return the features from the first split.
307+ for split in self ._hf_info .splits :
308+ ds = _load_dataset (
309+ self ._hf_builder ,
310+ split ,
311+ )
312+ if hasattr (ds , 'info' ) and ds .info .features :
313+ return ds .info .features
314+ raise ValueError ('No features found in the dataset.' )
286315
287316 def _info (self ) -> dataset_info_lib .DatasetInfo :
288317 return dataset_info_lib .DatasetInfo (
@@ -309,7 +338,6 @@ def _generate_splits(
309338 ) -> Sequence [splits_lib .SplitInfo ]:
310339 """Prepares the dataset by writing to shards directly."""
311340 del dl_manager , download_config # Unused.
312- self ._hf_download_and_prepare ()
313341
314342 shard_specs_by_split : dict [str , Sequence [_ShardSpec ]] = {}
315343 for hf_split , hf_split_info in self ._hf_info .splits .items ():
0 commit comments