Skip to content

Commit

Permalink
Workaround for pytorch/pytorch#138575 distributed checkpoint loading …
Browse files Browse the repository at this point in the history
…bug (#647)

Upstream PyTorch has a bug (pytorch/pytorch#138575) with `dcp.load()`
that causes the objects pointed to by `state_dict` to diverge from the
objects in use by the train loop. Specifically, torch is supposed to
update in-place the Stateful elements in `state_dict`, however it also
replaces the references in `state_dict` with the newly loaded object,
resulting in the CheckpointManager instance having a stale
trainstate/optimizer/etc.

This causes an issue when attempting to load a model from a checkpoint
and subsequently save the checkpoint -- the loaded
trainstate/optimizer/etc are saved instead of the current values. As a
result, if a training run is pre-empted and resumed twice, the results
are subtly wrong.

Until pytorch/pytorch#138575 is merged, this should work around the
issue.

From @karan-dalal and I, thanks for all the work on TorchTitan! We're
glad to be able to contribute back upstream in return!

---------

Co-authored-by: Andrew Gu <[email protected]>
  • Loading branch information
arjvik and awgu authored Oct 23, 2024
1 parent b19456a commit 1060fea
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,15 @@ def load(self, step: int = -1) -> bool:

# We won't have optimizer states to load, if we are loading a seed checkpoint
states = {"model": self.states["model"]} if step == 0 else self.states
# PyTorch bug: (pytorch/pytorch#138575)
# dcp.load() replaces the values of stateful elements in `states` with new objects
# from loading the checkpoint, in addition to updating the states of the original
# objects from `states` in-place. This is a problem because the state_dict no longer
# refers to the objects being used in the train loop, meaning any future checkpoints
# will not include updates to these objects (such as updated optimizer states, etc.)
original_stateful_states = {
k: v for k, v in states.items() if isinstance(v, Stateful)
}
logger.info(f"Loading the checkpoint at step {step}.")
begin = time.monotonic()
dcp.load(
Expand All @@ -478,6 +487,9 @@ def load(self, step: int = -1) -> bool:
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
)
# bugfix from above: restore the original stateful objects,
# whose states were already updated in-place by dcp.load()
states.update(original_stateful_states)
return True

def _purge_stale_checkpoints(self):
Expand Down

0 comments on commit 1060fea

Please sign in to comment.