Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 68 additions & 37 deletions src/lighteval/models/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# ruff: noqa: C901,E120
# ruff: noqa: C901
import os
import time
from typing import List, Optional, Tuple, Type, Union
Expand Down Expand Up @@ -54,6 +54,7 @@
LoglikelihoodDataset,
LoglikelihoodSingleTokenDataset,
)
from lighteval.logging.hierarchical_logger import hlog_warn
from lighteval.models.base_model import LightevalModel
from lighteval.models.model_config import EnvConfig
from lighteval.models.model_output import Batch, GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn
Expand Down Expand Up @@ -1138,7 +1139,7 @@ def greedy_until(
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
for request in requests:
request.stop_sequence = list(request.stop_sequence) + [self.tokenizer.eos_token]
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]
request.tokenized_context = self.tok_encode(request.context)

dataset = GenerativeTaskDatasetNanotron(requests=requests, dataset_splits=dataset_splits)
Expand All @@ -1161,13 +1162,20 @@ def greedy_until(
dataset.split_start = subset_start
dataset.split_end = min(subset_start + subset_length, total_length)

context_enc = dataset[0][1].tokenized_context
max_gen = max(item[1].generation_size for item in dataset)
max_input_length = min(len(context_enc) + max_gen, self.max_length)
if dataset[0][1].generation_size is None:
# No constraints on the generation size: max length allowed is the max model context
max_input_length = self.max_length
else:
# Longest context in the current split is the first item (since we sort reversed)
context_enc = dataset[0][1].tokenized_context
max_gen = max(item[1].generation_size for item in dataset)
max_input_length = min(len(context_enc) + max_gen, self.max_length)

batch_size = self._get_batch_size(
override_bs=override_bs, max_input_length=max_input_length, starting_batch_size=starting_batch_size
)
starting_batch_size = batch_size * 2 # for the next round
# For next iteration, since the batch will be smaller, we'll test a bigger batch size
starting_batch_size = batch_size * 2

# For the DP replicas
distributed_sampler = GenDistributedSampler(
Expand All @@ -1188,7 +1196,7 @@ def greedy_until(
)

tq = tqdm(dataloader, desc=f"greedy in subset {s} Node {dist.get_rank(self.parallel_context.world_pg)}")
for j, all_batch in enumerate(tq):
for j, indexed_batch in enumerate(tq):
if j < 3:
log_rank(
f"Memory usage: {torch.cuda.memory_allocated() / 1024**2:.2f}MB. Peak reserved memory: {torch.cuda.max_memory_reserved() / 1024**2:.2f}MB",
Expand All @@ -1198,22 +1206,59 @@ def greedy_until(
rank=0,
)
iteration_start_time = time.time()
example_index, batch_data = zip(*all_batch)
context = [c.tokenized_context for c in batch_data]
# we take the longest asked generation in the batch
# Multiple request may have different max generation length
max_tokens = max(d.generation_size for d in batch_data) # d[1][1]
if max_tokens <= 0:
raise ValueError("Greedy generation requires a positive value for max generation but we got -1")

max_context = self.max_length - max_tokens
padding_length = min(len(context[0]), max_context)
batch_model = self.prepare_batch(
example_index, batch = zip(*indexed_batch)

# NOTE: we are assuming all items in a batch behave similarly (same
# stop_tokens and max_tokens genrated) which is not necessarily
# the case! Because of that we should only use batch size of 1

# Since items are sorted by inverse length, the first one always has
# the maximum allowed generation size for the batch, unless we want to force truncation
# need to pass them somewhere ! stop_tokens = batch[0].stop_sequence
max_new_tokens = batch[0].generation_size

# The main question for this step is the following:
# Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
# of loosing some meaning, or have some generations that are exceedingly short?
# The choice we go for here is to avoid truncating the prompt if we can, since it
# should have been managed by the prompt creator/few shot manager if requested by the user.
context = [c.context for c in batch] # or tokenized context?
smallest_context = min(len(c) for c in context)
biggest_context = max(len(c) for c in context)
if smallest_context > self.max_length:
hlog_warn(
f"The smallest context of your batch ({smallest_context}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in"
+ str({i.task_name for i in batch})
+ ". This is likely to lead to some errors." # noqa C401
)

if (
biggest_context > self.max_length
): # There will be truncation of at least one sample, maximum generation size will be one
max_new_tokens = 1
else: # We can't allow generation of more than max_length
max_new_tokens = min(self.max_length - biggest_context, max_new_tokens)

# See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation
# Will do left truncation and padding, as defined when creating the tokenizer
tokenized = self.tokenizer(
context,
padding_length=padding_length,
max_context=max_context,
pad_on_left=True,
full_attention_masks=False,
truncation="longest_first", # we truncate to the model max length if needed
padding="longest", # we pad to the longest sequence
return_tensors="pt",
max_length=self.max_length - 1, # we always allow minimum one token of generation
add_special_tokens=self.add_special_tokens,
).to(self.device)

batch_model = Batch(
input_ids=tokenized["input_ids"],
input_lengths=[len(item == 1) for item in tokenized["attention_mask"]],
input_mask=tokenized["attention_mask"],
truncated=[
len(c) - tokenized["input_ids"].shape[1] if len(c) > tokenized["input_ids"].shape[1] else 0
for c in context
],
padded=[sum(mask == 0) for mask in tokenized["attention_mask"]],
)

# responses, logits and input_ids have all been gathered accross GPUs already
Expand All @@ -1222,10 +1267,9 @@ def greedy_until(
outputs = decode_tokenized(
input_ids=batch_model.input_ids,
input_mask=batch_model.input_mask,
# TODO @thomasw21: From ModelWithLoss extract the model.
model=self.model,
parallel_context=self.parallel_context,
max_new_tokens=max_tokens,
max_new_tokens=max_new_tokens,
max_micro_batch_size=batch_size, # ok for PP=1 for PP>1 we'll need to split the batch
returns_logits=returns_logits,
generation_config=self.generation_config,
Expand All @@ -1241,19 +1285,6 @@ def greedy_until(
logits = torch.stack([o.logits for o in outputs])
logits, len_logits = self.pad_and_gather(logits)

# if returns_logits:
# # Used input_ids to get its max_length
# transition_scores, len_logits = self.pad_and_gather(transition_scores)
# else:
# transition_scores, len_logits = None, None

# responses, logits, input_ids, len_resps, len_logits, len_ids = self._model_generate(
# input_ids=batched_inputs,
# attention_mask=attention_masks,
# max_tokens=max_tokens,
# stop=stop_tokens,
# returns_logits=returns_logits,
# )
if dist.get_rank(self.parallel_context.pp_pg) == self.output_pp_rank:
generations = batch_generations.numpy(force=True)
input_ids = batch_input_ids.numpy(force=True)
Expand Down