From 335b876ebe3ab10fe5c40d4b0857a78f4c820adb Mon Sep 17 00:00:00 2001 From: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Date: Mon, 21 Aug 2023 09:25:09 -0700 Subject: [PATCH] Fix restore sequence parallel (#7273) * Fix restore Signed-off-by: Cheng-Ping Hsieh * reset and restore transformer config sequence parallel Signed-off-by: Jason Wang * modify model parallel config as well Signed-off-by: Jason Wang --------- Signed-off-by: Cheng-Ping Hsieh Signed-off-by: Jason Wang Co-authored-by: Jason Wang --- .../nlp/models/language_modeling/megatron_gpt_model.py | 4 ++++ .../nlp/models/language_modeling/megatron_gpt_sft_model.py | 1 + 2 files changed, 5 insertions(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index d799cb6fb044..358f3387b812 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1348,6 +1348,8 @@ def _reset_sequence_parallelism_args(self): # Reset config values. Needed for calling generate. self.cfg.sequence_parallel = False + self.model_parallel_config.sequence_parallel = False + self.transformer_config.sequence_parallel = False # Reset model parameters. for module in self.get_gpt_module_list(): @@ -1362,6 +1364,8 @@ def _restore_sequence_parallelism_args(self): """ # Restore config values. self.cfg.sequence_parallel = self.last_sequence_parallel + self.model_parallel_config.sequence_parallel = self.last_sequence_parallel + self.transformer_config.sequence_parallel = self.last_sequence_parallel # Restore model parameters. for module in self.get_gpt_module_list(): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 3c55ced2cb8e..850d66e286da 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -586,6 +586,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): # Merge the functionality of previous on_inference_epoch_end() within inference_epoch_end() func here app_state = AppState() self._restore_activation_checkpointing_args() + self._restore_sequence_parallelism_args() if hasattr(self, "_train_ds"): _reconfigure_microbatch_calculator( rank=app_state.global_rank,