diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index d799cb6fb044..358f3387b812 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1348,6 +1348,8 @@ def _reset_sequence_parallelism_args(self): # Reset config values. Needed for calling generate. self.cfg.sequence_parallel = False + self.model_parallel_config.sequence_parallel = False + self.transformer_config.sequence_parallel = False # Reset model parameters. for module in self.get_gpt_module_list(): @@ -1362,6 +1364,8 @@ def _restore_sequence_parallelism_args(self): """ # Restore config values. self.cfg.sequence_parallel = self.last_sequence_parallel + self.model_parallel_config.sequence_parallel = self.last_sequence_parallel + self.transformer_config.sequence_parallel = self.last_sequence_parallel # Restore model parameters. for module in self.get_gpt_module_list(): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 3c55ced2cb8e..850d66e286da 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -586,6 +586,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): # Merge the functionality of previous on_inference_epoch_end() within inference_epoch_end() func here app_state = AppState() self._restore_activation_checkpointing_args() + self._restore_sequence_parallelism_args() if hasattr(self, "_train_ds"): _reconfigure_microbatch_calculator( rank=app_state.global_rank,