@@ -376,6 +376,7 @@ def _process_example(self, example):
376
376
'context_length' : len (context_ids ),
377
377
'answer_ids' : answer_ids ,
378
378
'metadata' : metadata ,
379
+ 'token_count' : len (input_ids ),
379
380
}
380
381
381
382
return processed_example
@@ -426,6 +427,7 @@ def collate_fn(self, batch):
426
427
answers = [item ['answer_ids' ] for item in batch ]
427
428
loss_mask = [self ._build_loss_mask (item )[1 :] for item in batch ]
428
429
metadata = [item ['metadata' ] for item in batch ]
430
+ token_count = [item ['token_count' ] for item in batch ]
429
431
430
432
max_length = max (max ([len (x ) for x in input_ids ]), max ([len (x ) for x in contexts ]) + self .tokens_to_generate )
431
433
# increase max length to nearest multiple of 4 or 8
@@ -457,6 +459,7 @@ def collate_fn(self, batch):
457
459
'context_lengths' : context_lengths ,
458
460
'answers' : answers ,
459
461
'metadata' : metadata ,
462
+ 'token_count' : token_count ,
460
463
}
461
464
462
465
return processed_batch
@@ -516,6 +519,8 @@ def collate_fn(self, batch):
516
519
517
520
loss_mask = [self ._build_loss_mask (item ) for item in batch ]
518
521
522
+ token_count = [item .shape [0 ] for item in input_ids ]
523
+
519
524
if self .pad_to_max_length :
520
525
max_length = self .max_seq_length
521
526
else :
@@ -556,6 +561,7 @@ def collate_fn(self, batch):
556
561
'loss_mask' : torch .LongTensor (loss_mask ),
557
562
'position_ids' : torch .LongTensor (position_ids ),
558
563
'cu_seqlens' : torch .IntTensor (cu_seqlens ), # cu_seqlens_q must be in dtype torch.int32
564
+ 'token_count' : token_count ,
559
565
}
560
566
561
567
return processed_batch
0 commit comments