Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 59 additions & 30 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import contextlib
import functools
import glob
import inspect
import math
import os
Expand Down Expand Up @@ -1302,7 +1303,7 @@ def train(
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")

if resume_from_checkpoint is not None:
if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled():
self._load_from_checkpoint(resume_from_checkpoint)

# If model was re-initialized, put it on the right device and update self.model_wrapped
Expand Down Expand Up @@ -1401,6 +1402,9 @@ def _inner_training_loop(

model = self._wrap_model(self.model_wrapped)

if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
self._load_from_checkpoint(resume_from_checkpoint, model)

# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
Expand Down Expand Up @@ -1666,6 +1670,8 @@ def _inner_training_loop(
xm.rendezvous("load_best_model_at_end")
elif args.local_rank != -1:
dist.barrier()
elif is_sagemaker_mp_enabled():
smp.barrier()

self._load_best_model()

Expand All @@ -1688,7 +1694,12 @@ def _inner_training_loop(

return TrainOutput(self.state.global_step, train_loss, metrics)

def _load_from_checkpoint(self, resume_from_checkpoint):
def _load_from_checkpoint(self, resume_from_checkpoint, model=None):

if model is None:
model = self.model
strict_load = is_sagemaker_mp_enabled()

if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
):
Expand All @@ -1713,20 +1724,22 @@ def _load_from_checkpoint(self, resume_from_checkpoint):
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
load_result = self.model.load_state_dict(state_dict, strict=False)
self._issue_warnings_after_load(load_result)

load_result = model.load_state_dict(state_dict, strict=strict_load)
if not strict_load:
self._issue_warnings_after_load(load_result)
# release memory
del state_dict
else:
# We load the sharded checkpoint
load_result = load_sharded_checkpoint(self.model, resume_from_checkpoint, strict=False)
self._issue_warnings_after_load(load_result)
load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=strict_load)
if not strict_load:
self._issue_warnings_after_load(load_result)

def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")

best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
strict_load = is_sagemaker_mp_enabled()
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if os.path.exists(best_model_path):
if self.deepspeed:
# temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
Expand All @@ -1743,12 +1756,13 @@ def _load_best_model(self):
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
load_result = self.model.load_state_dict(state_dict, strict=False)
self._issue_warnings_after_load(load_result)
load_result = model.load_state_dict(state_dict, strict=strict_load)
if not strict_load:
self._issue_warnings_after_load(load_result)
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
# Best model is a sharded checkpoint
load_result = load_sharded_checkpoint(self.model, self.state.best_model_checkpoint, strict=False)
self._issue_warnings_after_load(load_result)
load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=strict_load)
if not strict_load:
self._issue_warnings_after_load(load_result)
else:
logger.warning(
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
Expand Down Expand Up @@ -1886,17 +1900,21 @@ def _save_checkpoint(self, model, trial, metrics=None):
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
if smp.rdp_rank() == 0:
# Consolidate the state dict on all processed of rdp_rank 0
opt_state_dict = self.optimizer.state_dict()
# Save it and the scheduler on the main process
if self.args.should_save:
torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
smp.barrier()
if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
smp.save(
opt_state_dict,
os.path.join(output_dir, OPTIMIZER_NAME),
partial=True,
v3=smp.state.cfg.shard_optimizer_state,
)
if self.args.should_save:
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
Expand Down Expand Up @@ -1945,6 +1963,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)

local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank
if local_rank == -1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
Expand All @@ -1967,9 +1986,12 @@ def _load_optimizer_and_scheduler(self, checkpoint):
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
return

if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile(
os.path.join(checkpoint, SCHEDULER_NAME)
):
checkpoint_file_exists = (
glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
if is_sagemaker_mp_enabled()
else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
)
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states
if is_torch_tpu_available():
# On TPU we have to take some extra precautions to properly load the states on the right device.
Expand All @@ -1985,9 +2007,16 @@ def _load_optimizer_and_scheduler(self, checkpoint):
self.lr_scheduler.load_state_dict(lr_scheduler_state)
else:
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
)
if is_sagemaker_mp_enabled():

def opt_load_hook(mod, opt):
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))

self.model_wrapped.register_post_step_hook(opt_load_hook)
else:
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
Expand Down