diff --git a/examples/nlp/language_modeling/conf/megatron_hiddens_base_config.yaml b/examples/nlp/language_modeling/conf/megatron_hiddens_base_config.yaml new file mode 100644 index 000000000000..d63255d50ed3 --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_hiddens_base_config.yaml @@ -0,0 +1,43 @@ +# this file main purpose is documentation, and it should not be used directly +enc_output_name: z # name of key in hidden transforms output to pass to decoder (e.g., z for VAE/MIM) +tokens_loss_weight: 1.0 # weight of tokens loss (if not specified defaults to 1.0) +# the lists below are useful for adding multiple transforms and losses according to order +# if order is not important, you can use a single dictionary in the list with multiple keys +transform: # a list of dictionaries of transforms (or a joint dictionary) to apply to hiddens (list enforces order) + # - : # name of transform + # cls_name: # class path name + # : # transform parameters + # ... + - q_z_given_x: # Gaussian posterior with reparameterization + cls_name: cond_gaussian # class path name + hidden_size: 512 # hidden size of the encoder + min_logvar: -6.0 # minimum log variance + - logP_cls: + cls_name: guided_cls + input_name: hiddens + attr_name: logP + - QED_cls: + cls_name: guided_cls + input_name: hiddens + attr_name: QED +loss: # a list of dictionaries of loss terms (or a joint dictionary) to add to reconstruction loss (list enforces order) + # - : # name of loss + # cls_name: # class path name + # : # loss parameters + # ... + # below is example where order of losses does not matter so a single item in list is enough + - mim: # A-MIM example + cls_name: a_mim + loss_weight: 1.0 # weight of the MIM latent loss + vae: # VAE example + cls_name: vae + min_kl_value: null # minimum KL value if a float is provided + loss_weight: 1e-2 # weight of KL term in loss + logP_cls: + cls_name: guided_cls_loss + input_name: logP + loss_weight: 1.0 + QED_cls: + cls_name: guided_cls_loss + input_name: logP + loss_weight: 1.0 diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py index ddd46e681f94..82a3cbf36b64 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py @@ -249,7 +249,7 @@ def loss_func(output_tensor): lm_loss = loss_dict['lm loss'] loss = lm_loss reduced_loss = average_losses_across_data_parallel_group([loss, lm_loss]) - return loss, {'avg': reduced_loss} + return loss, {'loss': reduced_loss} return output_tensor, loss_func @@ -334,7 +334,7 @@ def training_step(self, dataloader_iter, batch_idx): ) if losses_reduced_per_micro_batch: - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensors_list = [loss_reduced['loss'] for loss_reduced in losses_reduced_per_micro_batch] loss_tensor = torch.vstack(loss_tensors_list) loss_mean = loss_tensor.mean(axis=0) else: @@ -447,7 +447,7 @@ def validation_step(self, dataloader_iter, batch_idx): ) if losses_reduced_per_micro_batch: - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensors_list = [loss_reduced['loss'] for loss_reduced in losses_reduced_per_micro_batch] loss_tensor = torch.vstack(loss_tensors_list) loss_mean = loss_tensor.mean(axis=0) else: diff --git a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py index 9fce0d52c4a1..d854505d1f74 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py @@ -276,47 +276,27 @@ def _reconfigure_and_process_inference_batch(self, batch, ds_config): def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): """ Dataloader produces a global batch which is turned into a list of microbatches. - The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ - # Get seq length of batch batch = next(dataloader_iter) if isinstance(batch, dict): # convert to list if not already converted. batch = self._process_batch(batch) - _, seq_length = batch[0].shape - _, dec_seq_length = batch[1].shape - tensor_shape = [seq_length, get_micro_batch_size(), self.cfg.encoder.hidden_size] - data_iter = get_iterator_k_split(batch, get_num_microbatches()) + # Get seq length of batch + encoder_seq_length = batch[0].size(1) + decoder_seq_length = batch[1].size(1) - fwd_bwd_function = get_forward_backward_func() + tensor_shape = [encoder_seq_length, get_micro_batch_size(), self.cfg.encoder.hidden_size] + data_iter = get_iterator_k_split(batch, get_num_microbatches()) - losses_reduced_per_micro_batch = fwd_bwd_function( - forward_step_func=self.get_forward_output_and_loss_func(), + return self._execute_fwd_bwd_function( data_iterator=data_iter, - model=[self.enc_dec_model], - num_microbatches=get_num_microbatches(), forward_only=forward_only, tensor_shape=tensor_shape, - decoder_seq_length=dec_seq_length, - dtype=self.autocast_dtype, - grad_scaler=self.trainer.precision_plugin.scaler.scale if self.cfg.precision == 16 else None, - sequence_parallel=self.cfg.get('sequence_parallel', False), - enable_autocast=self.enable_autocast, + decoder_seq_length=decoder_seq_length, ) - # only the last stages of the pipeline return losses - if losses_reduced_per_micro_batch: - # average loss across micro batches - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] - loss_tensor = torch.concat(loss_tensors_list) - loss_mean = loss_tensor.mean() - else: - # we're not on the last pipeline stage so no losses - loss_mean = torch.tensor(0.0).cuda() - - return loss_mean - def inference_step(self, dataloader_iter, batch_idx: int, mode: str, dataloader_idx=0): # Add try except since dataloader_iter in PTL 2.0 doesnt catch the end of the iterator try: @@ -366,12 +346,16 @@ def inference_step(self, dataloader_iter, batch_idx: int, mode: str, dataloader_ _ = metric(pred, label) outputs = { - 'loss': loss, 'preds': preds_text, 'labels': labels_text, 'categories': categories, 'inputs': input_text, } + + if isinstance(loss, dict): + outputs.update(loss) + else: + outputs['loss'] = loss if mode == 'validation': if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: self.validation_step_outputs[dataloader_idx].append(outputs) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index f8f8fe808612..ff4da0f624ed 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -240,7 +240,6 @@ def _populate_encoder_decoder_configs_for_backward_compatibility(self, cfg): ) # For models before separate encoder/decoder configs, tokens_head_bias was always True. def model_provider_func(self, pre_process, post_process, add_encoder, add_decoder): - # TODO: create get_encoder_decoder_model()here for different losses (e..g, nll, vae, mim) if not hasattr(self.cfg, 'encoder') or not hasattr(self.cfg, 'decoder'): logging.warning( 'Could not find encoder or decoder in config. This is probably because of restoring an old checkpoint. Copying shared model configs to encoder and decoder configs.' @@ -282,6 +281,7 @@ def model_provider_func(self, pre_process, post_process, add_encoder, add_decode share_token_embeddings=self.cfg.get('share_token_embeddings', True), share_decoder_tokens_head_embeddings=self.cfg.get('share_decoder_tokens_head_embeddings', True), tokens_head_bias=self.cfg.get('tokens_head_bias', True), + hiddens_cfg=self.cfg.get('hiddens', None), ) return model @@ -313,42 +313,54 @@ def forward( return output_tensor - def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): + def _execute_fwd_bwd_function(self, data_iterator, forward_only, tensor_shape, decoder_seq_length): """ - Dataloader produces a global batch which is turned into a list of microbatches. - The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + An auxiliary function that executes the fwd_bwd_step function and parse the returned values. """ - # Get seq length of batch - tensor_shape = [self.max_encoder_seq_length, self.cfg.micro_batch_size, self.cfg.encoder.hidden_size] - fwd_bwd_function = get_forward_backward_func() losses_reduced_per_micro_batch = fwd_bwd_function( forward_step_func=self.get_forward_output_and_loss_func(), - data_iterator=dataloader_iter, + data_iterator=data_iterator, model=[self.enc_dec_model], num_microbatches=get_num_microbatches(), forward_only=forward_only, tensor_shape=tensor_shape, - decoder_seq_length=self.max_decoder_seq_length, + decoder_seq_length=decoder_seq_length, dtype=self.autocast_dtype, grad_scaler=self.trainer.precision_plugin.scaler.scale if self.cfg.precision == 16 else None, + sequence_parallel=self.cfg.get('sequence_parallel', False), enable_autocast=self.enable_autocast, ) # only the last stages of the pipeline return losses if losses_reduced_per_micro_batch: - # average loss across micro batches - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] - loss_tensor = torch.concat(loss_tensors_list) - loss_mean = loss_tensor.mean() + mean_loss_dict = {} + for k in losses_reduced_per_micro_batch[0].keys(): + # average loss across micro batches + mean_loss_dict[k] = torch.stack( + [loss_reduced[k] for loss_reduced in losses_reduced_per_micro_batch] + ).mean() else: - if forward_only: - loss_mean = [] - else: - loss_mean = torch.tensor(0.0).cuda() + loss_mean = torch.tensor(0.0).cuda() + mean_loss_dict = {"loss": loss_mean} - return loss_mean + return mean_loss_dict + + def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): + """ + Dataloader produces a global batch which is turned into a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + """ + # Get seq length of batch + tensor_shape = [self.max_encoder_seq_length, self.cfg.micro_batch_size, self.cfg.encoder.hidden_size] + + return self._execute_fwd_bwd_function( + data_iterator=dataloader_iter, + forward_only=forward_only, + tensor_shape=tensor_shape, + decoder_seq_length=self.max_decoder_seq_length, + ) def training_step(self, dataloader_iter, batch_idx): """ @@ -362,7 +374,7 @@ def training_step(self, dataloader_iter, batch_idx): # we zero grads here because we also call backward in the megatron fwd/bwd functions self._optimizer.zero_grad() - loss_mean = self.fwd_bwd_step(dataloader_iter, batch_idx, False) + loss_dict = self.fwd_bwd_step(dataloader_iter, batch_idx, False) if self.with_distributed_adam: # synchronize asynchronous grad reductions @@ -386,14 +398,16 @@ def training_step(self, dataloader_iter, batch_idx): ## logging # we can only log on one rank if it is rank zero so we broadcast from last rank # we can avoid this broadcast by updating the PTL log function to accept specific ranks - torch.distributed.broadcast(loss_mean, get_last_rank()) + for k, v in loss_dict.items(): + torch.distributed.broadcast(v, get_last_rank()) + n = f'reduced_train_{k}' + self.log(n, v, prog_bar=n.endswith("_loss"), rank_zero_only=True, batch_size=1) if self.cfg.precision == 16: loss_scale = self.trainer.precision_plugin.scaler._scale if loss_scale is not None: self.log('loss_scale', loss_scale, batch_size=1) - self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) lr = self._optimizer.param_groups[0]['lr'] self.log('lr', lr, rank_zero_only=True, batch_size=1) self.log( @@ -407,7 +421,7 @@ def training_step(self, dataloader_iter, batch_idx): rank_zero_only=True, batch_size=1, ) - return loss_mean + return loss_dict @property def max_decoder_seq_length(self) -> int: @@ -556,16 +570,26 @@ def _process_batch(self, global_batch: Dict[str, torch.Tensor]) -> List[torch.Te global_batch["labels"], global_batch["enc_mask"], global_batch["dec_mask"], + global_batch.get('data', None), ] def get_forward_output_and_loss_func(self): def fwd_output_and_loss_func(dataloader_iter, model): batch = next(dataloader_iter) + # convert to list if not already converted. if isinstance(batch, dict): # convert to list if not already converted. batch = self._process_batch(batch) - batch = [x.cuda(non_blocking=True) for x in batch] - encoder_input_ids, decoder_input_ids, loss_mask, lm_labels, encoder_attn_mask, decoder_attn_mask = batch + batch = [x.cuda(non_blocking=True) if torch.is_tensor(x) else x for x in batch] + ( + encoder_input_ids, + decoder_input_ids, + loss_mask, + lm_labels, + encoder_attn_mask, + decoder_attn_mask, + batch_data, + ) = batch output = model( encoder_input_ids, # enc_input_ids @@ -574,12 +598,32 @@ def fwd_output_and_loss_func(dataloader_iter, model): decoder_attn_mask, # dec_attn_mask None, # token_type_ids lm_labels, # labels + batch_data, # batch_data ) def loss_func(output_tensor): - loss = self.loss_func(loss_mask, output_tensor) - reduced_loss = average_losses_across_data_parallel_group([loss]) - return loss, {'avg': reduced_loss} + if isinstance(output_tensor, dict): + # handle loss of hidden transformations + loss_dict = output_tensor + output_tensor = loss_dict.pop("output") + # compute reconstruction (tokens) only loss from per-token reconstruction loss + tokens_loss = self.loss_func(loss_mask, output_tensor) + loss_dict["tokens_loss"] = tokens_loss + tokens_loss_weight = loss_dict.get("tokens_loss_weight", 1.0) + # compute total loss + loss = loss_dict["loss"] = loss_dict["hiddens_loss"] + tokens_loss_weight * tokens_loss + # average losses across data parallel group + loss_dict = { + k: average_losses_across_data_parallel_group([v.mean()]) for k, v in loss_dict.items() + } + else: + # compute reconstruction (tokens) only loss from per-token reconstruction loss + loss = self.loss_func(loss_mask, output_tensor) + # average losses across data parallel group + reduced_loss = average_losses_across_data_parallel_group([loss]) + loss_dict = {'loss': reduced_loss} + + return loss, loss_dict return output, loss_func @@ -645,75 +689,104 @@ def _get_forward_output_only_func(self, arg_names, output_name, **kwargs): def fwd_output_only_func(dataloader_iter, model): batch = next(dataloader_iter) - batch = [x.cuda(non_blocking=True) for x in batch] + batch = [x.cuda(non_blocking=True) if torch.is_tensor(x) else x for x in batch] # map batch and shared args into forward args args = self._build_forward_args_from_kwargs(args_name=arg_names, args=batch, **kwargs) output = model(*args).contiguous() def id_func(output_tensor): + if isinstance(output_tensor, dict): + # handle loss of hidden transformations ("output" is the default output) + output_tensor = output_tensor["output"] + return output_tensor, {output_name: output_tensor} return output, id_func return fwd_output_only_func - def validation_step(self, dataloader_iter, batch_idx, dataloader_idx=0): + ########## + + def _test_validation_step(self, step_outputs, dataloader_iter, batch_idx, dataloader_idx=0): """ - return_values - if given, returns a dictionary with given keys and corresponding values + Shared code for validation and test step """ # Prefetch the dataloader_iter before fwd_bwd func to avoid PP rank 2 from waiting indefinitely with PP rank 1 reaches the end of dataloader_iter dataloader_iter, done = self._prefetch(dataloader_iter) if done: return - prefix = "test" if self.trainer.testing else "val" - loss = self.fwd_bwd_step(dataloader_iter, batch_idx, True) - if prefix == 'val': - if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(loss) - else: - self.validation_step_outputs.append(loss) + + loss_dict = self.fwd_bwd_step(dataloader_iter, batch_idx, True) + step_outputs.append(loss_dict) + + return loss_dict + + def validation_step(self, dataloader_iter, batch_idx, dataloader_idx=0): + """ + return_values - if given, returns a dictionary with given keys and corresponding values + """ + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + step_outputs = self.validation_step_outputs[dataloader_idx] else: - if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: - self.test_step_outputs[dataloader_idx].append(loss) - else: - self.test_step_outputs.append(loss) + step_outputs = self.validation_step_outputs - return loss + return self._test_validation_step( + step_outputs=step_outputs, + dataloader_iter=dataloader_iter, + batch_idx=batch_idx, + dataloader_idx=dataloader_idx, + ) - def on_validation_epoch_end(self): + def test_step(self, dataloader_iter, batch_idx, dataloader_idx=0): + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + step_outputs = self.test_step_outputs[dataloader_idx] + else: + step_outputs = self.test_step_outputs + + return self._test_validation_step( + step_outputs=step_outputs, + dataloader_iter=dataloader_iter, + batch_idx=batch_idx, + dataloader_idx=dataloader_idx, + ) + + def _test_validation_epoch_end(self, step_outputs, prefix): + """ + Shared logging for validation and test + """ # NOTE: we need to make sure outputs is not empty (this is a workaround for a bug in pytorch lightning (?)) - if not self.validation_step_outputs: - logging.warning("validation_epoch_end: outputs is empty") + if not step_outputs: + logging.warning(f"{prefix} epoch end: outputs is empty") return None - if parallel_state.is_pipeline_last_stage(): - # only the last pipeline parallel stages return loss - averaged_loss = torch.stack(self.validation_step_outputs).mean() + + # only the last pipeline parallel stages return loss + if parallel_state.is_pipeline_last_stage() and len(step_outputs): + averaged_loss = {k: torch.stack([x[k] for x in step_outputs]).mean() for k in step_outputs[0].keys()} else: - averaged_loss = torch.tensor(0.0).cuda() + # if we are here we assume that only loss is available and hidden transforms are disabled (since not supported in pipleline parallel) + averaged_loss = {'loss': torch.tensor(0.0).cuda()} # we can only log on one rank if it is rank zero so we broadcast from last rank - torch.distributed.broadcast(averaged_loss, get_last_rank()) - self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) - self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1) - self.validation_step_outputs.clear() # free memory + for k, v in averaged_loss.items(): + torch.distributed.broadcast(v, get_last_rank()) + averaged_loss[k] = v + n = f'{prefix}_{k}' + # log only '*_loss' values in progress bar + self.log(n, v, prog_bar=(n.endswith("_loss")), rank_zero_only=True, batch_size=1) + + # free memory + step_outputs.clear() + return averaged_loss - def test_step(self, dataloader_iter, batch_idx): - return self.validation_step(dataloader_iter, batch_idx) + def on_validation_epoch_end(self): + # FIXME: do we need this? 'global_step' is logged in training_step + self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1) + return self._test_validation_epoch_end(step_outputs=self.validation_step_outputs, prefix="val",) def on_test_epoch_end(self): - if parallel_state.is_pipeline_last_stage(): - # only the last pipeline parallel stages return loss - averaged_loss = torch.stack(self.test_step_outputs).mean() - else: - averaged_loss = torch.tensor(0.0).cuda() - - # we can only log on one rank if it is rank zero so we broadcast from last rank - torch.distributed.broadcast(averaged_loss, get_last_rank()) - self.log('test_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) - self.test_step_outputs.clear() # free memory - return averaged_loss + return self._test_validation_epoch_end(step_outputs=self.test_step_outputs, prefix="test",) def loss_func(self, loss_mask, tokens_loss): """ @@ -937,11 +1010,14 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] logging.info(f"response: {response}") return response - def encode(self, tokens_enc, enc_mask, encoder_input=None, reconfigure_microbatch=True): + def encode(self, tokens_enc, enc_mask, encoder_input=None, batch_data=None, reconfigure_microbatch=True): """ tokens_enc - encoder input tokens enc_mask - corresponding mask encoder_input - encoder input (bypass tokens), if given tokens_enc can be None. + batch_data - passed directly to all hidden transformations and losses. + Can be used to pass additional data like class label. + Format is not defined and should match the expected format of the used hiddens modules. """ # Check whether the DDP is initialized. This is needed when running inference outside of training loop. if parallel_state.is_unitialized(): @@ -987,8 +1063,8 @@ def dummy(): # build input arguments description if tokens_enc is not None: - batch_for_pipeline = [tokens_enc, enc_mask] - arg_names = ['enc_input_ids', 'enc_attn_mask'] + batch_for_pipeline = [tokens_enc, enc_mask, batch_data] + arg_names = ['enc_input_ids', 'enc_attn_mask', 'batch_data'] else: if encoder_input is None: raise ValueError("At least one of tokens_enc and encoder_input must be provided with not None value") @@ -1060,6 +1136,7 @@ def decode( ignore_ids=[], bos_id=None, # If bos=None, will use tokenizer.bos_id unless explicitly set to something else. predicted_tokens_dec=None, + batch_data=None, sampling_method: str = "greedy-search", sampling_kwargs: dict = {}, ): @@ -1168,8 +1245,8 @@ def dummy(): dec_mask = predicted_tokens_dec != tokenizer.pad_id dec_mask[:, 0] = 1 # Make sure you never mask the first token even if it is . - batch_for_pipeline = [enc_output, enc_output_attn_mask, predicted_tokens_dec, dec_mask] - arg_names = ['enc_output', 'enc_output_attn_mask', 'dec_input_ids', 'dec_attn_mask'] + batch_for_pipeline = [enc_output, enc_output_attn_mask, predicted_tokens_dec, dec_mask, batch_data] + arg_names = ['enc_output', 'enc_output_attn_mask', 'dec_input_ids', 'dec_attn_mask', 'batch_data'] forward_step_func = self._get_forward_output_only_func(arg_names=arg_names, output_name="logits") fwd_bwd_func = get_forward_backward_func() diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py index 2f667d815827..c7e63e1c5a59 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py @@ -203,7 +203,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): # only the last stages of the pipeline return losses if losses_reduced_per_micro_batch: # average loss across micro batches - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensors_list = [loss_reduced['loss'] for loss_reduced in losses_reduced_per_micro_batch] loss_tensor = torch.concat(loss_tensors_list) loss_mean = loss_tensor.mean() else: @@ -213,6 +213,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): return loss_mean def get_forward_output_and_loss_func(self): + # FIXME: consolidate this method into MegatronLMEncoderDecoderModel (or have a common base class) def fwd_output_and_loss_func(dataloader_iter, model): batch = next(dataloader_iter) batch = [x.cuda(non_blocking=True) for x in batch] @@ -226,7 +227,7 @@ def fwd_output_and_loss_func(dataloader_iter, model): def loss_func(output_tensor): loss = self.frozen_model.loss_func(loss_mask, output_tensor) reduced_loss = average_losses_across_data_parallel_group([loss]) - return loss, {'avg': reduced_loss} + return loss, {'loss': reduced_loss} return output_tensor, loss_func diff --git a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py index 0a233866cdff..3cd15100111e 100644 --- a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py +++ b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py @@ -303,34 +303,13 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): tensor_shape = [encoder_seq_length, get_micro_batch_size(), self.cfg.encoder.hidden_size] data_iter = get_iterator_k_split(batch, get_num_microbatches()) - fwd_bwd_function = get_forward_backward_func() - - losses_reduced_per_micro_batch = fwd_bwd_function( - forward_step_func=self.get_forward_output_and_loss_func(), + return self._execute_fwd_bwd_function( data_iterator=data_iter, - model=[self.enc_dec_model], - num_microbatches=get_num_microbatches(), forward_only=forward_only, tensor_shape=tensor_shape, decoder_seq_length=decoder_seq_length, - dtype=self.autocast_dtype, - grad_scaler=self.trainer.precision_plugin.scaler.scale if self.cfg.precision == 16 else None, - sequence_parallel=self.cfg.get('sequence_parallel', False), - enable_autocast=self.enable_autocast, ) - # only the last stages of the pipeline return losses - if losses_reduced_per_micro_batch: - # average loss across micro batches - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] - loss_tensor = torch.concat(loss_tensors_list) - loss_mean = loss_tensor.mean() - else: - # we're not on the last pipeline stage so no losses - loss_mean = torch.tensor(0.0).cuda() - - return loss_mean - def eval_step(self, dataloader_iter, batch_idx, dataloader_idx=0): # Add try except since dataloader_iter in PTL 2.0 doesnt catch the end of iterables try: @@ -379,18 +358,22 @@ def eval_step(self, dataloader_iter, batch_idx, dataloader_idx=0): outputs=tokens_enc, tokenizer=self.encoder_tokenizer, processor=source_processor, ) - val_outputs = { + loss_dict = { 'inputs': encoder_inputs, 'translations': preds, 'ground_truths': labels, - 'loss': reduced_loss, } + if isinstance(reduced_loss, dict): + loss_dict.update(reduced_loss) + else: + loss_dict['loss'] = reduced_loss + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(val_outputs) + self.validation_step_outputs[dataloader_idx].append(loss_dict) else: - self.validation_step_outputs.append(val_outputs) + self.validation_step_outputs.append(loss_dict) - return val_outputs + return loss_dict except StopIteration: return diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py index 6a99e908f107..51ed1c7e7ef3 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py @@ -17,6 +17,7 @@ from nemo.collections.nlp.modules.common.megatron.megatron_perceiver_encoders import MegatronPerceiverEncoderModule from nemo.collections.nlp.modules.common.megatron.module import MegatronModule +from nemo.collections.nlp.modules.common.megatron.transformations.megatron_hiddens import MegatronHiddensModule from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults try: @@ -44,6 +45,7 @@ def __init__( encoder_attn_mask_type: AttnMaskType = None, decoder_attn_mask_type: AttnMaskType = None, hidden_steps: int = None, + hiddens_module: MegatronHiddensModule = None, # allows for hidden state transformations before the decoder ): super(MegatronTransformerEncoderDecoderModule, self).__init__() @@ -55,6 +57,12 @@ def __init__( f"hidden_steps cannot be None for perceiver encoders. It is needed to compute the encoder-decoder cross attention mask." ) + self.hiddens_module = hiddens_module + if self.hiddens_module is not None and not isinstance(self.hiddens_module, MegatronHiddensModule): + raise TypeError( + f"hiddens_module must be of type MegatronHiddensModule, but got {type(self.hiddens_module)} instead." + ) + # try to infer mask_type if not given if encoder_attn_mask_type is None: if encoder is None: @@ -83,6 +91,20 @@ def __init__( self._encoder_key = "encoder" self._decoder_key = "decoder" + self._hiddens_module = "hiddens_module" + + def get_hiddens_mask(self, enc_attn_mask): + """ + Returns the attention mask for the output of the encoder. + Required for fixed-size bottleneck models. + """ + if self.encoder is not None and isinstance(self.encoder, MegatronPerceiverEncoderModule): + # Attention mask is expected to be of shape [B x S] + hiddens_mask = torch.ones(enc_attn_mask.size(0), self.hidden_steps).to(enc_attn_mask.device) + else: + hiddens_mask = enc_attn_mask + + return hiddens_mask def encode( self, @@ -91,10 +113,11 @@ def encode( enc_layer_past=None, enc_get_key_value=False, enc_self_attention_relative_position_bias=None, + batch_data=None, ): + """Encodes embedder input using encoder""" if self.encoder is None: raise ValueError(f"Cannot call .encode(...) when self.encoder is None.") - """Encodes embedder input using encoder""" enc_output = self.encoder( enc_input=enc_input, enc_attn_mask=enc_attn_mask, @@ -103,6 +126,12 @@ def encode( enc_self_attention_relative_position_bias=enc_self_attention_relative_position_bias, ) + # apply hidden transformations if needed + if self.hiddens_module is not None: + enc_output = self.hiddens_module.apply_hidden_transforms( + {"hiddens": enc_output, "hiddens_mask": self.get_hiddens_mask(enc_attn_mask),}, batch_data=batch_data, + ) + return enc_output def decode( @@ -148,6 +177,7 @@ def forward( enc_self_attention_relative_position_bias=None, dec_self_attention_relative_position_bias=None, dec_cross_attention_relative_position_bias=None, + batch_data=None, ): # encoder if enc_output is None: @@ -158,6 +188,7 @@ def forward( enc_layer_past=enc_layer_past, enc_get_key_value=enc_get_key_value, enc_self_attention_relative_position_bias=enc_self_attention_relative_position_bias, + batch_data=batch_data, ) else: assert self.encoder_hidden_state is not None @@ -169,22 +200,21 @@ def forward( return enc_output # decoder - # Adjust encoder attention mask if encoder is a perceiver. - if self.encoder is not None and isinstance(self.encoder, MegatronPerceiverEncoderModule): - # Attention mask is expected to be of shape [B x S] and enc_output is of size [S x B x H]. - enc_attn_mask = torch.ones(enc_output.size(1), self.hidden_steps).to(enc_output.device) - dec_output = self.decode( dec_input=dec_input, dec_attn_mask=dec_attn_mask, - enc_output=enc_output, - enc_attn_mask=enc_attn_mask, + enc_output=enc_output["enc_output"] # enc_output is a dict if we used hidden transformations + if self.hiddens_module is not None + else enc_output, + # Adjust encoder attention mask if encoder is a perceiver. + enc_attn_mask=self.get_hiddens_mask(enc_attn_mask), dec_layer_past=dec_layer_past, dec_get_key_value=dec_get_key_value, dec_self_attention_relative_position_bias=dec_self_attention_relative_position_bias, dec_cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias, ) + # if self.hiddens_module is not None enc_output is a dict, else it is a torch.tensor return dec_output, enc_output def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): @@ -195,6 +225,9 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars= state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars) state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars) + if self.hiddens_module is not None: + state_dict_[self._hiddens_module] = self.hiddens_module.state_dict(destination, prefix, keep_vars) + return state_dict_ def load_state_dict(self, state_dict, strict=True): @@ -202,3 +235,5 @@ def load_state_dict(self, state_dict, strict=True): self.encoder.load_state_dict(state_dict[self._encoder_key], strict=strict) self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict) + if self.hiddens_module is not None: + self.hiddens_module.load_state_dict(state_dict[self._hiddens_module], strict=strict) diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index fc16295020fb..928b3f6e8d83 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -28,6 +28,7 @@ KERPLERelativePositionEmbedding, T5RelativePositionEmbedding, ) +from nemo.collections.nlp.modules.common.megatron.transformations.megatron_hiddens import get_hiddens_module from nemo.collections.nlp.modules.common.megatron.utils import ( ApexGuardDefaults, build_position_ids, @@ -124,6 +125,7 @@ def __init__( share_token_embeddings=True, share_decoder_tokens_head_embeddings=True, tokens_head_bias=True, + hiddens_cfg: DictConfig = None, # allows for hidden state transformations before the decoder ): super(MegatronTokenLevelEncoderDecoderModule, self).__init__() @@ -140,6 +142,7 @@ def __init__( self.share_token_embeddings = share_token_embeddings self.share_decoder_tokens_head_embeddings = share_decoder_tokens_head_embeddings self.tokens_head_bias = tokens_head_bias + self.hiddens_cfg = hiddens_cfg encoder_kv_channels, decoder_kv_channels = self._validate_config() @@ -388,8 +391,12 @@ def __init__( use_flash_attention=decoder_cfg.get('use_flash_attention', False), ) + hiddens_module = get_hiddens_module(hiddens_cfg) self.enc_dec_model = MegatronTransformerEncoderDecoderModule( - encoder=encoder, decoder=decoder, hidden_steps=encoder_cfg.get('hidden_steps', -1), + encoder=encoder, + decoder=decoder, + hidden_steps=encoder_cfg.get('hidden_steps', -1), + hiddens_module=hiddens_module, ) self._enc_dec_model_key = "enc_dec_model" @@ -455,6 +462,10 @@ def _validate_config(self): assert ( self.share_decoder_tokens_head_embeddings ), "Decoder token embeddings and the outputlayer must be shared when using pipeline model parallel size > 1" + assert ( + self.hiddens_cfg is None + ), "Hiddens module must not be enabled when using pipeline model parallel size > 1" + return encoder_kv_channels, decoder_kv_channels def set_input_tensor(self, input_tensor): @@ -493,6 +504,7 @@ def forward( dec_attn_mask=None, token_type_ids=None, labels=None, + batch_data=None, # additional data to be passed to hiddens module enc_output=None, # Result of running the entire encoder enc_output_attn_mask=None, enc_input=None, # Result of running encoder embedding only @@ -554,9 +566,11 @@ def forward( enc_layer_past=None, enc_get_key_value=False, enc_self_attention_relative_position_bias=encoder_self_attention_relative_position_bias, + batch_data=batch_data, ) else: enc_output = self.enc_dec_model.encoder_hidden_state + return enc_output else: if enc_output_attn_mask is None: @@ -598,10 +612,11 @@ def forward( enc_self_attention_relative_position_bias=encoder_self_attention_relative_position_bias, dec_self_attention_relative_position_bias=decoder_self_attention_relative_position_bias, dec_cross_attention_relative_position_bias=decoder_cross_attention_relative_position_bias, + batch_data=batch_data, ) if self.post_process and self.add_decoder: - dec_output, enc_output = output # [s, b, h] + dec_output, enc_output = output # [s, b, h], enc_output might be a dict if hiddens_module is used # project decoder output to vocabulary-size dimensions if self.share_decoder_tokens_head_embeddings: token_logits = self.tokens_head(dec_output, self.word_embeddings_weight()) @@ -609,6 +624,7 @@ def forward( token_logits = self.tokens_head(dec_output)[0] if labels is not None: + # compute loss here # [b, s] -> [s, b] labels = labels.transpose(0, 1).contiguous() @@ -625,11 +641,30 @@ def forward( # [s, b] -> [b, s] tokens_loss = tokens_loss.transpose(0, 1).contiguous() - return tokens_loss + # check if hiddens is used + if self.hiddens_cfg is not None: + loss_dict = self.enc_dec_model.hiddens_module.apply_loss_transforms( + outputs=enc_output, batch_data=batch_data, + ) + loss_dict["tokens_loss"] = tokens_loss + # We need to store default output in a known key, so that we can mimic default behaviour + loss_dict["output"] = tokens_loss + return loss_dict + else: + return tokens_loss else: + # else return token logits (and hiddens if needed) # [s, b, h] -> [b, s, h] token_logits = token_logits.transpose(0, 1).contiguous() - return token_logits + if self.hiddens_cfg is not None: + # return all hiddens and token logits + hiddens_dict = enc_output + hiddens_dict["token_logits"] = token_logits + # We need to store default output in a known key, so that we can mimic default behaviour + hiddens_dict["output"] = token_logits + return hiddens_dict + else: + return token_logits elif self.add_decoder and not self.add_encoder: decoder_output, _ = output diff --git a/nemo/collections/nlp/modules/common/megatron/transformations/__init__.py b/nemo/collections/nlp/modules/common/megatron/transformations/__init__.py new file mode 100644 index 000000000000..50a412ac2e13 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/transformations/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .megatron_hidden_loss import * +from .megatron_hidden_transform import * diff --git a/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hidden_loss.py b/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hidden_loss.py new file mode 100644 index 000000000000..f10c34d3fad3 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hidden_loss.py @@ -0,0 +1,189 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import torch + +__all__ = ["MegatronBaseHiddenLoss", "MegatronAMIMHiddenLoss", "MegatronVAEHiddenLoss"] + + +class MegatronBaseHiddenLoss(torch.nn.Module): + """ + Base class to calculate hidden state loss. + Returned dict includes a loss value and additional outputs. + """ + + def __init__(self, loss_weight=1.0, name=""): + super().__init__() + self.name = name + self.loss_weight = float(loss_weight) + + def __str__(self): + return super().__str__() + f"(name={self.name})" + + def _validate_inputs(self, inputs): + """Validate inputs""" + # validate inputs + if not set(self.input_names).issubset(set(inputs.keys())): + raise ValueError(f"Inputs should contain {self.input_names}, but got {inputs.keys()}") + + @property + def input_names(self): + """Returns and caches input names""" + # we always expect hiddens_mask to be used to mask out loss of padded elements + return self._input_names() + ["hiddens_mask"] + + def _input_names(self): + """Add here all required inputs""" + return [] + + def _loss(self, inputs, batch_data=None): + """ + We expect input shapes to be [S x B x H] for Sequence, Batch, Hidden sizes (due to tensor parallel support). + We return a dictionary with dimensions [B x S x H], [B x S], [B], or []. + + Implement your own loss calculations. Must return "loss" key. + loss shape - [B x S] for Batch, Sequence sizes + batch_data - a dictionary of additional data that can be used to calculate loss + + Returns: + dict: a dictionary with loss and additional outputs (must include "loss" key) + example: {"loss": 0.0} + """ + raise NotImplementedError("Please implement loss calculations in child class") + + def loss(self, inputs, batch_data=None): + """A wrapper around custom _loss that adds a weighted loss and name to the output dict""" + self._validate_inputs(inputs) + + loss_dict = self._loss(inputs, batch_data=batch_data) + if "loss" not in loss_dict: + raise KeyError("Loss dict must contain 'loss' key") + + # average loss over active steps only. loss [B x S] + loss = loss_dict["loss"] + # hiddens_mask has shape of [B x S] + hiddens_mask = inputs["hiddens_mask"].to(loss) + loss = loss * hiddens_mask + # sequence level loss [B x S] -> batch level loss [B] + loss = loss.sum(dim=1) / hiddens_mask.sum(dim=1).clamp(min=1.0) + + # compute batch level weighted loss (scalar) + weighted_loss = loss.sum() * self.loss_weight + + # store updated losses + loss_dict["loss"] = loss + loss_dict["weighted_loss"] = weighted_loss + loss_dict["weight_loss"] = torch.tensor(self.loss_weight).to(weighted_loss) + + return loss_dict + + +class MegatronAMIMHiddenLoss(MegatronBaseHiddenLoss): + """ + Based on + Implements A-MIM loss with a unit Normal anchor. + A-MIM - asymmetric MIM (without sampling) + """ + + def __init__(self, loss_weight=1.0, hidden_aggregation_method="sum", name="mim"): + super().__init__( + name=name, loss_weight=loss_weight, + ) + + # allows to determine how to aggregate hidden loss over hidden dimension + self.hidden_aggregation_method = hidden_aggregation_method + + def _input_names(self): + """Add here all required inputs""" + return ["z", "z_log_prob"] + + def _loss(self, inputs, batch_data=None): + """ + We expect input shapes to be [S x B x H] for Sequence, Batch, Hidden sizes (due to tensor parallel support). + We return a dictionary with dimensions [B x S x H], [B x S], [B], or []. + + Implement your own loss calculations. Must return "loss" key. + loss shape - [B x S] for Batch, Sequence sizes + batch_data - a dictionary of additional data that can be used to calculate loss + """ + z = inputs["z"] + # get posterior + log_prob_q_z_given_x = inputs["z_log_prob"] + # compute log prob of anchor a unit Normal distribution + log_prob_P_z = -0.5 * (math.log(2 * math.pi) + z.pow(2)) + # aggregate over hidden dimension, default is sum + log_prob_P_z = getattr(log_prob_P_z, self.hidden_aggregation_method)(dim=-1) + + # A-MIM loss = log_p_x_given_z - 0.5 * (log_prob_P_z + log_prob_q_z_given_x) + # here we return only the hidden loss part + loss = -0.5 * (log_prob_P_z + log_prob_q_z_given_x) + + # return losses shaped [B x S] + return { + "loss": loss.transpose(0, 1), + "log_prob_P_z": log_prob_P_z.transpose(0, 1), + "log_prob_q_z_given_x": log_prob_q_z_given_x.transpose(0, 1), + } + + +class MegatronVAEHiddenLoss(MegatronBaseHiddenLoss): + """ + Based on + Implements VAE loss with a unit Normal anchor. + """ + + def __init__(self, loss_weight=1.0, min_kl_value=None, name="vae"): + super().__init__( + name=name, loss_weight=loss_weight, + ) + + # minimum value for KL divergence + if min_kl_value is None: + self.min_kl_value = min_kl_value + else: + self.min_kl_value = float(min_kl_value) + + def _input_names(self): + """Add here all required inputs""" + return ["z", "z_log_prob"] + + def _loss(self, inputs, batch_data=None): + """ + We expect input shapes to be [S x B x H] for Sequence, Batch, Hidden sizes (due to tensor parallel support). + We return a dictionary with dimensions [B x S x H], [B x S], [B], or []. + + Implement your own loss calculations. Must return "loss" key. + loss shape - [B x S] for Batch, Sequence sizes + batch_data - a dictionary of additional data that can be used to calculate loss + """ + z = inputs["z"] + # get posterior + log_prob_q_z_given_x = inputs["z_log_prob"] + # compute log prob of anchor a unit Normal distribution + log_prob_p_z = -0.5 * (math.log(2 * math.pi) + z.pow(2)).sum(dim=-1) + + # VAE loss = log_p_x_given_z - KL(q(z|x) || p(z)) + kl_div = log_prob_q_z_given_x - log_prob_p_z + # here we return only the hidden loss part + loss = -kl_div + + # return losses shaped [B x S] + return { + "loss": loss.transpose(0, 1), + "kl_div": kl_div.transpose(0, 1), + "log_prob_p_z": log_prob_p_z.transpose(0, 1), + "log_prob_q_z_given_x": log_prob_q_z_given_x.transpose(0, 1), + } diff --git a/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hidden_transform.py b/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hidden_transform.py new file mode 100644 index 000000000000..1c424a6a069b --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hidden_transform.py @@ -0,0 +1,170 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math + +import torch + +from nemo.collections.nlp.modules.common.megatron.utils import init_method_normal + +try: + from megatron.core import tensor_parallel + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +if not HAVE_MEGATRON_CORE: + raise NotImplementedError("Megatron Core is required to use Megatron Hidden Transformations") + +__all__ = ["MegatronBaseHiddenTransform", "MegatronGaussianHiddenTransform"] + + +class MegatronBaseHiddenTransform(torch.nn.Module): + """Base class to apply hidden state transformations""" + + def __init__(self, name=""): + super().__init__() + + self.name = name + + def __str__(self): + return super().__str__() + f"(name={self.name})" + + @property + def input_names(self): + """ + Provide here all required inputs + """ + return [] + + @property + def output_names(self): + """ + Provide here all generated outputs + """ + return [] + + def _validate_inputs(self, inputs): + """Validate inputs""" + # validate inputs + if not set(self.input_names).issubset(set(inputs.keys())): + raise ValueError(f"Inputs should contain {self.input_names}, but got {inputs.keys()}") + + def _transform(self, inputs, batch_data=None): + """ + Implement your own transformations. + We expect here shapes to be [S x B x H] for Sequence, Batch, Hidden sizes (due to tensor parallel support). + """ + # by default we pass inputs. + outputs = inputs.copy() + + return outputs + + def transform(self, inputs, batch_data=None): + """Apply a transformations on the inputs (hiddens is always assumed)""" + # validate inputs + self._validate_inputs(inputs) + + outputs = self._transform(inputs, batch_data=batch_data) + + return outputs + + +class MegatronGaussianHiddenTransform(MegatronBaseHiddenTransform): + """ + Constructes a diagonal Gaussian distribution from the hidden states and samples from it using reparametrization. + """ + + def __init__(self, hidden_size, min_logvar=-6, init_method_std=0.02, name="cond_gaussian"): + super().__init__(name=name) + # limit smaller allowed variance (for numerical stability) + self.min_logvar = min_logvar + self.hidden_size = hidden_size + # project hiddens to mean and log variance (support tensor parallelism) + self.hiddens_to_mean_logvar = tensor_parallel.ColumnParallelLinear( + hidden_size, + hidden_size * 2, + gather_output=True, + init_method=init_method_normal(init_method_std), + skip_bias_add=False, + use_cpu_initialization=False, + bias=True, + sequence_parallel_enabled=False, + async_tensor_model_parallel_allreduce=True, + gradient_accumulation_fusion=False, + ) + + @property + def input_names(self): + """ + Provide here all required inputs + """ + return ["hiddens", "hiddens_mask"] + + @property + def output_names(self): + """ + Provide here all generated outputs + """ + return ["z_mean", "z_logvar", "z", "z_log_prob"] + + def _transform(self, inputs, batch_data=None): + """ + We expect here shapes to be [S x B x H] for Sequence, Batch, Hidden sizes (due to tensor parallel support). + + inputs: + hiddens: accepts a tensor of shape [S x B x H] + + outputs: + z: a sample from Gaussian a tensor of shape [S x B x H] + z_mean: mean of Gaussian a tensor of shape [S x B x H] + z_logvar: log variance of Gaussian a tensor of shape [S x B x H] + z_log_prob: log probability of z over posterior log q(z|x) a tensor of shape [S x B x H] + """ + hiddens = inputs["hiddens"] + # compute distribution's parameters (or use cached ones) + if "z_mean" in inputs and "z_logvar" in inputs: + z_mean = inputs["z_mean"] + z_logvar = inputs["z_logvar"] + else: + # ColumnLinear returns output and bias, we ignore bias here (already added to hiddens) + z_mean, z_logvar = self.hiddens_to_mean_logvar(hiddens)[0].chunk(2, dim=-1) + # clamp logvar + z_logvar = z_logvar.clamp(min=self.min_logvar) + # sample z with reparametrization (or use cached one) + if "z" in inputs: + z = inputs["z"] + z_log_prob = inputs.get("z_log_prob", None) + else: + e = torch.randn_like(hiddens) + z = (z_logvar * 0.5).exp() * e + z_mean + z_log_prob = None + + if z_log_prob is None: + # compute log probability of z under a diagonal Gaussian distribution + z_log_prob = -0.5 * (math.log(2 * math.pi) + z_logvar + (z - z_mean).pow(2) / z_logvar.exp()) + # sum over the last dimension (hidden_size) + z_log_prob = z_log_prob.sum(dim=-1) + + return { + "z": z, # [S x B x H] + "z_mean": z_mean, # [S x B x H] + "z_logvar": z_logvar, # [S x B x H] + "z_log_prob": z_log_prob, # [S x B] + } diff --git a/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hiddens.py b/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hiddens.py new file mode 100644 index 000000000000..3e869a70f20f --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hiddens.py @@ -0,0 +1,310 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +In order to register external hidden transforms and losses please use the following methods: +* register_hidden_loss(cls_name: str, class_path: str) +* register_hidden_transform(cls_name: str, class_path: str) + +See example config in: examples/nlp/language_modeling/conf/megatron_hiddens_base_config.yaml +""" + +import functools +import itertools +from typing import List + +import torch +from omegaconf.dictconfig import DictConfig +from omegaconf.omegaconf import OmegaConf + +from nemo.collections.nlp.modules.common.megatron.transformations.megatron_hidden_loss import MegatronBaseHiddenLoss +from nemo.collections.nlp.modules.common.megatron.transformations.megatron_hidden_transform import ( + MegatronBaseHiddenTransform, +) +from nemo.utils import logging +from nemo.utils.model_utils import import_class_by_path + +__all__ = ["MegatronHiddensModule"] + +# a registry of all hidden transforms (maps name to class path) +_LOSS_CLASS_REGISTRY = { + "a_mim": "nemo.collections.nlp.modules.common.megatron.transformations.megatron_hidden_loss.MegatronAMIMHiddenLoss", + "vae": "nemo.collections.nlp.modules.common.megatron.transformations.megatron_hidden_loss.MegatronVAEHiddenLoss", +} + +# a registry of all hidden losses (maps name to class path) +_TRANSFORM_CLASS_REGISTRY = { + "cond_gaussian": "nemo.collections.nlp.modules.common.megatron.transformations.megatron_hidden_transform.MegatronGaussianHiddenTransform", +} + + +def get_registered_hiddens(): + """ + Return: + A dictionary with all registered hidden transforms and losses. + + Example: + { + "loss": ["a-mim", "vae"], + "transform": ["cond_gaussian"], + } + """ + return { + "loss": list(_LOSS_CLASS_REGISTRY.keys()), + "transform": list(_TRANSFORM_CLASS_REGISTRY.keys()), + } + + +def register_hidden_loss(cls_name: str, class_path: str): + """ + Register a hidden loss. + + + Args: + cls_name: name of the class + class_path: path to the class (e.g., "nemo.collections.nlp.modules.common.megatron.transformations.megatron_hidden_transform.MegatronGaussianHiddenTransform") + """ + if cls_name in _LOSS_CLASS_REGISTRY: + raise ValueError(f"Cannot register duplicate hidden loss ({cls_name})") + _LOSS_CLASS_REGISTRY[cls_name] = class_path + logging.info(f"Registered hidden loss {cls_name} at {class_path}") + + +def register_hidden_transform(cls_name: str, class_path: str): + """ + Register a hidden transform. + + Args: + cls_name: name of the class + class_path: path to the class (e.g., "nemo.collections.nlp.modules.common.megatron.transformations.megatron_hidden_transform.MegatronGaussianHiddenTransform") + """ + if cls_name in _TRANSFORM_CLASS_REGISTRY: + raise ValueError(f"Cannot register duplicate hidden transform ({cls_name})") + _TRANSFORM_CLASS_REGISTRY[cls_name] = class_path + logging.info(f"Registered hidden transform {cls_name} at {class_path}") + + +def get_hiddens_module(cfg=None): + """Build a MegatronHiddensModule from a configuration cfg""" + # Build a hiddens module if config is provided. + if cfg is None: + return None + + logging.info(f"NOTE: Adding hiddens transforms and losses") + + # build all hidden transforms. We support a list or a dictionary of transforms (list enforces order) + transform_cfg = cfg.get("transform", []) + if isinstance(transform_cfg, (DictConfig, dict)): + transform_cfg = [transform_cfg] + hidden_transforms = [] + # here we expect transform_cfg to be a list of dictionaries + for cur_list_cfg in transform_cfg: + for name, cur_cfg in cur_list_cfg.items(): + cls_kwargs = OmegaConf.to_container(cur_cfg) + if not "cls_name" in cls_kwargs: + raise KeyError(f"Missing 'cls_name' in hidden transform {name}") + + cls_name = cls_kwargs.pop("cls_name") + # add name based on dictionary if not given in conf + if "name" not in cls_kwargs: + cls_kwargs["name"] = name + if cls_name not in _TRANSFORM_CLASS_REGISTRY: + raise KeyError(f"Unknown hidden transform {cls_name}, available: {_TRANSFORM_CLASS_REGISTRY.keys()}") + try: + cur_transform = import_class_by_path(_TRANSFORM_CLASS_REGISTRY[cls_name])(**cls_kwargs) + except Exception as e: + logging.error(f"Failed to build hidden transform {name} with cfg={cur_cfg}") + raise e + + hidden_transforms.append(cur_transform) + logging.info(f"Added transform {name} with cfg={cur_cfg}") + + # build all hidden losses + loss_cfg = cfg.get("loss", []) + if isinstance(loss_cfg, (DictConfig, dict)): + loss_cfg = [loss_cfg] + hidden_loss_transforms = [] + # here we expect loss_cfg to be a list of dictionaries + for cur_list_cfg in loss_cfg: + for name, cur_cfg in cur_list_cfg.items(): + cls_kwargs = OmegaConf.to_container(cur_cfg) + if not "cls_name" in cls_kwargs: + raise KeyError(f"Missing 'cls_name' in hidden loss {name}") + + cls_name = cls_kwargs.pop("cls_name") + # add name based on dictionary if not given in conf + if "name" not in cls_kwargs: + cls_kwargs["name"] = name + if cls_name not in _LOSS_CLASS_REGISTRY: + raise KeyError(f"Unknown hidden loss {cls_name}, available: {_LOSS_CLASS_REGISTRY.keys()}") + try: + cur_loss = import_class_by_path(_LOSS_CLASS_REGISTRY[cls_name])(**cls_kwargs) + except Exception as e: + logging.error(f"Failed to build hidden loss {name} with cfg={cur_cfg}") + raise e + hidden_loss_transforms.append(cur_loss) + logging.info(f"Added loss {name} with cfg={cur_cfg}") + + enc_output_name = cfg.get("enc_output_name", "hiddens") + + return MegatronHiddensModule( + hidden_transforms=hidden_transforms, + hidden_loss_transforms=hidden_loss_transforms, + enc_output_name=enc_output_name, + ) + + +class MegatronHiddensModule(torch.nn.Module): + """ + This class jointly handles the hidden transforms and hidden loss transforms. + It helps in validating, and applying the transforms. + """ + + def __init__( + self, + hidden_transforms: List[MegatronBaseHiddenLoss] = [], + hidden_loss_transforms: List[MegatronBaseHiddenTransform] = [], + enc_output_name: str = "hiddens", # name (key) of the encoder output + tokens_loss_weight: float = 1.0, # weight of the tokens loss + loss_prefix: str = "hiddens_", # if not None or "", add this prefix to all loss names + ): + super().__init__() + self.hidden_transforms = hidden_transforms + self.hidden_loss_transforms = hidden_loss_transforms + self.enc_output_name = enc_output_name + self.tokens_loss_weight = tokens_loss_weight + self.loss_prefix = loss_prefix + + # register all hidden / loss transforms as submodules to support learned parameters + if not all([isinstance(ht, MegatronBaseHiddenLoss) for ht in self.hidden_loss_transforms]): + raise TypeError( + f"hidden_loss_transforms should be a list of MegatronBaseHiddenLoss, but got {hidden_loss_transforms}" + ) + self.hidden_loss_transforms = torch.nn.ModuleList(self.hidden_loss_transforms) + if not all([isinstance(ht, MegatronBaseHiddenTransform) for ht in self.hidden_transforms]): + raise TypeError( + f"hidden_transforms should be a list of MegatronBaseHiddenTransform, but got {hidden_transforms}" + ) + self.hidden_transforms = torch.nn.ModuleList(self.hidden_transforms) + + # validate the inputs and outputs of all hidden transforms (make sure there are no duplicate output names) + duplicate_names = {} + # initialize with available outputs from hidden transforms with hiddens and mask as default + hidden_outputs = set(["hiddens", "hiddens_mask", "enc_output"]) + for ht in self.hidden_transforms: + # validate that all required inputs are available by order of hidden transforms + cur_input_names = set(ht.input_names) + if not cur_input_names.issubset(hidden_outputs): + raise ValueError( + f"Hidden transform {ht.name} requires inputs {cur_input_names - hidden_outputs} that are not available" + ) + + # collect all duplicate output names + cur_hidden_outputs = set(ht.output_names) + if not cur_hidden_outputs.isdisjoint(hidden_outputs): + duplicate_names[ht.name] = list(cur_hidden_outputs.intersection(hidden_outputs)) + + hidden_outputs.update(cur_hidden_outputs) + + # fail here reporting all duplicate output names + if duplicate_names: + raise ValueError( + f"Hidden transforms have duplicate outputs {{name: [duplicate outputs]}} = {duplicate_names}" + ) + + # validate that all loss transforms are supported by output of hidden transforms ("hiddens" is given by default) + loss_inputs = set(itertools.chain(*[lt.input_names for lt in self.hidden_loss_transforms])) + if not loss_inputs.issubset(hidden_outputs): + loss_inputs_dict = {lt.name: lt.input_names for lt in self.hidden_loss_transforms} + raise ValueError( + f"Loss transforms inputs = {loss_inputs - hidden_outputs} are not supported by hidden transforms with hidden_outputs = {hidden_outputs}, expected inputs per loss = {loss_inputs_dict}" + ) + + @functools.cached_property + def hidden_outputs(self): + """Get the hidden outputs from all the hidden transforms""" + all_output_names = [ht.output_names for ht in self.hidden_transforms] + [["hiddens", "hiddens_mask"]] + output_names = set().union(*all_output_names) + + return list(output_names) + + @functools.cached_property + def loss_inputs(self): + """Get the loss inputs from all the loss transforms""" + loss_inputs = set().union(*[lt.input_names for lt in self.hidden_loss_transforms]) + return list(loss_inputs) + + def apply_hidden_transforms(self, inputs, batch_data=None): + """ + Apply hidden transforms + Args: + inputs: a dictionary of inputs, with "hiddens" as the default key for hidden states + batch_data: a dictionary of batch data (e.g. "input_features"), optional + + Returns: + outputs: a dictionary of outputs, collecting + """ + outputs = inputs.copy() + for hidden_transform in self.hidden_transforms: + # make sure to collect all outputs from hidden transforms + outputs.update(hidden_transform.transform(outputs, batch_data=batch_data)) + + # update final encoder output + outputs["enc_output"] = outputs[self.enc_output_name] + + return outputs + + def apply_loss_transforms(self, outputs, batch_data=None): + """ + Apply loss transforms + Args: + outputs: a dictionary of outputs (after hidden transforms) + batch_data: a dictionary of batch data (e.g. "target_ids"), optional + + Returns: + loss_dict: a dictionary of all losses, + { + loss: joint loss (float), + _*: loss values from loss transforms, could be loss, or loss elements + } + """ + loss_dict = {} + joint_loss = 0.0 + for i, loss_transform in enumerate(self.hidden_loss_transforms): + cur_loss_dict = loss_transform.loss(outputs, batch_data=batch_data) + joint_loss = joint_loss + cur_loss_dict["weighted_loss"] + cur_loss_dict.pop("weighted_loss") + # add name to loss values + if loss_transform.name: + cur_loss_dict = {f"{loss_transform.name}_{k}": v for k, v in cur_loss_dict.items()} + + # check if cur_loss keys are unique - we do not allow to override keys + dup_keys = set(cur_loss_dict.keys()).intersection(set(loss_dict.keys())) + if len(dup_keys): + raise ValueError( + f"Loss transform ({i}) {loss_transform} is trying to override the following loss keys {list(dup_keys)}" + ) + # update loss dict + loss_dict.update(cur_loss_dict) + + # joint weighted loss (float) + loss_dict["loss"] = joint_loss + + # add prefix to all loss keys (default to 'hiddens_') + if self.loss_prefix: + loss_dict = {f"{self.loss_prefix}{k}": v for k, v in loss_dict.items()} + + # add tokens loss weight (to be used by caller, or be ignored) + loss_dict["tokens_loss_weight"] = torch.tensor(self.tokens_loss_weight).to(joint_loss) + + return loss_dict diff --git a/nemo/collections/nlp/modules/common/megatron/utils.py b/nemo/collections/nlp/modules/common/megatron/utils.py index 7c7a428fa43f..045509d5adf9 100644 --- a/nemo/collections/nlp/modules/common/megatron/utils.py +++ b/nemo/collections/nlp/modules/common/megatron/utils.py @@ -383,8 +383,12 @@ def get_iterator_k_split(batch: List[torch.Tensor], num_microbatches: int) -> It microbatches = [dict(elem) for elem in microbatches] else: assert batch[0].shape[0] % num_microbatches == 0, "Issue with batch size configuration!" - split_batch = [torch.tensor_split(item, num_microbatches, dim=0) for item in batch] - microbatches = [[elem[i] for elem in split_batch] for i in range(num_microbatches)] + split_batch = [ + torch.tensor_split(item, num_microbatches, dim=0) if torch.is_tensor(item) else item for item in batch + ] + microbatches = [ + [elem[i] if elem is not None else elem for elem in split_batch] for i in range(num_microbatches) + ] return itertools.chain(microbatches)