Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,32 +1361,40 @@ def save_pretrained(
if ignore_key in state_dict.keys():
del state_dict[ignore_key]

# Shard the model if it is too big.
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
from .utils import is_sagemaker_mp_enabled

# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename):
os.remove(full_filename)
if is_sagemaker_mp_enabled():
# Do not shard checkpoints when sagemaker model parallel is enabled
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
save_function(state_dict, output_model_file)
logger.info(f"Model weights saved in {output_model_file}")
else:
# Shard the model if it is too big.
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size)

# Save the model
for shard_file, shard in shards.items():
save_function(shard, os.path.join(save_directory, shard_file))
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename):
os.remove(full_filename)

if index is None:
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
else:
save_index_file = os.path.join(save_directory, WEIGHTS_INDEX_NAME)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
# Save the model
for shard_file, shard in shards.items():
save_function(shard, os.path.join(save_directory, shard_file))

if index is None:
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
else:
save_index_file = os.path.join(save_directory, WEIGHTS_INDEX_NAME)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)

if push_to_hub:
url = self._push_to_hub(repo, commit_message=commit_message)
Expand Down
122 changes: 95 additions & 27 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,16 +836,20 @@ def create_optimizer(self):
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if is_sagemaker_mp_enabled():
opt_model = self.model_wrapped
else:
opt_model = self.model
if self.optimizer is None:
decay_parameters = get_parameter_names(self.model, [nn.LayerNorm])
decay_parameters = get_parameter_names(opt_model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
"params": [p for n, p in opt_model.named_parameters() if n in decay_parameters],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
"params": [p for n, p in opt_model.named_parameters() if n not in decay_parameters],
"weight_decay": 0.0,
},
]
Expand Down Expand Up @@ -1032,6 +1036,8 @@ def _wrap_model(self, model, training=True):
# Wrapping the base model twice in a DistributedModel will raise an error.
if isinstance(self.model_wrapped, smp.model.DistributedModel):
return self.model_wrapped
if self.args.smp_tensor_parallel_full_model:
smp.set_tensor_parallelism(model)
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)

# already initialized its own DDP and AMP
Expand Down Expand Up @@ -1166,7 +1172,15 @@ def train(
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")

if resume_from_checkpoint is not None:
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
if self.args.smp_load_partial:
import glob

# SMP partial checkpoints are in {filename}_{pp_rank()}_{tp_rank()} or {filename}_{pp_rank()}_{tp_rank()}_{rdp_rank()} format.
checkpoint_file_exists = glob.glob(os.path.join(resume_from_checkpoint, WEIGHTS_NAME) + "_*")
else:
checkpoint_file_exists = os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME))

if not checkpoint_file_exists:
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

logger.info(f"Loading model from {resume_from_checkpoint}).")
Expand All @@ -1185,13 +1199,14 @@ def train(
# will be resumed in deepspeed_init
pass
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
if not is_sagemaker_mp_enabled():
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)

# release memory
del state_dict
# release memory
del state_dict

# If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded:
Expand Down Expand Up @@ -1272,6 +1287,12 @@ def train(

model = self._wrap_model(self.model_wrapped)

if resume_from_checkpoint is not None and is_sagemaker_mp_enabled():
state_dict = smp.load(
os.path.join(resume_from_checkpoint, WEIGHTS_NAME), partial=self.args.smp_load_partial
)
model.load_state_dict(state_dict)

# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
Expand Down Expand Up @@ -1534,13 +1555,21 @@ def train(
xm.rendezvous("load_best_model_at_end")
elif args.local_rank != -1:
dist.barrier()
elif is_sagemaker_mp_enabled():
smp.barrier()

logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)

best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
if os.path.exists(best_model_path):
if self.args.smp_load_partial:
import glob

# SMP partial checkpoints are in {filename}_{pp_rank()}_{tp_rank()} or {filename}_{pp_rank()}_{tp_rank()}_{rdp_rank()} format.
smp_checkpoint_file_exists = glob.glob(best_model_path + "_*")

if os.path.exists(best_model_path) or smp_checkpoint_file_exists:
if self.deepspeed:
# temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self)
Expand All @@ -1554,9 +1583,13 @@ def train(
)
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
if is_sagemaker_mp_enabled():
state_dict = smp.load(best_model_path, partial=self.args.smp_load_partial)
model.load_state_dict(state_dict)
else:
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
else:
logger.warning(
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
Expand Down Expand Up @@ -1714,12 +1747,20 @@ def _save_checkpoint(self, model, trial, metrics=None):
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
if smp.rdp_rank() == 0:
# Consolidate the state dict on all processed of rdp_rank 0
# Consolidate the state dict on all processes
if self.args.smp_save_partial:
opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
else:
opt_state_dict = self.optimizer.state_dict()
if self.args.should_save or smp.state.cfg.shard_optimizer_state:
smp.save(
opt_state_dict,
os.path.join(output_dir, OPTIMIZER_NAME),
partial=self.args.smp_save_partial,
v3=smp.state.cfg.shard_optimizer_state,
)
# Save it and the scheduler on the main process
if self.args.should_save:
torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
Expand Down Expand Up @@ -1795,9 +1836,15 @@ def _load_optimizer_and_scheduler(self, checkpoint):
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
return

if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile(
os.path.join(checkpoint, SCHEDULER_NAME)
):
if self.args.smp_load_partial:
import glob

# SMP partial checkpoints are in {filename}_{pp_rank()}_{tp_rank()} or {filename}_{pp_rank()}_{tp_rank()}_{rdp_rank()} format.
checkpoint_file_exists = glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
else:
checkpoint_file_exists = os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))

if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states
if is_torch_tpu_available():
# On TPU we have to take some extra precautions to properly load the states on the right device.
Expand All @@ -1813,9 +1860,18 @@ def _load_optimizer_and_scheduler(self, checkpoint):
self.lr_scheduler.load_state_dict(lr_scheduler_state)
else:
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
)
if is_sagemaker_mp_enabled():

def opt_load_hook(mod, opt):
opt.load_state_dict(
smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=self.args.smp_load_partial)
)

self.model_wrapped.register_post_step_hook(opt_load_hook)
else:
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
Expand Down Expand Up @@ -2087,7 +2143,10 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
self._save_tpu(output_dir)
elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes.
state_dict = self.model_wrapped.state_dict()
if self.args.smp_save_partial:
state_dict = self.model_wrapped.local_state_dict()
else:
state_dict = self.model_wrapped.state_dict()
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif (
Expand Down Expand Up @@ -2175,9 +2234,15 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if state_dict is None:
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
if self.args.smp_save_partial:
smp.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME), partial=self.args.smp_save_partial)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir, state_dict=state_dict)
if self.args.smp_save_partial:
self.model.save_pretrained(output_dir, save_function=smp.save, state_dict=state_dict)
else:
self.model.save_pretrained(output_dir, state_dict=state_dict)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)

Expand Down Expand Up @@ -2592,7 +2657,10 @@ def _nested_gather(self, tensors, name=None):
name = "nested_gather"
tensors = nested_xla_mesh_reduce(tensors, name)
elif is_sagemaker_mp_enabled():
tensors = smp_gather(tensors)
if smp.state.cfg.ddp:
tensors = distributed_concat(tensors.cuda())
else:
tensors = smp_gather(tensors)
elif self.args.local_rank != -1:
tensors = distributed_concat(tensors)
return tensors
Expand Down
37 changes: 36 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,15 @@ class TrainingArguments:
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
that need inputs, predictions and references for scoring calculation in Metric class.
smp_save_partial (`bool`, *optional*, defaults to `False`):
If True, saves checkpoints partially. `"smp_save_partial"` can only be used with Sagemaker Model Parallel
library.
smp_load_partial (`bool`, *optional*, defaults to `False`):
If True, loads partial checkpoints. `"smp_load_partial"` can only be used with Sagemaker Model Parallel
library.
smp_tensor_parallel_full_model (`bool`, *optional*, defaults to `False`):
If True, apply tensor paralellism for the full model. `"smp_tensor_parallel_full_model"` can only be used
with Sagemaker Model Parallel library.
"""

output_dir: str = field(
Expand Down Expand Up @@ -750,6 +759,12 @@ class TrainingArguments:
include_inputs_for_metrics: bool = field(
default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
)
smp_save_partial: bool = field(default=False, metadata={"help": "Save checkpoints partially for SMP."})
smp_load_partial: bool = field(default=False, metadata={"help": "Load partial checkpoints for SMP."})
smp_tensor_parallel_full_model: bool = field(
default=False, metadata={"help": "Enables tensor parallelism for full model in SMP."}
)

# Deprecated arguments
fp16_backend: str = field(
default="auto",
Expand Down Expand Up @@ -985,6 +1000,23 @@ def __post_init__(self):
FutureWarning,
)

if (self.smp_save_partial or self.smp_load_partial) and not is_sagemaker_mp_enabled():
raise ValueError(
f"smp_save_partial and smp_load_partial can only be used with Sagemaker Model Parallel library."
)

if is_sagemaker_mp_enabled() and not self.smp_save_partial and smp.state.cfg.shard_optimizer_state:
warnings.warn(
"Optimizer sharding can only be used with partial checkpointing. Switching to partial checkpointing."
)
self.smp_save_partial = True

if is_sagemaker_mp_enabled() and self.smp_save_partial and self.load_best_model_at_end:
self.smp_load_partial = True

if is_sagemaker_mp_enabled() and (not self.smp_save_partial) and (self.save_strategy != IntervalStrategy.NO):
warnings.warn("Saving weights but not the optimizer state.")

def __str__(self):
self_as_dict = asdict(self)

Expand Down Expand Up @@ -1218,7 +1250,10 @@ def should_save(self):
return self.local_process_index == 0
else:
if is_sagemaker_mp_enabled():
return smp.rank() == 0
if self.smp_save_partial:
return smp.rdp_rank() == 0
else:
return smp.rank() == 0
else:
return self.process_index == 0

Expand Down