Skip to content

Commit b42551f

Browse files
hynky1999NathanHBclefourrier
authored
Adds Baseline workflow + fixes (#363)
* add baseline + fix tasks arg * comments :) * different model name so that the naming is consitent with normal models --------- Co-authored-by: Nathan Habib <[email protected]> Co-authored-by: Clémentine Fourrier <[email protected]>
1 parent 1c730b8 commit b42551f

File tree

3 files changed

+144
-8
lines changed

3 files changed

+144
-8
lines changed

src/lighteval/__main__.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,20 @@
2222
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2323
# SOFTWARE.
2424

25+
2526
import argparse
2627
import os
2728
from dataclasses import asdict
2829
from pprint import pformat
2930

30-
from lighteval.parsers import parser_accelerate, parser_nanotron, parser_utils_tasks
31+
from lighteval.parsers import parser_accelerate, parser_baseline, parser_nanotron, parser_utils_tasks
3132
from lighteval.tasks.registry import Registry, taskinfo_selector
3233

3334

3435
CACHE_DIR = os.getenv("HF_HOME")
3536

3637

37-
def cli_evaluate():
38+
def cli_evaluate(): # noqa: C901
3839
parser = argparse.ArgumentParser(description="CLI tool for lighteval, a lightweight framework for LLM evaluation")
3940
subparsers = parser.add_subparsers(help="help for subcommand", dest="subcommand")
4041

@@ -46,9 +47,12 @@ def cli_evaluate():
4647
parser_b = subparsers.add_parser("nanotron", help="use nanotron as backend for evaluation.")
4748
parser_nanotron(parser_b)
4849

50+
parser_c = subparsers.add_parser("baseline", help="compute baseline for a task")
51+
parser_baseline(parser_c)
52+
4953
# Subparser for task utils functions
50-
parser_c = subparsers.add_parser("tasks", help="display information about available tasks and samples.")
51-
parser_utils_tasks(parser_c)
54+
parser_d = subparsers.add_parser("tasks", help="display information about available tasks and samples.")
55+
parser_utils_tasks(parser_d)
5256

5357
args = parser.parse_args()
5458

@@ -62,18 +66,24 @@ def cli_evaluate():
6266

6367
main_nanotron(args.checkpoint_config_path, args.lighteval_config_path, args.cache_dir)
6468

69+
elif args.subcommand == "baseline":
70+
from lighteval.main_baseline import main as main_baseline
71+
72+
main_baseline(args)
73+
6574
elif args.subcommand == "tasks":
75+
registry = Registry(cache_dir=args.cache_dir, custom_tasks=args.custom_tasks)
6676
if args.list:
67-
Registry(cache_dir="").print_all_tasks()
77+
registry.print_all_tasks()
6878

6979
if args.inspect:
7080
print(f"Loading the tasks dataset to cache folder: {args.cache_dir}")
7181
print(
7282
"All examples will be displayed without few shot, as few shot sample construction requires loading a model and using its tokenizer. "
7383
)
7484
# Loading task
75-
task_names_list, _ = taskinfo_selector(args.inspect)
76-
task_dict = Registry(cache_dir=args.cache_dir).get_task_dict(task_names_list)
85+
task_names_list, _ = taskinfo_selector(args.inspect, task_registry=registry)
86+
task_dict = registry.get_task_dict(task_names_list)
7787
for name, task in task_dict.items():
7888
print("-" * 10, name, "-" * 10)
7989
if args.show_config:
@@ -84,7 +94,6 @@ def cli_evaluate():
8494
print("-" * 10, "SAMPLES")
8595
print(f"-- sample {ix} --")
8696
print(pformat(asdict(sample), indent=1))
87-
8897
else:
8998
print("You did not provide any argument. Exiting")
9099

src/lighteval/main_baseline.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
from lighteval.logging.evaluation_tracker import EvaluationTracker
24+
from lighteval.metrics.utils.metric_utils import MetricCategory
25+
from lighteval.models.abstract_model import ModelInfo
26+
from lighteval.tasks.lighteval_task import LightevalTask
27+
from lighteval.tasks.registry import Registry, taskinfo_selector
28+
from lighteval.utils.utils import as_list
29+
30+
31+
def main(args):
32+
"""
33+
Compute baselines for given tasks.
34+
35+
It has been tested with generative and accuracy tasks, but may not work correctly for other task types.
36+
37+
The baseline is computed as follows:
38+
- For multiple-choice tasks: It assumes random guessing, so the score is n_correct/number_of_choices.
39+
- For other metrics: It assigns a score of 0, which may not be appropriate for all task types.
40+
41+
Note:
42+
This baseline computation may not be suitable for all task types and should be used with caution.
43+
"""
44+
task_registry = Registry(cache_dir=args.cache_dir, custom_tasks=args.custom_tasks)
45+
task_names_list, fewshots_dict = taskinfo_selector(args.tasks, task_registry)
46+
task_dict = task_registry.get_task_dict(task_names_list)
47+
48+
evaluation_tracker = EvaluationTracker(
49+
output_dir=args.output_dir,
50+
save_details=False,
51+
push_to_hub=False,
52+
push_to_tensorboard=False,
53+
public=False,
54+
hub_results_org=None,
55+
)
56+
evaluation_tracker.general_config_logger.log_model_info(
57+
ModelInfo(
58+
model_name="lighteval/baseline",
59+
model_sha=None,
60+
model_dtype=None,
61+
model_size=None,
62+
)
63+
)
64+
evaluation_tracker.task_config_logger.log(task_dict)
65+
66+
LightevalTask.load_datasets(list(task_dict.values()), args.dataset_loading_processes)
67+
68+
for task_name, task in task_dict.items():
69+
task_docs = list(task.eval_docs())
70+
n_samples = min(args.max_samples, len(task_docs)) if args.max_samples else len(task_docs)
71+
72+
p_correct_score = [
73+
len(as_list(task_doc.gold_index)) / len(task_doc.choices) for task_doc in task_docs[:n_samples]
74+
]
75+
76+
metric_results = {
77+
metric.metric_name: p_correct_score
78+
if metric.category
79+
in [MetricCategory.MULTICHOICE, MetricCategory.MULTICHOICE_PMI, MetricCategory.MULTICHOICE_ONE_TOKEN]
80+
else 0
81+
for metric in task.metrics
82+
}
83+
84+
for fewshots, _ in fewshots_dict[task_name]:
85+
evaluation_tracker.metrics_logger.log(f"{task_name}|{fewshots}", metric_results)
86+
87+
evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict, bootstrap_iters=1000)
88+
evaluation_tracker.save()

src/lighteval/parsers.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,44 @@ def parser_accelerate(parser=None):
104104
return parser
105105

106106

107+
def parser_baseline(parser=None):
108+
if parser is None:
109+
parser = argparse.ArgumentParser(
110+
description="CLI tool for lighteval, a lightweight framework for LLM evaluation"
111+
)
112+
113+
parser.add_argument(
114+
"--custom_tasks",
115+
type=str,
116+
default=None,
117+
help="Path to a file with custom tasks (a TASK list of dict and potentially prompt formating functions)",
118+
)
119+
120+
parser.add_argument(
121+
"--tasks",
122+
type=str,
123+
required=True,
124+
help="Task to compute the baseline for",
125+
)
126+
parser.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to evaluate on")
127+
parser.add_argument(
128+
"--dataset_loading_processes", type=int, default=1, help="Number of processes to use for loading the datasets"
129+
)
130+
131+
parser.add_argument(
132+
"--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models"
133+
)
134+
# Ooutput related
135+
parser.add_argument(
136+
"--output_dir",
137+
required=True,
138+
type=str,
139+
help="Directory to save the results, fsspec compliant (e.g. s3://bucket/path)",
140+
)
141+
142+
return parser
143+
144+
107145
def parser_nanotron(parser=None):
108146
if parser is None:
109147
parser = argparse.ArgumentParser(
@@ -142,6 +180,7 @@ def parser_utils_tasks(parser=None):
142180
default=None,
143181
help="Id of tasks or path to a text file with a list of tasks (e.g. 'original|mmlu:abstract_algebra|5') for which you want to manually inspect samples.",
144182
)
183+
parser.add_argument("--custom_tasks", type=str, default=None, help="Path to a file with custom tasks")
145184
parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to display")
146185
parser.add_argument("--show_config", default=False, action="store_true", help="Will display the full task config")
147186
parser.add_argument(

0 commit comments

Comments
 (0)