From f2451f82437c1d7c69e930086bc7c29cb3d7bd87 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 20 Mar 2023 16:59:13 +0000 Subject: [PATCH] Further changes to fix fine-tuning Signed-off-by: SeanNaren --- .../glue_benchmark/glue_benchmark_dataset.py | 22 +++++++++++++++++-- .../language_modeling/megatron_glue_model.py | 8 +++++++ .../megatron_lm_encoder_decoder_model.py | 14 +++++++----- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/nemo/collections/nlp/data/glue_benchmark/glue_benchmark_dataset.py b/nemo/collections/nlp/data/glue_benchmark/glue_benchmark_dataset.py index 8a1ebed77d3a..2a14aa5afc58 100644 --- a/nemo/collections/nlp/data/glue_benchmark/glue_benchmark_dataset.py +++ b/nemo/collections/nlp/data/glue_benchmark/glue_benchmark_dataset.py @@ -382,6 +382,7 @@ def __init__( max_seq_length_decoder: int = 128, use_cache: bool = True, prefix_override: str = None, + pad_to_max_length: bool = True, ): """ Processes GLUE datasets @@ -392,10 +393,12 @@ def __init__( max_seq_length: max sequence length minus 2 for [CLS] and [SEP] use_cache: whether to use data cache prefix_override: if you want to override default prompt for this task specify this via a string. + pad_to_max_length: If true, pad to the maximum length. """ super().__init__(file_name, task_name, tokenizer, max_seq_length, use_cache, compute_features=False) self.max_seq_length = max_seq_length self.max_seq_length_decoder = max_seq_length_decoder + self.pad_to_max_length = pad_to_max_length self.processor = processors[self.task_name]() self.prefix_override = prefix_override self.features = self.convert_examples_to_features() @@ -412,9 +415,16 @@ def collate_fn(self, batch): dec_input = [item['text_dec'] for item in batch] labels = [item['labels'] for item in batch] - max_dec_input_length = max([len(item) for item in dec_input]) if dec_input else 0 max_enc_query_length = max([len(item) for item in enc_query]) if enc_query else 0 + max_dec_input_length = max([len(item) for item in dec_input]) if dec_input else 0 max_label_length = max([len(item) for item in labels]) if labels else 0 + if self.pad_to_max_length: + assert max_enc_query_length <= self.max_seq_length + assert max_dec_input_length <= self.max_seq_length_decoder + assert max_label_length <= self.max_seq_length_decoder + max_enc_query_length = self.max_seq_length + max_dec_input_length = self.max_seq_length_decoder + max_label_length = self.max_seq_length_decoder loss_mask = [([1] * (len(item))) + ([0] * (max_label_length - len(item))) for item in labels] enc_query = [item + [self.tokenizer.pad_id] * (max_enc_query_length - len(item)) for item in enc_query] @@ -488,10 +498,18 @@ def __init__( use_cache: bool = True, prefix_override: str = None, lang_list: List[str] = None, + pad_to_max_length: bool = True, ): self.lang_list = set(lang_list) super().__init__( - file_name, task_name, tokenizer, max_seq_length, max_seq_length_decoder, use_cache, prefix_override + file_name, + task_name, + tokenizer, + max_seq_length, + max_seq_length_decoder, + use_cache, + prefix_override, + pad_to_max_length, ) if len(lang_list) <= 0 or lang_list is None: raise ValueError(f"Found an empty or None lang_list for {self.task_name}") diff --git a/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py b/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py index 5cc0f7ea3a32..541f9b663e9f 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py @@ -82,3 +82,11 @@ def build_train_valid_test_datasets(self, stage): self._train_ds = self._build_dataset(self.cfg.data.train_ds, check_implict_grad_acc=False) logging.info(f'Length of train dataset: {len(self._train_ds)}') logging.info(f'Finished building GLUE/XNLI datasets.') + + @property + def max_decoder_seq_length(self) -> int: + return self._train_ds.max_seq_length_decoder + + @property + def max_encoder_seq_length(self) -> int: + return self._train_ds.max_seq_length diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index 19a3ec8a3d56..98ed2b11f531 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -306,7 +306,7 @@ def training_step(self, dataloader_iter, batch_idx): # we zero grads here because we also call backward in the megatron fwd/bwd functions self._optimizer.zero_grad() - tensor_shape = [self.cfg.seq_length, self.cfg.micro_batch_size, self.cfg.encoder.hidden_size] + tensor_shape = [self.max_encoder_seq_length, self.cfg.micro_batch_size, self.cfg.encoder.hidden_size] fwd_bwd_function = get_forward_backward_func() @@ -317,7 +317,7 @@ def training_step(self, dataloader_iter, batch_idx): num_microbatches=get_num_microbatches(), forward_only=False, tensor_shape=tensor_shape, - decoder_seq_length=self.decoder_seq_length, + decoder_seq_length=self.max_decoder_seq_length, dtype=self.autocast_dtype, grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, ) @@ -383,9 +383,13 @@ def training_step(self, dataloader_iter, batch_idx): return loss_mean @property - def decoder_seq_length(self) -> int: + def max_decoder_seq_length(self) -> int: return self._cfg.data.seq_length_dec + @property + def max_encoder_seq_length(self) -> int: + return self.cfg.seq_length + def backward(self, *args, **kwargs): """ LightningModule hook to do backward. We want this to do nothing since we run backward in the fwd/bwd functions from apex. @@ -612,7 +616,7 @@ def validation_step(self, dataloader_iter, batch_idx): """ return_values - if given, returns a dictionary with given keys and corresponding values """ - tensor_shape = [self.cfg.seq_length, self.cfg.micro_batch_size, self.cfg.encoder.hidden_size] + tensor_shape = [self.max_encoder_seq_length, self.cfg.micro_batch_size, self.cfg.encoder.hidden_size] fwd_bwd_func = get_forward_backward_func() losses_reduced_per_micro_batch = fwd_bwd_func( forward_step_func=self.get_forward_output_and_loss_func(), @@ -621,7 +625,7 @@ def validation_step(self, dataloader_iter, batch_idx): forward_only=True, tensor_shape=tensor_shape, num_microbatches=get_num_microbatches(), - decoder_seq_length=self.decoder_seq_length, + decoder_seq_length=self.max_decoder_seq_length, dtype=self.autocast_dtype, grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, )