Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Workaround for pytorch/pytorch#138575 distributed checkpoint loading …
…bug (#647) Upstream PyTorch has a bug (pytorch/pytorch#138575) with `dcp.load()` that causes the objects pointed to by `state_dict` to diverge from the objects in use by the train loop. Specifically, torch is supposed to update in-place the Stateful elements in `state_dict`, however it also replaces the references in `state_dict` with the newly loaded object, resulting in the CheckpointManager instance having a stale trainstate/optimizer/etc. This causes an issue when attempting to load a model from a checkpoint and subsequently save the checkpoint -- the loaded trainstate/optimizer/etc are saved instead of the current values. As a result, if a training run is pre-empted and resumed twice, the results are subtly wrong. Until pytorch/pytorch#138575 is merged, this should work around the issue. From @karan-dalal and I, thanks for all the work on TorchTitan! We're glad to be able to contribute back upstream in return! --------- Co-authored-by: Andrew Gu <[email protected]>
- Loading branch information