Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 64 additions & 22 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import contextlib
import glob
import inspect
import math
import os
Expand Down Expand Up @@ -1193,7 +1194,13 @@ def train(
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")

if resume_from_checkpoint is not None:
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
# SMP partial checkpoints are in {filename}_{pp_rank()}_{tp_rank()} or {filename}_{pp_rank()}_{tp_rank()}_{rdp_rank()} format.
checkpoint_file_exists = (
glob.glob(os.path.join(resume_from_checkpoint, WEIGHTS_NAME) + "_*")
if is_sagemaker_mp_enabled()
else os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME))
)
Comment on lines +1198 to +1202
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used several times, could we refactor it in a util function that takes the filename?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do that.

Comment on lines +1198 to +1202
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used several times, could we refactor it in a util function that takes the filename?

if not checkpoint_file_exists:
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

logger.info(f"Loading model from {resume_from_checkpoint}).")
Expand All @@ -1211,6 +1218,9 @@ def train(
if args.deepspeed:
# will be resumed in deepspeed_init
pass
elif is_sagemaker_mp_enabled():
# will be resumed after model is wrapped
pass
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
Expand Down Expand Up @@ -1299,6 +1309,10 @@ def train(

model = self._wrap_model(self.model_wrapped)

if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
state_dict = smp.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), partial=True)
model.load_state_dict(state_dict)

# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
Expand Down Expand Up @@ -1561,13 +1575,19 @@ def train(
xm.rendezvous("load_best_model_at_end")
elif args.local_rank != -1:
dist.barrier()
elif is_sagemaker_mp_enabled():
smp.barrier()

logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)

best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
if os.path.exists(best_model_path):
checkpoint_file_exists = (
glob.glob(best_model_path + "_*") if is_sagemaker_mp_enabled() else os.path.exists(best_model_path)
)

if checkpoint_file_exists:
if self.deepspeed:
# temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self)
Expand All @@ -1579,6 +1599,9 @@ def train(
self.deepspeed.load_checkpoint(
self.state.best_model_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
)
elif is_sagemaker_mp_enabled():
state_dict = smp.load(best_model_path, partial=True)
model.load_state_dict(state_dict)
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
Expand Down Expand Up @@ -1741,17 +1764,20 @@ def _save_checkpoint(self, model, trial, metrics=None):
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
if smp.rdp_rank() == 0:
# Consolidate the state dict on all processed of rdp_rank 0
opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process
if self.args.should_save:
torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
if self.args.should_save or smp.state.cfg.shard_optimizer_state:
smp.save(
opt_state_dict,
os.path.join(output_dir, OPTIMIZER_NAME),
partial=True,
v3=smp.state.cfg.shard_optimizer_state,
)
if self.args.should_save:
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
Expand Down Expand Up @@ -1822,9 +1848,12 @@ def _load_optimizer_and_scheduler(self, checkpoint):
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
return

if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile(
os.path.join(checkpoint, SCHEDULER_NAME)
):
checkpoint_file_exists = (
glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
if is_sagemaker_mp_enabled()
else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
)
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states
if is_torch_tpu_available():
# On TPU we have to take some extra precautions to properly load the states on the right device.
Expand All @@ -1840,9 +1869,16 @@ def _load_optimizer_and_scheduler(self, checkpoint):
self.lr_scheduler.load_state_dict(lr_scheduler_state)
else:
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
)
if is_sagemaker_mp_enabled():

def opt_load_hook(mod, opt):
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))

self.model_wrapped.register_post_step_hook(opt_load_hook)
else:
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
Expand Down Expand Up @@ -2114,7 +2150,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
self._save_tpu(output_dir)
elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes.
state_dict = self.model_wrapped.state_dict()
state_dict = self.model_wrapped.local_state_dict()
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif (
Expand Down Expand Up @@ -2202,9 +2238,15 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if state_dict is None:
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
if is_sagemaker_mp_enabled():
smp.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME), partial=True)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir, state_dict=state_dict)
if is_sagemaker_mp_enabled():
smp.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME), partial=True)
else:
self.model.save_pretrained(output_dir, state_dict=state_dict)
Comment on lines +2246 to +2249
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No calling save_pretrained here means the config will not be saved and the checkpoint won't be able to be loaded with from_pretrained independently of the training. It's not a regular checkpoint anyway, so maybe it's okay. Flagging this here anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SMP checkpoints are saved partially hence we do not want to shard SMP checkpoints. In order to use save_pretrained for SMP, we need to skip shard_checkpoint for SMP. Independent ofmax_shard_size shard_checkpoint is called and we hit errors in shard_checkpoint since SMP checkpoints are different. If shard_checkpoint can be optional in save_pretrained, we can use save_pretrained with save_function=smp.save. In my previous PR I tried to skip shard_checkpoint for SMP, but feedback was not change save_pretrained.

from_pretrained won't work for SMP models. We are working on how to support fine-tuning. In this PR, I added support partial checkpoint saving/loading during training.

if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,7 +1219,7 @@ def should_save(self):
return self.local_process_index == 0
else:
if is_sagemaker_mp_enabled():
return smp.rank() == 0
return smp.rdp_rank() == 0
else:
return self.process_index == 0

Expand Down