3
3
import logging
4
4
import pickle
5
5
import shutil
6
+ import traceback
6
7
from abc import ABCMeta , abstractmethod
7
8
from collections import defaultdict
8
9
from concurrent .futures import ProcessPoolExecutor , ThreadPoolExecutor , as_completed
46
47
47
48
from .aliases import PathOrStr
48
49
from .config import BaseConfig , ShardedCheckpointerType , TrainConfig
50
+ from .exceptions import OLMoCheckpointError
49
51
from .optim import Optimizer , fix_optim_state_dict
50
52
from .safetensors_util import safetensors_file_to_state_dict
51
53
from .torch_util import barrier , get_fs_local_rank , get_global_rank , get_world_size
52
54
from .util import (
55
+ _get_s3_client ,
53
56
default_thread_count ,
54
57
dir_is_empty ,
55
58
get_bytes_range ,
@@ -319,7 +322,10 @@ def __init__(
319
322
path ,
320
323
single_file_per_rank = single_file_per_rank ,
321
324
sync_files = sync_files ,
322
- thread_count = thread_count or default_thread_count (),
325
+ # NOTE: we default to 1 thread here instead of whatever `default_thread_count()`
326
+ # returns because uploading big checkpoint files with multiple threads causes
327
+ # boto3 to fail in weird ways.
328
+ thread_count = thread_count or 1 ,
323
329
per_thread_copy_ahead = per_thread_copy_ahead ,
324
330
)
325
331
self .upload_to = None if upload_to is None else upload_to .rstrip ("/" )
@@ -336,6 +342,12 @@ def write_data(
336
342
for write_result in fut .wait ():
337
343
files_to_upload .add (write_result .storage_data .relative_path )
338
344
345
+ # Create the global S3 client up front to work around a threading issue in boto.
346
+ if self .upload_to .startswith ("s3://" ):
347
+ _get_s3_client ("s3" )
348
+ elif self .upload_to .startswith ("r2://" ):
349
+ _get_s3_client ("r2" )
350
+
339
351
with ThreadPoolExecutor (max_workers = self .thread_count ) as executor :
340
352
futures = []
341
353
for fname in files_to_upload :
@@ -344,7 +356,13 @@ def write_data(
344
356
log .info (f"Uploading { source } to { target } ..." )
345
357
futures .append (executor .submit (upload , source , target , save_overwrite = self .save_overwrite ))
346
358
for f in as_completed (futures ):
347
- f .result ()
359
+ try :
360
+ f .result ()
361
+ except BaseException :
362
+ # NOTE: we might get an error here that can't be pickled, which causes a different failure
363
+ # later when PyTorch tries to reduce that error across ranks. So here we just make
364
+ # sure we're raising a simple error type that can be pickled.
365
+ raise OLMoCheckpointError (f"Original error:\n { traceback .format_exc ()} " )
348
366
return fut
349
367
350
368
def finish (self , metadata : Metadata , results : List [List [WriteResult ]]) -> None :
@@ -386,13 +404,26 @@ def _get_content_for_read(self, read_item: ReadItem) -> Tuple[ReadItem, bytes]:
386
404
return (read_item , content )
387
405
388
406
def read_data (self , plan : dist_cp .LoadPlan , planner : dist_cp .LoadPlanner ) -> Future [None ]:
407
+ # Create the global S3 client up front to work around a threading issue in boto.
408
+ if isinstance (self .path , str ):
409
+ if self .path .startswith ("s3://" ):
410
+ _get_s3_client ("s3" )
411
+ elif self .path .startswith ("r2://" ):
412
+ _get_s3_client ("r2" )
413
+
389
414
with ThreadPoolExecutor (max_workers = self .thread_count ) as executor :
390
415
read_item_content_futures = []
391
416
for read_item in plan .items :
392
417
read_item_content_futures .append (executor .submit (self ._get_content_for_read , read_item ))
393
418
read_item_content_results = []
394
419
for f in as_completed (read_item_content_futures ):
395
- read_item_content_results .append (f .result ())
420
+ try :
421
+ read_item_content_results .append (f .result ())
422
+ except BaseException :
423
+ # NOTE: we might get an error here that can't be pickled, which causes a different failure
424
+ # later when PyTorch tries to reduce that error across ranks. So here we just make
425
+ # sure we're raising a simple error type that can be pickled.
426
+ raise OLMoCheckpointError (f"Original error:\n { traceback .format_exc ()} " )
396
427
397
428
# Modified from `FileSystemReader.read_data()`
398
429
for read_item , content in read_item_content_results :
0 commit comments