Skip to content

Commit f177886

Browse files
vysargessh-meister
authored andcommitted
Add token count and sequence length logging for MegatronGPTSFTModel as a config option (NVIDIA#8136)
Signed-off-by: Valerie Sarge <[email protected]> Signed-off-by: Sasha Meister <[email protected]>
1 parent b49d539 commit f177886

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py

+6
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ def _process_example(self, example):
376376
'context_length': len(context_ids),
377377
'answer_ids': answer_ids,
378378
'metadata': metadata,
379+
'token_count': len(input_ids),
379380
}
380381

381382
return processed_example
@@ -426,6 +427,7 @@ def collate_fn(self, batch):
426427
answers = [item['answer_ids'] for item in batch]
427428
loss_mask = [self._build_loss_mask(item)[1:] for item in batch]
428429
metadata = [item['metadata'] for item in batch]
430+
token_count = [item['token_count'] for item in batch]
429431

430432
max_length = max(max([len(x) for x in input_ids]), max([len(x) for x in contexts]) + self.tokens_to_generate)
431433
# increase max length to nearest multiple of 4 or 8
@@ -457,6 +459,7 @@ def collate_fn(self, batch):
457459
'context_lengths': context_lengths,
458460
'answers': answers,
459461
'metadata': metadata,
462+
'token_count': token_count,
460463
}
461464

462465
return processed_batch
@@ -516,6 +519,8 @@ def collate_fn(self, batch):
516519

517520
loss_mask = [self._build_loss_mask(item) for item in batch]
518521

522+
token_count = [item.shape[0] for item in input_ids]
523+
519524
if self.pad_to_max_length:
520525
max_length = self.max_seq_length
521526
else:
@@ -556,6 +561,7 @@ def collate_fn(self, batch):
556561
'loss_mask': torch.LongTensor(loss_mask),
557562
'position_ids': torch.LongTensor(position_ids),
558563
'cu_seqlens': torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32
564+
'token_count': token_count,
559565
}
560566

561567
return processed_batch

nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py

+9
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,20 @@ def _determine_log_key(self, data_config, dataloader_idx, metric_name, mode):
335335

336336
def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
337337
batch = next(dataloader_iter)
338+
339+
log_token_counts = self.cfg.get('log_token_counts', False)
340+
if log_token_counts:
341+
token_count_avg = sum(batch['token_count']) / len(batch['token_count'])
342+
338343
# Pass only torch.Tensor to prevent errors when process get_iterator_k_split()
339344
batch = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)}
340345
_, seq_length = batch['tokens'].shape
341346
data_iter = get_iterator_k_split(batch, get_num_microbatches())
342347

348+
if log_token_counts:
349+
self.log('seq_length_padded', seq_length, prog_bar=True, batch_size=1)
350+
self.log('tokens_avg', token_count_avg, prog_bar=True, sync_dist=True, batch_size=1)
351+
343352
# handle asynchronous grad reduction
344353
no_sync_func = None
345354
grad_sync_func = None

0 commit comments

Comments
 (0)