diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 5554220653227..a7b5b9c1864a8 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -94,8 +94,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): else: base_module = self.model + # Set the profile start and end steps in the unit of global batach + if hasattr(self, '_nsys_profile_enabled'): + self._nsys_profile_start_step = self.cfg.nsys_profile.get('start_step', 0) + self._nsys_profile_end_step = self.cfg.nsys_profile.get('end_step', 0) + self._reset_activation_checkpointing_args() - self._reset_sequence_parallelism_args() self.virtual_tokens = 0 def setup_metric(self, data_cfg): @@ -593,7 +597,6 @@ def inference_epoch_end(self, outputs, mode, data_cfg): # Merge the functionality of previous on_inference_epoch_end() within inference_epoch_end() func here app_state = AppState() self._restore_activation_checkpointing_args() - self._restore_sequence_parallelism_args() if hasattr(self, "_train_ds"): _reconfigure_microbatch_calculator( rank=app_state.global_rank, @@ -816,7 +819,6 @@ def setup_eval_dataloader(self, datasets, data_cfg): def on_validation_epoch_start(self): self._reset_activation_checkpointing_args() - self._reset_sequence_parallelism_args() app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, @@ -829,7 +831,6 @@ def on_validation_epoch_start(self): def on_test_epoch_start(self): self._reset_activation_checkpointing_args() - self._reset_sequence_parallelism_args() app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index 6c51988e67282..1a13d22785202 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -730,9 +730,6 @@ def sample_sequence_batch( micro_batch_size=micro_batch_size, data_parallel_size=1, ) - assert ( - model.cfg.get('sequence_parallel', False) == False - ), 'sequence_parallel should be False during inference. Disable it in the model config if restoring from nemo or in hparams.yaml if restoring from PTL checkpoint' assert ( model.cfg.get('activations_checkpoint_granularity', None) is None ), 'activations_checkpoint_granularity should be None during inference. Disable it in the model config if restoring from nemo or in hparams.yaml if restoring from PTL checkpoint' diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 4c7efffcb1174..754d18eb8267c 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -190,6 +190,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Setup nsys profiling if it has been enabled in the model config self._setup_nsys_profiling() + # A flag for the profile generation + self._profile_complete = False + def __init_subclass__(cls) -> None: cls._save_restore_connector = SaveRestoreConnector() @@ -1720,7 +1723,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> O # nsys profiling if self.device.type == 'cuda': if hasattr(self, '_nsys_profile_enabled'): - if self._nsys_profile_enabled: + if self._nsys_profile_enabled and not self._profile_complete: if batch_idx == self._nsys_profile_start_step and get_rank() in self._nsys_profile_ranks: logging.info("====== Start nsys profiling ======") torch.cuda.cudart().cudaProfilerStart() @@ -1757,10 +1760,11 @@ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, unused: int = if self.device.type == 'cuda': if hasattr(self, '_nsys_profile_enabled'): - if self._nsys_profile_enabled: + if self._nsys_profile_enabled and not self._profile_complete: if batch_idx == self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks: logging.info("====== End nsys profiling ======") torch.cuda.cudart().cudaProfilerStop() + self._profile_complete = True def _cleanup_on_execution_end(self): """