From 16ff1381ba92fa615441ca4672415de0dbf0209e Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Mon, 5 Jun 2023 08:36:29 -0700 Subject: [PATCH] Add ub communicator initialization to validation step (#6807) --- .../language_modeling/megatron_gpt_model.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index fd1382e668cf..f2147ff563e4 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -499,12 +499,39 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): return loss_mean + def initialize_ub_func(self): + input_shape = [ + self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'), + self.cfg.get('hidden_size'), + ] + ub_cfg_file_name = self.cfg.get('ub_tp_comm_overlap_cfg', None) + if ub_cfg_file_name is not None: + try: + import yaml + + with open(ub_cfg_file_name, 'r') as ub_cfg_file: + ub_cfgs = yaml.safe_load(ub_cfg_file) + except (ImportError, TypeError): + print("Fail to read ub_tp_comm_overlap config file.") + else: + ub_cfgs = None + te_module.initialize_ub( + shape=input_shape, + tp_size=self.cfg.get('tensor_model_parallel_size'), + use_fp8=self.cfg.get('fp8'), + ub_cfgs=ub_cfgs, + ) + self.initialize_ub = False + def training_step(self, dataloader_iter, batch_idx): """ We pass the dataloader iterator function to the micro-batch scheduler. The input batch to each micro-batch is fetched using the dataloader function in the micro-batch fwd function. """ + # Initialize userbuffer communicators. + if self.initialize_ub: + self.initialize_ub_func() # we zero grads here because we also call backward in the megatron-core fwd/bwd functions self._optimizer.zero_grad() @@ -829,6 +856,10 @@ def validation_step(self, dataloader_iter, batch_idx): from the dataloader to produce a list of microbatches. The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ + # Initialize userbuffer communicators. + if self.initialize_ub: + self.initialize_ub_func() + if isinstance(self.model, list): for model_module in self.model: model_module.eval()