From ab9934e16c19ed463c1fb22829db414e8cd66d60 Mon Sep 17 00:00:00 2001 From: eljandoubi Date: Mon, 28 Oct 2024 00:15:59 +0100 Subject: [PATCH 1/6] Remove FSDP wrapping from sub-models. --- src/transformers/trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d41b7181be63..8fcdc4bd984d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2277,12 +2277,13 @@ def _inner_training_loop( # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX use_accelerator_prepare = True if model is self.model else False - # configure fsdp plugin for qlora if any - if use_accelerator_prepare: - self._fsdp_qlora_plugin_updates() + if use_accelerator_prepare and self.is_fsdp_enabled: + # Remove FSDP wrapping from sub-models. + self.model = extract_model_from_parallel(self.model, recursive=True) if delay_optimizer_creation: if use_accelerator_prepare: + self._fsdp_qlora_plugin_updates() self.model = self.accelerator.prepare(self.model) self.create_optimizer_and_scheduler(num_training_steps=max_steps) From d518f76463b984707bb23e7041976caa08e5f9d6 Mon Sep 17 00:00:00 2001 From: AbdelKarim ELJANDOUBI <78537694+eljandoubi@users.noreply.github.com> Date: Mon, 28 Oct 2024 14:30:43 +0100 Subject: [PATCH 2/6] solve conflict trainer.py --- src/transformers/trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8fcdc4bd984d..0c9091c36591 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2276,14 +2276,16 @@ def _inner_training_loop( # this is for unhandled cases such as # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX use_accelerator_prepare = True if model is self.model else False - + if use_accelerator_prepare and self.is_fsdp_enabled: + #In case of auto_find_batch_size=True # Remove FSDP wrapping from sub-models. self.model = extract_model_from_parallel(self.model, recursive=True) + # configure fsdp plugin for qlora if any + self._fsdp_qlora_plugin_updates() if delay_optimizer_creation: if use_accelerator_prepare: - self._fsdp_qlora_plugin_updates() self.model = self.accelerator.prepare(self.model) self.create_optimizer_and_scheduler(num_training_steps=max_steps) From e76dfd03cfe1225206673ab68ca7e9358ff3d0b4 Mon Sep 17 00:00:00 2001 From: eljandoubi Date: Tue, 29 Oct 2024 15:47:15 +0100 Subject: [PATCH 3/6] make fixup --- src/transformers/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0c9091c36591..38b301e3544b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2276,9 +2276,9 @@ def _inner_training_loop( # this is for unhandled cases such as # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX use_accelerator_prepare = True if model is self.model else False - + if use_accelerator_prepare and self.is_fsdp_enabled: - #In case of auto_find_batch_size=True + # In case of auto_find_batch_size=True # Remove FSDP wrapping from sub-models. self.model = extract_model_from_parallel(self.model, recursive=True) # configure fsdp plugin for qlora if any From a6dcb92a44c52b92a2a766768687941a0484b699 Mon Sep 17 00:00:00 2001 From: eljandoubi Date: Thu, 31 Oct 2024 00:01:45 +0100 Subject: [PATCH 4/6] add unit test for fsdp_auto_wrap_policy when using auto_find_batch_size --- tests/trainer/test_trainer_fsdp.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/trainer/test_trainer_fsdp.py b/tests/trainer/test_trainer_fsdp.py index 4bcf5de04520..eca6a30664f0 100644 --- a/tests/trainer/test_trainer_fsdp.py +++ b/tests/trainer/test_trainer_fsdp.py @@ -117,6 +117,33 @@ def test_trainer(self): execute_subprocess_async(cmd, env=self.get_env()) # successful return here == success - any errors would have caused an error in the sub-call + class TestFSDPTrainerWrap(TestCasePlus): + @require_accelerate + @require_torch_multi_gpu + @require_fsdp + def test_trainer(self): + output_dir = self.get_auto_remove_tmp_dir() + cmd = [ + "accelerate", + "launch", + "--use_fsdp", + "--main_process_port", + f"{get_torch_dist_unique_port()}", + "--num_processes", + f"{torch.cuda.device_count()}", + "--fsdp_transformer_layer_cls_to_wrap", + "GPT2Block", + f"{self.test_file_dir}/test_trainer_fsdp.py", + "--output_dir", + f"{output_dir}", + "--report_to", + "none", + "--auto_find_batch_size", + "True", + ] + execute_subprocess_async(cmd, env=self.get_env()) + # successful return here == success - any errors would have caused an error in the sub-call + if __name__ == "__main__": parser = HfArgumentParser((Seq2SeqTrainingArguments,)) From 7ebb2c6134df7f52c8366b95a2f0de9086c714fc Mon Sep 17 00:00:00 2001 From: AbdelKarim ELJANDOUBI <78537694+eljandoubi@users.noreply.github.com> Date: Mon, 4 Nov 2024 14:47:28 +0100 Subject: [PATCH 5/6] put back extract_model_from_parallel --- src/transformers/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 38b301e3544b..8fb3e13eb572 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -233,6 +233,7 @@ from accelerate.utils import ( DistributedDataParallelKwargs, DistributedType, + extract_model_from_parallel, load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, From 0df20d68955962d139f593df9f0a819dd2c49d63 Mon Sep 17 00:00:00 2001 From: AbdelKarim ELJANDOUBI <78537694+eljandoubi@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:26:02 +0100 Subject: [PATCH 6/6] use transformers unwrap_model --- src/transformers/trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8fb3e13eb572..f3935c0ffd2f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -66,7 +66,7 @@ from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available from .integrations.tpu import tpu_spmd_dataloader from .modelcard import TrainingSummary -from .modeling_utils import PreTrainedModel, load_sharded_checkpoint +from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES, @@ -233,7 +233,6 @@ from accelerate.utils import ( DistributedDataParallelKwargs, DistributedType, - extract_model_from_parallel, load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, @@ -2281,7 +2280,7 @@ def _inner_training_loop( if use_accelerator_prepare and self.is_fsdp_enabled: # In case of auto_find_batch_size=True # Remove FSDP wrapping from sub-models. - self.model = extract_model_from_parallel(self.model, recursive=True) + self.model = unwrap_model(self.model, recursive=True) # configure fsdp plugin for qlora if any self._fsdp_qlora_plugin_updates()