diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 25b7b6993092..055919ce1184 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -703,6 +703,7 @@ def __init__( # Tensor axis is just a placeholder where it will not be used in FSDPv2. num_devices = xr.global_runtime_device_count() xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor"))) + self.is_fsdp_xla_v1_enabled = self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled def _activate_neftune(self, model): r""" @@ -2969,7 +2970,20 @@ def _save_rng_state(self, output_dir): def _save_optimizer_and_scheduler(self, output_dir): if is_torch_xla_available(): xm.rendezvous("saving_optimizer_states") - xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + if self.is_fsdp_xla_v1_enabled: + optm = { + "optimizer": self.optimizer.state_dict(), + "shard_metadata": self.model.get_shard_metadata(), + } + xm.save( + optm, + os.path.join( + output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" + ), + master_only=False, + ) + else: + xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) with warnings.catch_warnings(record=True) as caught_warnings: xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) @@ -3047,11 +3061,26 @@ def _load_optimizer_and_scheduler(self, checkpoint): ) ) ) + checkpoint_file_exists = ( + glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}")) + if self.is_fsdp_xla_v1_enabled + else checkpoint_file_exists + ) if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states if is_torch_xla_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. - optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") + if self.is_fsdp_xla_v1_enabled: + optimizer_state = torch.load( + os.path.join( + checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" + ), + map_location="cpu", + ) + # We only need `optimizer` when resuming from checkpoint + optimizer_state = optimizer_state["optimizer"] + else: + optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") with warnings.catch_warnings(record=True) as caught_warnings: lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") reissue_pt_warnings(caught_warnings) @@ -3464,7 +3493,7 @@ def _save_tpu(self, output_dir: Optional[str] = None): model = self.model xm.mark_step() - if xm.is_master_ordinal(): + if xm.is_master_ordinal(local=False): os.makedirs(output_dir, exist_ok=True) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) @@ -3472,7 +3501,40 @@ def _save_tpu(self, output_dir: Optional[str] = None): # They can then be reloaded using `from_pretrained()` supported_classes = (PushToHubMixin,) xm.rendezvous("saving_checkpoint") - if not isinstance(model, supported_classes): + if self.is_fsdp_xla_v1_enabled: + ckpt = { + "model": model.state_dict(), + "shard_metadata": model.get_shard_metadata(), + } + ckpt_path = os.path.join( + output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{WEIGHTS_NAME}" + ) + # All ranks save sharded checkpoint + xm.save(ckpt, ckpt_path, master_only=False) + # Make sure all ranks have saved checkpoints + xm.rendezvous("save_full_checkpoints") + # Master save full checkpoint + if self.args.should_save: + from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints + + full_state_dict, _ = consolidate_sharded_model_checkpoints( + ckpt_prefix=os.path.join(output_dir, ""), + ckpt_suffix=f"rank*-of-*-{WEIGHTS_NAME}", + save_model=False, + ) + model = model.module.module + unwrapped_model = self.accelerator.unwrap_model(model) + if isinstance(unwrapped_model, supported_classes): + unwrapped_model.save_pretrained( + output_dir, + state_dict=full_state_dict, + save_function=xm.save, + safe_serialization=self.args.save_safetensors, + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + elif not isinstance(model, supported_classes): if isinstance(self.accelerator.unwrap_model(model), supported_classes): self.accelerator.unwrap_model(model).save_pretrained( output_dir,