@@ -530,21 +530,19 @@ def greedy_until(
530
530
starting_batch_size = STARTING_BATCH_SIZE
531
531
results = []
532
532
533
- for split_start , split_end in tqdm (
534
- dataset .splits_start_end_iterator (),
533
+ for split in tqdm (
534
+ dataset .splits_iterator (),
535
535
total = dataset .num_dataset_splits ,
536
536
desc = "Splits" ,
537
537
position = 0 ,
538
538
disable = self .disable_tqdm ,
539
539
):
540
- if dataset [0 ].generation_size is None :
540
+ if split [0 ].generation_size is None :
541
541
# No constraints on the generation size: max length allowed is the max model context
542
542
max_context_continuation_size_allowed = self .max_length
543
543
else :
544
544
# Longest context in the current split is the first item (since we sort reversed)
545
- longest_context_continuation_size_in_split = (
546
- len (dataset [0 ].tokenized_context ) + dataset [0 ].generation_size
547
- )
545
+ longest_context_continuation_size_in_split = len (split [0 ].tokenized_context ) + split [0 ].generation_size
548
546
max_context_continuation_size_allowed = min (
549
547
longest_context_continuation_size_in_split , self .max_length
550
548
)
@@ -556,7 +554,7 @@ def greedy_until(
556
554
# For next iteration, since the batch will be smaller, we'll test a bigger batch size
557
555
starting_batch_size = batch_size * 2
558
556
559
- dataloader = DataLoader (dataset , batch_size = batch_size , collate_fn = lambda batch : batch )
557
+ dataloader = DataLoader (split , batch_size = batch_size , collate_fn = lambda batch : batch )
560
558
if self .accelerator :
561
559
dataloader = self .accelerator .prepare (dataloader )
562
560
@@ -765,9 +763,9 @@ def _loglikelihood_tokens(
765
763
starting_batch_size = STARTING_BATCH_SIZE
766
764
res = []
767
765
768
- for split_start , split_end in tqdm (dataset .splits_start_end_iterator ()):
769
- context_enc = dataset [0 ].tokenized_context
770
- continuation_enc = dataset [0 ].tokenized_continuation
766
+ for split in tqdm (dataset .splits_iterator ()):
767
+ context_enc = split [0 ].tokenized_context
768
+ continuation_enc = split [0 ].tokenized_continuation
771
769
if rolling : # we take all the sequence in rolling mode
772
770
max_context_continuation_size_allowed = len (context_enc + continuation_enc )
773
771
else : # in normal mode, we left cut the context if needed
@@ -782,7 +780,7 @@ def _loglikelihood_tokens(
782
780
)
783
781
starting_batch_size = batch_size * 2
784
782
785
- dataloader = DataLoader (dataset , batch_size = batch_size , collate_fn = lambda batch : batch )
783
+ dataloader = DataLoader (split , batch_size = batch_size , collate_fn = lambda batch : batch )
786
784
if self .accelerator :
787
785
dataloader = self .accelerator .prepare (dataloader )
788
786
@@ -1009,13 +1007,13 @@ def _loglikelihood_single_token(
1009
1007
starting_batch_size = STARTING_BATCH_SIZE
1010
1008
res = []
1011
1009
1012
- for split_start , split_end in tqdm (dataset .splits_start_end_iterator ()):
1013
- context_enc = dataset [0 ].tokenized_context
1010
+ for split in tqdm (dataset .splits_iterator ()):
1011
+ context_enc = split [0 ].tokenized_context
1014
1012
max_context = len (context_enc [- self .max_length :])
1015
1013
batch_size = self ._get_batch_size (override_bs = self .config .batch_size , max_input_length = max_context )
1016
1014
starting_batch_size = batch_size * 2
1017
1015
1018
- dataloader = DataLoader (dataset , batch_size = starting_batch_size , collate_fn = lambda batch : batch )
1016
+ dataloader = DataLoader (split , batch_size = starting_batch_size , collate_fn = lambda batch : batch )
1019
1017
if self .accelerator is not None :
1020
1018
dataloader = self .accelerator .prepare (dataloader )
1021
1019
0 commit comments