Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 66 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ def __init__(
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
num_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
self.is_fsdp_xla_v1_enabled = self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled

def _activate_neftune(self, model):
r"""
Expand Down Expand Up @@ -2969,7 +2970,20 @@ def _save_rng_state(self, output_dir):
def _save_optimizer_and_scheduler(self, output_dir):
if is_torch_xla_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
if self.is_fsdp_xla_v1_enabled:
optm = {
"optimizer": self.optimizer.state_dict(),
"shard_metadata": self.model.get_shard_metadata(),
}
xm.save(
optm,
os.path.join(
output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
),
master_only=False,
)
else:
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
Expand Down Expand Up @@ -3047,11 +3061,26 @@ def _load_optimizer_and_scheduler(self, checkpoint):
)
)
)
checkpoint_file_exists = (
glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}"))
if self.is_fsdp_xla_v1_enabled
else checkpoint_file_exists
)
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states
if is_torch_xla_available():
# On TPU we have to take some extra precautions to properly load the states on the right device.
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
if self.is_fsdp_xla_v1_enabled:
optimizer_state = torch.load(
os.path.join(
checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
),
map_location="cpu",
)
# We only need `optimizer` when resuming from checkpoint
optimizer_state = optimizer_state["optimizer"]
else:
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
with warnings.catch_warnings(record=True) as caught_warnings:
lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
reissue_pt_warnings(caught_warnings)
Expand Down Expand Up @@ -3464,15 +3493,48 @@ def _save_tpu(self, output_dir: Optional[str] = None):
model = self.model
xm.mark_step()

if xm.is_master_ordinal():
if xm.is_master_ordinal(local=False):
os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
supported_classes = (PushToHubMixin,)
xm.rendezvous("saving_checkpoint")
if not isinstance(model, supported_classes):
if self.is_fsdp_xla_v1_enabled:
ckpt = {
"model": model.state_dict(),
"shard_metadata": model.get_shard_metadata(),
}
ckpt_path = os.path.join(
output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{WEIGHTS_NAME}"
)
# All ranks save sharded checkpoint
xm.save(ckpt, ckpt_path, master_only=False)
# Make sure all ranks have saved checkpoints
xm.rendezvous("save_full_checkpoints")
# Master save full checkpoint
if self.args.should_save:
from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints

full_state_dict, _ = consolidate_sharded_model_checkpoints(
ckpt_prefix=os.path.join(output_dir, ""),
ckpt_suffix=f"rank*-of-*-{WEIGHTS_NAME}",
save_model=False,
)
model = model.module.module
unwrapped_model = self.accelerator.unwrap_model(model)
if isinstance(unwrapped_model, supported_classes):
unwrapped_model.save_pretrained(
output_dir,
state_dict=full_state_dict,
save_function=xm.save,
safe_serialization=self.args.save_safetensors,
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME))
elif not isinstance(model, supported_classes):
if isinstance(self.accelerator.unwrap_model(model), supported_classes):
self.accelerator.unwrap_model(model).save_pretrained(
output_dir,
Expand Down