diff --git a/src/lighteval/models/nanotron_model.py b/src/lighteval/models/nanotron_model.py index 89c1b75c9..69ad420f1 100644 --- a/src/lighteval/models/nanotron_model.py +++ b/src/lighteval/models/nanotron_model.py @@ -20,7 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -# ruff: noqa: C901,E120 +# ruff: noqa: C901 import os import time from typing import List, Optional, Tuple, Type, Union @@ -54,6 +54,7 @@ LoglikelihoodDataset, LoglikelihoodSingleTokenDataset, ) +from lighteval.logging.hierarchical_logger import hlog_warn from lighteval.models.base_model import LightevalModel from lighteval.models.model_config import EnvConfig from lighteval.models.model_output import Batch, GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn @@ -1138,7 +1139,7 @@ def greedy_until( # automatic (variable) batch size detection for vectorization # pull longest context sample from request for request in requests: - request.stop_sequence = list(request.stop_sequence) + [self.tokenizer.eos_token] + request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token] request.tokenized_context = self.tok_encode(request.context) dataset = GenerativeTaskDatasetNanotron(requests=requests, dataset_splits=dataset_splits) @@ -1161,13 +1162,20 @@ def greedy_until( dataset.split_start = subset_start dataset.split_end = min(subset_start + subset_length, total_length) - context_enc = dataset[0][1].tokenized_context - max_gen = max(item[1].generation_size for item in dataset) - max_input_length = min(len(context_enc) + max_gen, self.max_length) + if dataset[0][1].generation_size is None: + # No constraints on the generation size: max length allowed is the max model context + max_input_length = self.max_length + else: + # Longest context in the current split is the first item (since we sort reversed) + context_enc = dataset[0][1].tokenized_context + max_gen = max(item[1].generation_size for item in dataset) + max_input_length = min(len(context_enc) + max_gen, self.max_length) + batch_size = self._get_batch_size( override_bs=override_bs, max_input_length=max_input_length, starting_batch_size=starting_batch_size ) - starting_batch_size = batch_size * 2 # for the next round + # For next iteration, since the batch will be smaller, we'll test a bigger batch size + starting_batch_size = batch_size * 2 # For the DP replicas distributed_sampler = GenDistributedSampler( @@ -1188,7 +1196,7 @@ def greedy_until( ) tq = tqdm(dataloader, desc=f"greedy in subset {s} Node {dist.get_rank(self.parallel_context.world_pg)}") - for j, all_batch in enumerate(tq): + for j, indexed_batch in enumerate(tq): if j < 3: log_rank( f"Memory usage: {torch.cuda.memory_allocated() / 1024**2:.2f}MB. Peak reserved memory: {torch.cuda.max_memory_reserved() / 1024**2:.2f}MB", @@ -1198,22 +1206,59 @@ def greedy_until( rank=0, ) iteration_start_time = time.time() - example_index, batch_data = zip(*all_batch) - context = [c.tokenized_context for c in batch_data] - # we take the longest asked generation in the batch - # Multiple request may have different max generation length - max_tokens = max(d.generation_size for d in batch_data) # d[1][1] - if max_tokens <= 0: - raise ValueError("Greedy generation requires a positive value for max generation but we got -1") - - max_context = self.max_length - max_tokens - padding_length = min(len(context[0]), max_context) - batch_model = self.prepare_batch( + example_index, batch = zip(*indexed_batch) + + # NOTE: we are assuming all items in a batch behave similarly (same + # stop_tokens and max_tokens genrated) which is not necessarily + # the case! Because of that we should only use batch size of 1 + + # Since items are sorted by inverse length, the first one always has + # the maximum allowed generation size for the batch, unless we want to force truncation + # need to pass them somewhere ! stop_tokens = batch[0].stop_sequence + max_new_tokens = batch[0].generation_size + + # The main question for this step is the following: + # Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk + # of loosing some meaning, or have some generations that are exceedingly short? + # The choice we go for here is to avoid truncating the prompt if we can, since it + # should have been managed by the prompt creator/few shot manager if requested by the user. + context = [c.context for c in batch] # or tokenized context? + smallest_context = min(len(c) for c in context) + biggest_context = max(len(c) for c in context) + if smallest_context > self.max_length: + hlog_warn( + f"The smallest context of your batch ({smallest_context}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in" + + str({i.task_name for i in batch}) + + ". This is likely to lead to some errors." # noqa C401 + ) + + if ( + biggest_context > self.max_length + ): # There will be truncation of at least one sample, maximum generation size will be one + max_new_tokens = 1 + else: # We can't allow generation of more than max_length + max_new_tokens = min(self.max_length - biggest_context, max_new_tokens) + + # See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation + # Will do left truncation and padding, as defined when creating the tokenizer + tokenized = self.tokenizer( context, - padding_length=padding_length, - max_context=max_context, - pad_on_left=True, - full_attention_masks=False, + truncation="longest_first", # we truncate to the model max length if needed + padding="longest", # we pad to the longest sequence + return_tensors="pt", + max_length=self.max_length - 1, # we always allow minimum one token of generation + add_special_tokens=self.add_special_tokens, + ).to(self.device) + + batch_model = Batch( + input_ids=tokenized["input_ids"], + input_lengths=[len(item == 1) for item in tokenized["attention_mask"]], + input_mask=tokenized["attention_mask"], + truncated=[ + len(c) - tokenized["input_ids"].shape[1] if len(c) > tokenized["input_ids"].shape[1] else 0 + for c in context + ], + padded=[sum(mask == 0) for mask in tokenized["attention_mask"]], ) # responses, logits and input_ids have all been gathered accross GPUs already @@ -1222,10 +1267,9 @@ def greedy_until( outputs = decode_tokenized( input_ids=batch_model.input_ids, input_mask=batch_model.input_mask, - # TODO @thomasw21: From ModelWithLoss extract the model. model=self.model, parallel_context=self.parallel_context, - max_new_tokens=max_tokens, + max_new_tokens=max_new_tokens, max_micro_batch_size=batch_size, # ok for PP=1 for PP>1 we'll need to split the batch returns_logits=returns_logits, generation_config=self.generation_config, @@ -1241,19 +1285,6 @@ def greedy_until( logits = torch.stack([o.logits for o in outputs]) logits, len_logits = self.pad_and_gather(logits) - # if returns_logits: - # # Used input_ids to get its max_length - # transition_scores, len_logits = self.pad_and_gather(transition_scores) - # else: - # transition_scores, len_logits = None, None - - # responses, logits, input_ids, len_resps, len_logits, len_ids = self._model_generate( - # input_ids=batched_inputs, - # attention_mask=attention_masks, - # max_tokens=max_tokens, - # stop=stop_tokens, - # returns_logits=returns_logits, - # ) if dist.get_rank(self.parallel_context.pp_pg) == self.output_pp_rank: generations = batch_generations.numpy(force=True) input_ids = batch_input_ids.numpy(force=True)