Skip to content

Commit 0745e9a

Browse files
authored
Now manages no generation size is set in a generative task description (#76)
When no generation size is set, we want to use the max generation size possible according to the model (the model.max_length). This should also fix the bug in #73 . Also includes a small duplicate removal of TaskConfig
1 parent 437a35f commit 0745e9a

File tree

4 files changed

+63
-52
lines changed

4 files changed

+63
-52
lines changed

src/lighteval/data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def _sorting_criteria(self, request: GreedyUntilRequest | GreedyUntilWithLogitsR
195195
"""
196196
toks = request.tokenized_context
197197
gen_length = request.generation_size
198+
# The generative task has no limit except the model context
199+
if gen_length is None:
200+
gen_length = 0
198201
return -(len(toks) + gen_length)
199202

200203

src/lighteval/logging/info_loggers.py

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from lighteval.metrics.stderr import get_stderr_function
1414
from lighteval.models.model_loader import ModelInfo
1515
from lighteval.models.model_output import ModelReturn
16-
from lighteval.tasks.lighteval_task import LightevalTask
16+
from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig
1717
from lighteval.tasks.requests import Doc
1818
from lighteval.utils import as_list, is_nanotron_available, sanitize_numpy
1919

@@ -497,53 +497,11 @@ class TaskConfigLogger:
497497
"""Logs the different parameters of the current [`LightevalTask`] of interest.
498498
499499
Attributes:
500-
tasks_config (dict[str, TaskConfig]): Maps each task to its associated [`TaskConfig`]
500+
tasks_config (dict[str, LightevalTaskConfig]): Maps each task to its associated [`LightevalTaskConfig`]
501501
502502
"""
503503

504-
@dataclass
505-
class TaskConfig:
506-
"""Stored configuration of a given [`LightevalTask`].
507-
508-
Arguments:
509-
name (str): Short name of the evaluation task.
510-
suite (list[str]): Evaluation suites to which the task belongs.
511-
prompt_function (str): Name of the function used to create the [`Doc`] samples from each line of the evaluation dataset.
512-
hf_repo (str): Path of the hub dataset repository containing the evaluation information.
513-
hf_subset (str): Subset used for the current task, will be default if none is selected.
514-
hf_avail_splits (list[str]): All the available splits in the evaluation dataset
515-
evaluation_splits (list[str]): List of the splits actually used for this evaluation
516-
few_shots_split (str): Name of the split from which to sample few-shot examples
517-
few_shots_select (str): Method with which to sample few-shot examples
518-
generation_size (int): Maximum allowed size of the generation
519-
metric (list[str]): List of all the metrics for the current task.
520-
stop_sequence (list[str]): Stop sequence which interrupts the generation for generative metrics.
521-
original_num_docs (int): Number of documents in the task
522-
effective_num_docs (int): Number of documents used in a specific evaluation
523-
truncated_num_docs (bool): Whether less than the total number of documents were used
524-
output_regex (str)
525-
frozen (bool)
526-
527-
"""
528-
529-
name: str
530-
suite: list[str]
531-
prompt_function: str
532-
hf_repo: str
533-
hf_subset: str
534-
hf_avail_splits: list[str]
535-
evaluation_splits: list[str]
536-
few_shots_split: str
537-
few_shots_select: str
538-
generation_size: int
539-
metric: list[str]
540-
stop_sequence: list[str]
541-
output_regex: str
542-
frozen: bool
543-
original_num_docs: int = -1
544-
effective_num_docs: int = -1
545-
546-
tasks_configs: dict[str, TaskConfig] = {}
504+
tasks_configs: dict[str, LightevalTaskConfig] = {}
547505

548506
def log(self, task_dict: dict[str, LightevalTask]) -> None:
549507
self.tasks_configs = {name: task.cfg for name, task in task_dict.items()}

src/lighteval/models/base_model.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,9 +354,17 @@ def greedy_until(
354354
position=0,
355355
disable=self.disable_tqdm,
356356
):
357-
# Longest context in the current split is the first item (since we sort reversed)
358-
longest_context_continuation_size_in_split = len(dataset[0].tokenized_context) + dataset[0].generation_size
359-
max_context_continuation_size_allowed = min(longest_context_continuation_size_in_split, self.max_length)
357+
if dataset[0].generation_size is None:
358+
# No constraints on the generation size: max length allowed is the max model context
359+
max_context_continuation_size_allowed = self.max_length
360+
else:
361+
# Longest context in the current split is the first item (since we sort reversed)
362+
longest_context_continuation_size_in_split = (
363+
len(dataset[0].tokenized_context) + dataset[0].generation_size
364+
)
365+
max_context_continuation_size_allowed = min(
366+
longest_context_continuation_size_in_split, self.max_length
367+
)
360368
batch_size = self._get_batch_size(
361369
override_bs=override_bs,
362370
max_input_length=max_context_continuation_size_allowed,
@@ -376,9 +384,25 @@ def greedy_until(
376384
# stop_tokens and max_tokens genrated) which is not necessarily
377385
# the case! Because of that we only use batch size of 1
378386
stop_tokens = batch[0].stop_sequence
379-
max_generated_tokens = batch[0].generation_size
380387
context = [c.context for c in batch]
381-
max_context_size_allowed = self.max_length - max_generated_tokens
388+
max_context_size_allowed = self.max_length
389+
if batch[0].generation_size is None:
390+
# No constraints on max tokens except the model and data
391+
# Max generation possible is the max_length - the smallest context
392+
smallest_context = min([len(c) for c in context])
393+
if smallest_context < self.max_length:
394+
max_generated_tokens = self.max_length - smallest_context
395+
max_context_size_allowed = self.max_length
396+
else:
397+
# The max context size is smaller than the smallest context
398+
max_generated_tokens = 1
399+
max_context_size_allowed = self.max_length - 1
400+
hlog_warn(
401+
f"The smallest context of your batch ({smallest_context}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in {[i.task_name for i in batch]}. This is likely to lead to some errors."
402+
)
403+
else:
404+
max_generated_tokens = batch[0].generation_size
405+
max_context_size_allowed = self.max_length - max_generated_tokens
382406

383407
tokenized = self.tokenizer(
384408
context,

src/lighteval/tasks/lighteval_task.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,29 @@
4242

4343
@dataclass
4444
class LightevalTaskConfig:
45+
"""Stored configuration of a given [`LightevalTask`].
46+
47+
Arguments:
48+
name (str): Short name of the evaluation task.
49+
suite (list[str]): Evaluation suites to which the task belongs.
50+
prompt_function (str): Name of the function used to create the [`Doc`] samples from each line of the evaluation dataset.
51+
hf_repo (str): Path of the hub dataset repository containing the evaluation information.
52+
hf_subset (str): Subset used for the current task, will be default if none is selected.
53+
hf_avail_splits (list[str]): All the available splits in the evaluation dataset
54+
evaluation_splits (list[str]): List of the splits actually used for this evaluation
55+
few_shots_split (str): Name of the split from which to sample few-shot examples
56+
few_shots_select (str): Method with which to sample few-shot examples
57+
generation_size (int): Maximum allowed size of the generation
58+
metric (list[str]): List of all the metrics for the current task.
59+
stop_sequence (list[str]): Stop sequence which interrupts the generation for generative metrics.
60+
original_num_docs (int): Number of documents in the task
61+
effective_num_docs (int): Number of documents used in a specific evaluation
62+
truncated_num_docs (bool): Whether less than the total number of documents were used
63+
output_regex (str)
64+
frozen (bool)
65+
66+
"""
67+
4568
name: str
4669
prompt_function: str
4770
hf_repo: str
@@ -51,12 +74,15 @@ class LightevalTaskConfig:
5174
evaluation_splits: Optional[Tuple[str]] = None
5275
few_shots_split: Optional[str] = None
5376
few_shots_select: Optional[str] = None
54-
generation_size: int = -1
77+
generation_size: int = None
5578
stop_sequence: Optional[Tuple[str]] = None
5679
output_regex: Optional[str] = None
5780

5881
frozen: bool = False
59-
suite: Optional[Tuple[str]] = None # we use this to know if we should use a custom lighteval or bigcode task
82+
suite: Optional[Tuple[str]] = None
83+
84+
original_num_docs: int = -1
85+
effective_num_docs: int = -1
6086

6187
def as_dict(self):
6288
return {

0 commit comments

Comments
 (0)