From fe2070efd852e056e6b4fa45ab0fc3749d4fa706 Mon Sep 17 00:00:00 2001 From: Valerie Sarge Date: Mon, 8 Jan 2024 15:42:21 -0800 Subject: [PATCH] Add token count and sequence length logging for MegatronGPTSFTModel as a config option Signed-off-by: Valerie Sarge --- .../data/language_modeling/megatron/gpt_sft_dataset.py | 6 ++++++ .../models/language_modeling/megatron_gpt_sft_model.py | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index 63c4f3459682a..8b73e158e898a 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -376,6 +376,7 @@ def _process_example(self, example): 'context_length': len(context_ids), 'answer_ids': answer_ids, 'metadata': metadata, + 'token_count': len(input_ids), } return processed_example @@ -426,6 +427,7 @@ def collate_fn(self, batch): answers = [item['answer_ids'] for item in batch] loss_mask = [self._build_loss_mask(item)[1:] for item in batch] metadata = [item['metadata'] for item in batch] + token_count = [item['token_count'] for item in batch] max_length = max(max([len(x) for x in input_ids]), max([len(x) for x in contexts]) + self.tokens_to_generate) # increase max length to nearest multiple of 4 or 8 @@ -457,6 +459,7 @@ def collate_fn(self, batch): 'context_lengths': context_lengths, 'answers': answers, 'metadata': metadata, + 'token_count': token_count, } return processed_batch @@ -516,6 +519,8 @@ def collate_fn(self, batch): loss_mask = [self._build_loss_mask(item) for item in batch] + token_count = [item.shape[0] for item in input_ids] + if self.pad_to_max_length: max_length = self.max_seq_length else: @@ -556,6 +561,7 @@ def collate_fn(self, batch): 'loss_mask': torch.LongTensor(loss_mask), 'position_ids': torch.LongTensor(position_ids), 'cu_seqlens': torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32 + 'token_count': token_count, } return processed_batch diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 350a70ae5439b..c2e67918e00d3 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -335,11 +335,20 @@ def _determine_log_key(self, data_config, dataloader_idx, metric_name, mode): def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): batch = next(dataloader_iter) + + log_token_counts = self.cfg.get('log_token_counts', False) + if log_token_counts: + token_count_avg = sum(batch['token_count']) / len(batch['token_count']) + # Pass only torch.Tensor to prevent errors when process get_iterator_k_split() batch = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)} _, seq_length = batch['tokens'].shape data_iter = get_iterator_k_split(batch, get_num_microbatches()) + if log_token_counts: + self.log('seq_length_padded', seq_length, prog_bar=True, batch_size=1) + self.log('tokens_avg', token_count_avg, prog_bar=True, sync_dist=True, batch_size=1) + # handle asynchronous grad reduction no_sync_func = None grad_sync_func = None