Skip to content

Commit 50da7a4

Browse files
committed
Add work-arounds for new-style checkpointing issues
1 parent 6d42d7a commit 50da7a4

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

olmo/checkpoint.py

+34-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import pickle
55
import shutil
6+
import traceback
67
from abc import ABCMeta, abstractmethod
78
from collections import defaultdict
89
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
@@ -46,10 +47,12 @@
4647

4748
from .aliases import PathOrStr
4849
from .config import BaseConfig, ShardedCheckpointerType, TrainConfig
50+
from .exceptions import OLMoCheckpointError
4951
from .optim import Optimizer, fix_optim_state_dict
5052
from .safetensors_util import safetensors_file_to_state_dict
5153
from .torch_util import barrier, get_fs_local_rank, get_global_rank, get_world_size
5254
from .util import (
55+
_get_s3_client,
5356
default_thread_count,
5457
dir_is_empty,
5558
get_bytes_range,
@@ -319,7 +322,10 @@ def __init__(
319322
path,
320323
single_file_per_rank=single_file_per_rank,
321324
sync_files=sync_files,
322-
thread_count=thread_count or default_thread_count(),
325+
# NOTE: we default to 1 thread here instead of whatever `default_thread_count()`
326+
# returns because uploading big checkpoint files with multiple threads causes
327+
# boto3 to fail in weird ways.
328+
thread_count=thread_count or 1,
323329
per_thread_copy_ahead=per_thread_copy_ahead,
324330
)
325331
self.upload_to = None if upload_to is None else upload_to.rstrip("/")
@@ -336,6 +342,12 @@ def write_data(
336342
for write_result in fut.wait():
337343
files_to_upload.add(write_result.storage_data.relative_path)
338344

345+
# Create the global S3 client up front to work around a threading issue in boto.
346+
if self.upload_to.startswith("s3://"):
347+
_get_s3_client("s3")
348+
elif self.upload_to.startswith("r2://"):
349+
_get_s3_client("r2")
350+
339351
with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
340352
futures = []
341353
for fname in files_to_upload:
@@ -344,7 +356,13 @@ def write_data(
344356
log.info(f"Uploading {source} to {target}...")
345357
futures.append(executor.submit(upload, source, target, save_overwrite=self.save_overwrite))
346358
for f in as_completed(futures):
347-
f.result()
359+
try:
360+
f.result()
361+
except BaseException:
362+
# NOTE: we might get an error here that can't be pickled, which causes a different failure
363+
# later when PyTorch tries to reduce that error across ranks. So here we just make
364+
# sure we're raising a simple error type that can be pickled.
365+
raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
348366
return fut
349367

350368
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
@@ -386,13 +404,26 @@ def _get_content_for_read(self, read_item: ReadItem) -> Tuple[ReadItem, bytes]:
386404
return (read_item, content)
387405

388406
def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Future[None]:
407+
# Create the global S3 client up front to work around a threading issue in boto.
408+
if isinstance(self.path, str):
409+
if self.path.startswith("s3://"):
410+
_get_s3_client("s3")
411+
elif self.path.startswith("r2://"):
412+
_get_s3_client("r2")
413+
389414
with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
390415
read_item_content_futures = []
391416
for read_item in plan.items:
392417
read_item_content_futures.append(executor.submit(self._get_content_for_read, read_item))
393418
read_item_content_results = []
394419
for f in as_completed(read_item_content_futures):
395-
read_item_content_results.append(f.result())
420+
try:
421+
read_item_content_results.append(f.result())
422+
except BaseException:
423+
# NOTE: we might get an error here that can't be pickled, which causes a different failure
424+
# later when PyTorch tries to reduce that error across ranks. So here we just make
425+
# sure we're raising a simple error type that can be pickled.
426+
raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
396427

397428
# Modified from `FileSystemReader.read_data()`
398429
for read_item, content in read_item_content_results:

olmo/exceptions.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
__all__ = ["OLMoError", "OLMoConfigurationError", "OLMoCliError", "OLMoEnvironmentError", "OLMoNetworkError"]
1+
__all__ = [
2+
"OLMoError",
3+
"OLMoConfigurationError",
4+
"OLMoCliError",
5+
"OLMoEnvironmentError",
6+
"OLMoNetworkError",
7+
"OLMoCheckpointError",
8+
]
29

310

411
class OLMoError(Exception):
@@ -31,6 +38,12 @@ class OLMoNetworkError(OLMoError):
3138
"""
3239

3340

41+
class OLMoCheckpointError(OLMoError):
42+
"""
43+
An error occurred reading or writing from a checkpoint.
44+
"""
45+
46+
3447
class OLMoThreadError(Exception):
3548
"""
3649
Raised when a thread fails.

0 commit comments

Comments
 (0)