From ba1571488d545c5048cca4ab64dd3acd142f2e39 Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Mon, 29 Jul 2024 16:39:06 +0800 Subject: [PATCH 1/6] Support save/load ckpt for XLA FSDP --- src/transformers/trainer.py | 47 +++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 25b7b6993092..fc2c08ea3114 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2969,7 +2969,14 @@ def _save_rng_state(self, output_dir): def _save_optimizer_and_scheduler(self, output_dir): if is_torch_xla_available(): xm.rendezvous("saving_optimizer_states") - xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + if self.is_fsdp_xla_enabled: + xm.save( + self.optimizer.state_dict(), + os.path.join(output_dir, f"rank{self.args.process_index}_{OPTIMIZER_NAME}"), + master_only=False + ) + else: + xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) with warnings.catch_warnings(record=True) as caught_warnings: xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) @@ -3047,11 +3054,19 @@ def _load_optimizer_and_scheduler(self, checkpoint): ) ) ) + checkpoint_file_exists = ( + glob.glob(os.path.join(checkpoint, "rank*_" + OPTIMIZER_NAME)) + if self.is_fsdp_xla_enabled else checkpoint_file_exists + ) if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states if is_torch_xla_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. - optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") + if self.is_fsdp_xla_enabled: + optimizer_state = torch.load(os.path.join( + checkpoint, f"rank{self.args.process_index}_{OPTIMIZER_NAME}"), map_location="cpu") + else: + optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") with warnings.catch_warnings(record=True) as caught_warnings: lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") reissue_pt_warnings(caught_warnings) @@ -3464,7 +3479,7 @@ def _save_tpu(self, output_dir: Optional[str] = None): model = self.model xm.mark_step() - if xm.is_master_ordinal(): + if xm.is_master_ordinal(local=False): os.makedirs(output_dir, exist_ok=True) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) @@ -3472,7 +3487,31 @@ def _save_tpu(self, output_dir: Optional[str] = None): # They can then be reloaded using `from_pretrained()` supported_classes = (PushToHubMixin,) xm.rendezvous("saving_checkpoint") - if not isinstance(model, supported_classes): + if self.is_fsdp_xla_enabled: + ckpt = { + "model": model.state_dict(), + "shard_metadata": model.get_shard_metadata(), + } + ckpt_path = os.path.join( + output_dir, + f"rank{self.args.process_index}_of_{self.args.world_size}_{WEIGHTS_NAME}.pth" + ) + # All ranks save sharded checkpoint + xm.save(ckpt, ckpt_path, master_only=False) + # Make sure all ranks have saved checkpoints + xm.rendezvous("save_full_checkpoints") + # Master save full checkpoint + if xm.is_master_ordinal(local=False): + from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints + full_state_dict, _ = consolidate_sharded_model_checkpoints( + ckpt_prefix=os.path.join(output_dir, ""), + ckpt_suffix=f"rank*_of_*_{WEIGHTS_NAME}.pth", + save_model=False) + torch.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + # Remove temporary sharded checkpoints + xm.rendezvous("remove_unused_checkpoints") + os.remove(ckpt_path) + elif not isinstance(model, supported_classes): if isinstance(self.accelerator.unwrap_model(model), supported_classes): self.accelerator.unwrap_model(model).save_pretrained( output_dir, From 9910e855b65fa233e8a96169c5ba3064aa98d6a5 Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Mon, 29 Jul 2024 20:04:17 +0800 Subject: [PATCH 2/6] Fix bug for save --- src/transformers/trainer.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index fc2c08ea3114..8e953d81c3bf 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2969,9 +2969,13 @@ def _save_rng_state(self, output_dir): def _save_optimizer_and_scheduler(self, output_dir): if is_torch_xla_available(): xm.rendezvous("saving_optimizer_states") - if self.is_fsdp_xla_enabled: + if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled: + optm = { + "optimizer": self.optimizer.state_dict(), + "shard_metadata": self.model.get_shard_metadata(), + } xm.save( - self.optimizer.state_dict(), + optm, os.path.join(output_dir, f"rank{self.args.process_index}_{OPTIMIZER_NAME}"), master_only=False ) @@ -3056,15 +3060,18 @@ def _load_optimizer_and_scheduler(self, checkpoint): ) checkpoint_file_exists = ( glob.glob(os.path.join(checkpoint, "rank*_" + OPTIMIZER_NAME)) - if self.is_fsdp_xla_enabled else checkpoint_file_exists + if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled + else checkpoint_file_exists ) if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states if is_torch_xla_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. - if self.is_fsdp_xla_enabled: + if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled: optimizer_state = torch.load(os.path.join( checkpoint, f"rank{self.args.process_index}_{OPTIMIZER_NAME}"), map_location="cpu") + # We only need `optimizer` when resuming from checkpoint + optimizer_state = optimizer_state["optimizer"] else: optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") with warnings.catch_warnings(record=True) as caught_warnings: @@ -3487,7 +3494,7 @@ def _save_tpu(self, output_dir: Optional[str] = None): # They can then be reloaded using `from_pretrained()` supported_classes = (PushToHubMixin,) xm.rendezvous("saving_checkpoint") - if self.is_fsdp_xla_enabled: + if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled: ckpt = { "model": model.state_dict(), "shard_metadata": model.get_shard_metadata(), @@ -3501,13 +3508,25 @@ def _save_tpu(self, output_dir: Optional[str] = None): # Make sure all ranks have saved checkpoints xm.rendezvous("save_full_checkpoints") # Master save full checkpoint - if xm.is_master_ordinal(local=False): + if self.args.should_save: from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints + from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP full_state_dict, _ = consolidate_sharded_model_checkpoints( ckpt_prefix=os.path.join(output_dir, ""), ckpt_suffix=f"rank*_of_*_{WEIGHTS_NAME}.pth", save_model=False) - torch.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + assert isinstance(model, FSDP) + model = model.module.module + if isinstance(self.accelerator.unwrap_model(model), supported_classes): + self.accelerator.unwrap_model(model).save_pretrained( + output_dir, + state_dict=full_state_dict, + save_function=xm.save, + safe_serialization=self.args.save_safetensors, + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) # Remove temporary sharded checkpoints xm.rendezvous("remove_unused_checkpoints") os.remove(ckpt_path) From 9aca73e68f7f0d3c46f99496fd84935df220fa12 Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Tue, 30 Jul 2024 11:34:52 +0800 Subject: [PATCH 3/6] Fix style --- src/transformers/trainer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8e953d81c3bf..ed97b5ad664c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2977,7 +2977,7 @@ def _save_optimizer_and_scheduler(self, output_dir): xm.save( optm, os.path.join(output_dir, f"rank{self.args.process_index}_{OPTIMIZER_NAME}"), - master_only=False + master_only=False, ) else: xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) @@ -3068,8 +3068,9 @@ def _load_optimizer_and_scheduler(self, checkpoint): if is_torch_xla_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled: - optimizer_state = torch.load(os.path.join( - checkpoint, f"rank{self.args.process_index}_{OPTIMIZER_NAME}"), map_location="cpu") + optimizer_state = torch.load( + os.path.join(checkpoint, f"rank{self.args.process_index}_{OPTIMIZER_NAME}"), map_location="cpu" + ) # We only need `optimizer` when resuming from checkpoint optimizer_state = optimizer_state["optimizer"] else: @@ -3500,8 +3501,7 @@ def _save_tpu(self, output_dir: Optional[str] = None): "shard_metadata": model.get_shard_metadata(), } ckpt_path = os.path.join( - output_dir, - f"rank{self.args.process_index}_of_{self.args.world_size}_{WEIGHTS_NAME}.pth" + output_dir, f"rank{self.args.process_index}_of_{self.args.world_size}_{WEIGHTS_NAME}.pth" ) # All ranks save sharded checkpoint xm.save(ckpt, ckpt_path, master_only=False) @@ -3509,12 +3509,14 @@ def _save_tpu(self, output_dir: Optional[str] = None): xm.rendezvous("save_full_checkpoints") # Master save full checkpoint if self.args.should_save: - from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP + from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints + full_state_dict, _ = consolidate_sharded_model_checkpoints( ckpt_prefix=os.path.join(output_dir, ""), ckpt_suffix=f"rank*_of_*_{WEIGHTS_NAME}.pth", - save_model=False) + save_model=False, + ) assert isinstance(model, FSDP) model = model.module.module if isinstance(self.accelerator.unwrap_model(model), supported_classes): From 00aacf94c5c8fc54b63eb3b41d210d8ccf158fce Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Thu, 1 Aug 2024 14:30:22 +0800 Subject: [PATCH 4/6] reserve sharded ckpt and better file naming --- src/transformers/trainer.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ed97b5ad664c..3e6833dfc930 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2976,7 +2976,9 @@ def _save_optimizer_and_scheduler(self, output_dir): } xm.save( optm, - os.path.join(output_dir, f"rank{self.args.process_index}_{OPTIMIZER_NAME}"), + os.path.join( + output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" + ), master_only=False, ) else: @@ -3059,7 +3061,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): ) ) checkpoint_file_exists = ( - glob.glob(os.path.join(checkpoint, "rank*_" + OPTIMIZER_NAME)) + glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}")) if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled else checkpoint_file_exists ) @@ -3069,7 +3071,10 @@ def _load_optimizer_and_scheduler(self, checkpoint): # On TPU we have to take some extra precautions to properly load the states on the right device. if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled: optimizer_state = torch.load( - os.path.join(checkpoint, f"rank{self.args.process_index}_{OPTIMIZER_NAME}"), map_location="cpu" + os.path.join( + checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" + ), + map_location="cpu", ) # We only need `optimizer` when resuming from checkpoint optimizer_state = optimizer_state["optimizer"] @@ -3501,7 +3506,7 @@ def _save_tpu(self, output_dir: Optional[str] = None): "shard_metadata": model.get_shard_metadata(), } ckpt_path = os.path.join( - output_dir, f"rank{self.args.process_index}_of_{self.args.world_size}_{WEIGHTS_NAME}.pth" + output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{WEIGHTS_NAME}" ) # All ranks save sharded checkpoint xm.save(ckpt, ckpt_path, master_only=False) @@ -3509,15 +3514,13 @@ def _save_tpu(self, output_dir: Optional[str] = None): xm.rendezvous("save_full_checkpoints") # Master save full checkpoint if self.args.should_save: - from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints full_state_dict, _ = consolidate_sharded_model_checkpoints( ckpt_prefix=os.path.join(output_dir, ""), - ckpt_suffix=f"rank*_of_*_{WEIGHTS_NAME}.pth", + ckpt_suffix=f"rank*-of-*-{WEIGHTS_NAME}", save_model=False, ) - assert isinstance(model, FSDP) model = model.module.module if isinstance(self.accelerator.unwrap_model(model), supported_classes): self.accelerator.unwrap_model(model).save_pretrained( @@ -3529,9 +3532,6 @@ def _save_tpu(self, output_dir: Optional[str] = None): else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) - # Remove temporary sharded checkpoints - xm.rendezvous("remove_unused_checkpoints") - os.remove(ckpt_path) elif not isinstance(model, supported_classes): if isinstance(self.accelerator.unwrap_model(model), supported_classes): self.accelerator.unwrap_model(model).save_pretrained( From 99d568d703c820397c3c134ee8d6145117f5301a Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Fri, 9 Aug 2024 09:52:36 +0800 Subject: [PATCH 5/6] minor fix Co-authored-by: Zach Mueller --- src/transformers/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3e6833dfc930..8e69f4fffca6 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3522,8 +3522,9 @@ def _save_tpu(self, output_dir: Optional[str] = None): save_model=False, ) model = model.module.module - if isinstance(self.accelerator.unwrap_model(model), supported_classes): - self.accelerator.unwrap_model(model).save_pretrained( + unwrapped_model = self.accelerator.unwrap_model(model) + if isinstance(unwrapped_model, supported_classes): + unwrapped_model.save_pretrained( output_dir, state_dict=full_state_dict, save_function=xm.save, From dbbff95c32b7aa53ac9a39b644c5ee95e6192287 Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Thu, 15 Aug 2024 16:41:55 +0800 Subject: [PATCH 6/6] add is_fsdp_xla_v1_enabled --- src/transformers/trainer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8e69f4fffca6..055919ce1184 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -703,6 +703,7 @@ def __init__( # Tensor axis is just a placeholder where it will not be used in FSDPv2. num_devices = xr.global_runtime_device_count() xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor"))) + self.is_fsdp_xla_v1_enabled = self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled def _activate_neftune(self, model): r""" @@ -2969,7 +2970,7 @@ def _save_rng_state(self, output_dir): def _save_optimizer_and_scheduler(self, output_dir): if is_torch_xla_available(): xm.rendezvous("saving_optimizer_states") - if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled: + if self.is_fsdp_xla_v1_enabled: optm = { "optimizer": self.optimizer.state_dict(), "shard_metadata": self.model.get_shard_metadata(), @@ -3062,14 +3063,14 @@ def _load_optimizer_and_scheduler(self, checkpoint): ) checkpoint_file_exists = ( glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}")) - if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled + if self.is_fsdp_xla_v1_enabled else checkpoint_file_exists ) if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states if is_torch_xla_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. - if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled: + if self.is_fsdp_xla_v1_enabled: optimizer_state = torch.load( os.path.join( checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" @@ -3500,7 +3501,7 @@ def _save_tpu(self, output_dir: Optional[str] = None): # They can then be reloaded using `from_pretrained()` supported_classes = (PushToHubMixin,) xm.rendezvous("saving_checkpoint") - if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled: + if self.is_fsdp_xla_v1_enabled: ckpt = { "model": model.state_dict(), "shard_metadata": model.get_shard_metadata(),