diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index a24fd8947..bb6e7ee93 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -340,7 +340,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True else: try: iteration = state_dict['iteration'] - args.consumed_train_tokens = state_dict['tokens'] + if 'tokens' in state_dict: + args.consumed_train_tokens = state_dict['tokens'] except KeyError: try: # Backward compatible with older checkpoints iteration = state_dict['total_iters'] diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py index 5435b60b4..ae1fcdb2b 100644 --- a/megatron/learning_rates.py +++ b/megatron/learning_rates.py @@ -185,6 +185,8 @@ def load_state_dict(self, sd): num_steps = sd['num_iters'] else: num_steps = sd['num_steps'] - self.warmup_tokens = sd['warmup_tokens'] - self.num_tokens = sd['num_tokens'] + if 'warmup_tokens' in sd: + self.warmup_tokens = sd['warmup_tokens'] + if 'num_tokens' in sd: + self.num_tokens = sd['num_tokens'] self.step(num_steps, self.num_tokens)