Skip to content

Commit

Permalink
Merge pull request #1071 from bghira/bugfix/state-tracker-unseen-lost
Browse files Browse the repository at this point in the history
bugfix: restore sampler state on rank 0 correctly
  • Loading branch information
bghira authored Oct 14, 2024
2 parents d0b5f37 + dd06086 commit 6947840
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,13 +1468,16 @@ def init_resume_checkpoint(self, lr_scheduler):
structured_data={"message": f"Resuming model: {path}"},
message_type="init_resume_checkpoint",
)
training_state_filename = f"training_state.json"
if get_rank() > 0:
training_state_filename = f"training_state-{get_rank()}.json"
for _, backend in StateTracker.get_data_backends().items():
if "sampler" in backend:
backend["sampler"].load_states(
state_path=os.path.join(
self.config.output_dir,
path,
f"training_state-{get_rank()}.json",
training_state_filename,
),
)
self.state["global_resume_step"] = self.state["global_step"] = (
Expand Down

0 comments on commit 6947840

Please sign in to comment.