Skip to content

Commit

Permalink
Add token count and sequence length logging for MegatronGPTSFTModel a…
Browse files Browse the repository at this point in the history
…s a config option (#8136)

Signed-off-by: Valerie Sarge <[email protected]>
  • Loading branch information
vysarge authored Jan 14, 2024
1 parent c30536d commit 1fede57
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def _process_example(self, example):
'context_length': len(context_ids),
'answer_ids': answer_ids,
'metadata': metadata,
'token_count': len(input_ids),
}

return processed_example
Expand Down Expand Up @@ -426,6 +427,7 @@ def collate_fn(self, batch):
answers = [item['answer_ids'] for item in batch]
loss_mask = [self._build_loss_mask(item)[1:] for item in batch]
metadata = [item['metadata'] for item in batch]
token_count = [item['token_count'] for item in batch]

max_length = max(max([len(x) for x in input_ids]), max([len(x) for x in contexts]) + self.tokens_to_generate)
# increase max length to nearest multiple of 4 or 8
Expand Down Expand Up @@ -457,6 +459,7 @@ def collate_fn(self, batch):
'context_lengths': context_lengths,
'answers': answers,
'metadata': metadata,
'token_count': token_count,
}

return processed_batch
Expand Down Expand Up @@ -516,6 +519,8 @@ def collate_fn(self, batch):

loss_mask = [self._build_loss_mask(item) for item in batch]

token_count = [item.shape[0] for item in input_ids]

if self.pad_to_max_length:
max_length = self.max_seq_length
else:
Expand Down Expand Up @@ -556,6 +561,7 @@ def collate_fn(self, batch):
'loss_mask': torch.LongTensor(loss_mask),
'position_ids': torch.LongTensor(position_ids),
'cu_seqlens': torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32
'token_count': token_count,
}

return processed_batch
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,20 @@ def _determine_log_key(self, data_config, dataloader_idx, metric_name, mode):

def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
batch = next(dataloader_iter)

log_token_counts = self.cfg.get('log_token_counts', False)
if log_token_counts:
token_count_avg = sum(batch['token_count']) / len(batch['token_count'])

# Pass only torch.Tensor to prevent errors when process get_iterator_k_split()
batch = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)}
_, seq_length = batch['tokens'].shape
data_iter = get_iterator_k_split(batch, get_num_microbatches())

if log_token_counts:
self.log('seq_length_padded', seq_length, prog_bar=True, batch_size=1)
self.log('tokens_avg', token_count_avg, prog_bar=True, sync_dist=True, batch_size=1)

# handle asynchronous grad reduction
no_sync_func = None
grad_sync_func = None
Expand Down

0 comments on commit 1fede57

Please sign in to comment.