Skip to content

Commit 3412567

Browse files
authored
Adding TinyBench (#104)
* Edited mechanism for corpus aggregations using dict * very important bug fix on aggregation call!
1 parent 1ec7222 commit 3412567

File tree

4 files changed

+316
-7
lines changed

4 files changed

+316
-7
lines changed
Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team & Felipe Maia Polo
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
# ruff: noqa: F405, F403, F401
24+
"""
25+
See https://github.com/felipemaiapolo/tinyBenchmarks/ for the original code.
26+
27+
Test with `python run_evals_accelerate.py --model_args "pretrained=EleutherAI/pythia-70m" --tasks "extended|tiny:winogrande|0|0,extended|tiny:gsm8k|0|0,extended|tiny:hellaswag|0|0,extended|tiny:arc|0|0,extended|tiny:truthfulqa|0|0" --extended_tasks extended_tasks --output_dir "./evals"`
28+
"""
29+
import os
30+
import pickle
31+
32+
import numpy as np
33+
import requests
34+
from aenum import extend_enum
35+
from scipy.optimize import minimize
36+
37+
from lighteval.metrics import Metrics
38+
from lighteval.metrics.metrics import CorpusLevelMetricGrouping
39+
from lighteval.metrics.metrics_sample import ExactMatches, LoglikelihoodAcc
40+
from lighteval.metrics.normalizations import gsm8k_normalizer
41+
from lighteval.metrics.utils import MetricCategory, MetricUseCase
42+
from lighteval.tasks.lighteval_task import LightevalTaskConfig
43+
from lighteval.tasks.requests import Doc
44+
45+
46+
# Utility functions
47+
def sigmoid(z):
48+
return 1 / (1 + np.exp(-z))
49+
50+
51+
def item_curve(theta, a, b):
52+
z = np.clip(a * theta - b, -30, 30).sum(axis=1)
53+
return sigmoid(z)
54+
55+
56+
def fit_theta(responses_test, seen_items, A, B, theta_init=None, eps=1e-10, optimizer="BFGS"):
57+
D = A.shape[1]
58+
59+
# Define the negative log likelihood function
60+
def neg_log_like(x):
61+
P = item_curve(x.reshape(1, D, 1), A[:, :, seen_items], B[:, :, seen_items]).squeeze()
62+
log_likelihood = np.sum(
63+
responses_test[seen_items] * np.log(P + eps) + (1 - responses_test[seen_items]) * np.log(1 - P + eps)
64+
)
65+
return -log_likelihood
66+
67+
# Use the minimize function to find the ability parameters that minimize the negative log likelihood
68+
optimal_theta = minimize(neg_log_like, np.zeros(D), method=optimizer).x[None, :, None]
69+
return optimal_theta
70+
71+
72+
# Evaluation function
73+
class TinyCorpusAggregator:
74+
LEADEBRBOARD_SCENARIOS = ["truthfulqa", "gsm8k", "winogrande", "arc", "hellaswag"]
75+
BENCHS = ["lb", "mmlu"]
76+
METRICS = ["irt", "pirt", "gpirt"]
77+
# Not included yet:
78+
# - helm_lite (not avail on datasets)
79+
# - alpaca (needs to be added to lighteval first)
80+
81+
def __init__(self, task: str):
82+
self.number_of_examples = 100
83+
if task not in self.LEADEBRBOARD_SCENARIOS + self.BENCHS:
84+
raise ValueError(f"Bench name must be one of {','.join(self.LEADEBRBOARD_SCENARIOS + self.BENCHS)}.")
85+
self.task = task
86+
self.scenario = "lb" if task in self.LEADEBRBOARD_SCENARIOS else task
87+
self.download()
88+
self.estimates = None
89+
self.num_samples = 0
90+
91+
def download(self):
92+
# Downloading files
93+
if not os.path.isfile("extended_tasks/tiny_benchmarks/tinyBenchmarks.pkl"):
94+
url = "https://raw.githubusercontent.com/felipemaiapolo/tinyBenchmarks/main/tinyBenchmarks/tinyBenchmarks.pkl"
95+
response = requests.get(url)
96+
if response.status_code == 200:
97+
# Write the content to a file
98+
with open("extended_tasks/tiny_benchmarks/tinyBenchmarks.pkl", "wb") as file:
99+
file.write(response.content)
100+
101+
def compute(self, **args):
102+
if self.task == "gsm8k":
103+
res = ExactMatches(
104+
strip_strings=True, normalize_pred=gsm8k_normalizer, normalize_gold=gsm8k_normalizer
105+
).compute(**args)
106+
return {m: res for m in self.METRICS}
107+
else:
108+
res = LoglikelihoodAcc().compute(**args)
109+
return {m: res for m in self.METRICS}
110+
111+
def aggregate(self, y_input):
112+
if len(y_input) == self.num_samples and self.estimates is not None:
113+
return self.estimates[self.task]
114+
115+
# We load the weights for the relevant examples
116+
with open("extended_tasks/tiny_benchmarks/tinyBenchmarks.pkl", "rb") as handle:
117+
tinyBenchmarks = pickle.load(handle)
118+
119+
seen_examples = tinyBenchmarks[self.scenario]["seen_examples"]
120+
examples_weights = tinyBenchmarks[self.scenario]["examples_weights"]
121+
irt_parameters = tinyBenchmarks[self.scenario]["irt_parameters"]
122+
A, B = irt_parameters["A"], irt_parameters["B"]
123+
optimal_lambdas = tinyBenchmarks[self.scenario]["optimal_lambdas"]
124+
scenarios_position = tinyBenchmarks[self.scenario]["scenarios_position"]
125+
subscenarios_position = tinyBenchmarks[self.scenario]["subscenarios_position"]
126+
127+
N = np.max([np.max(x) for x in scenarios_position.values()]) + 1
128+
balance_weights = np.ones(N)
129+
for scenario in scenarios_position.keys():
130+
N_sce = len(scenarios_position[scenario])
131+
n_sub = len(subscenarios_position[scenario])
132+
for sub in subscenarios_position[scenario].keys():
133+
n_i = len(subscenarios_position[scenario][sub])
134+
balance_weights[subscenarios_position[scenario][sub]] = N_sce / (n_sub * n_i)
135+
136+
# In case we use the big IRT model to estimate the performance of individual scenarios
137+
if self.task not in self.BENCHS:
138+
scenarios = [self.task]
139+
ind_scenario = (
140+
self.number_of_examples * ([i for i, s in enumerate(scenarios_position.keys()) if s == self.task][0])
141+
)
142+
seen_examples = seen_examples[ind_scenario : ind_scenario + self.number_of_examples]
143+
else:
144+
scenarios = list(scenarios_position.keys())
145+
146+
# Creating vector y and estimating theta
147+
y = np.zeros(N)
148+
for i, j in enumerate(seen_examples):
149+
y[j] = y_input[i]
150+
151+
# Getting estimates
152+
theta = fit_theta(y, seen_examples, A, B)
153+
estimates = {}
154+
unseen_examples = [i for i in range(N) if i not in seen_examples]
155+
156+
for scenario in scenarios:
157+
N_sce = len(scenarios_position[scenario])
158+
seen_examples_sce = [s for s in seen_examples if s in scenarios_position[scenario]]
159+
unseen_examples_sce = [s for s in unseen_examples if s in scenarios_position[scenario]]
160+
161+
data_part_IRTp = ((balance_weights * y)[seen_examples_sce]).mean()
162+
irt_part = (balance_weights * item_curve(theta.reshape(1, A.shape[1], 1), A, B))[
163+
0, [unseen_examples_sce]
164+
].mean()
165+
IRTp_lambd = self.number_of_examples / N_sce
166+
IRT = (examples_weights[scenario] * y[seen_examples_sce]).sum()
167+
IRTp = IRTp_lambd * data_part_IRTp + (1 - IRTp_lambd) * irt_part
168+
IRTpp = optimal_lambdas[scenario] * IRT + (1 - optimal_lambdas[scenario]) * IRTp
169+
170+
estimates[scenario] = {}
171+
estimates[scenario]["irt"] = IRT
172+
estimates[scenario]["pirt"] = IRTp
173+
estimates[scenario]["gpirt"] = IRTpp
174+
175+
self.num_samples = len(y_input)
176+
self.estimates = estimates
177+
178+
return estimates[self.task]
179+
180+
181+
# TASK CREATION
182+
task_params = [
183+
{
184+
"name": "winogrande",
185+
"dataset": "tinyBenchmarks/tinyWinogrande",
186+
"subset": "winogrande_xl",
187+
"prompt": "winogrande",
188+
"splits": ["train", "validation", "test"],
189+
"evaluation_split": ["validation"],
190+
},
191+
{
192+
"name": "arc",
193+
"dataset": "tinyBenchmarks/tinyAI2_arc",
194+
"subset": "ARC-Challenge",
195+
"prompt": "arc",
196+
"splits": ["train", "validation", "test"],
197+
"evaluation_split": ["validation"],
198+
},
199+
{
200+
"name": "hellaswag",
201+
"dataset": "tinyBenchmarks/tinyHellaswag",
202+
"subset": "default",
203+
"prompt": "hellaswag_harness",
204+
"splits": ["train", "validation", "test"],
205+
"evaluation_split": ["validation"],
206+
},
207+
{
208+
"name": "mmlu",
209+
"dataset": "tinyBenchmarks/tinyMMLU",
210+
"subset": "all",
211+
"prompt": "mmlu_harness",
212+
"splits": ["validation", "dev", "test"],
213+
"evaluation_split": ["test"],
214+
},
215+
{
216+
"name": "truthfulqa",
217+
"dataset": "tinyBenchmarks/tinyTruthfulQA",
218+
"subset": "multiple_choice",
219+
"prompt": "truthful_qa_multiple_choice",
220+
"splits": ["validation"],
221+
"evaluation_split": ["validation"],
222+
},
223+
{
224+
"name": "gsm8k",
225+
"dataset": "tinyBenchmarks/tinyGSM8k",
226+
"subset": "main",
227+
"prompt": "gsm8k",
228+
"splits": ["train", "test"],
229+
"evaluation_split": ["test"],
230+
},
231+
# {
232+
# "name": "alpacaeval",
233+
# "dataset": "tinyBenchmarks/tinyAlpacaEval",
234+
# "subset": "default"
235+
# },
236+
]
237+
238+
_TASKS = []
239+
for task in task_params:
240+
name = task["name"]
241+
generation_size = None
242+
stop_sequence = None
243+
if name == "gsm8k":
244+
generation_size = 256
245+
stop_sequence = ["Question:", "Question"]
246+
task = LightevalTaskConfig(
247+
name=f"tiny:{name}",
248+
prompt_function=task["prompt"],
249+
suite=["extended"],
250+
hf_repo=task["dataset"],
251+
hf_subset=task["subset"],
252+
hf_avail_splits=task["splits"],
253+
evaluation_splits=task["evaluation_split"],
254+
few_shots_split=None,
255+
few_shots_select="random_sampling",
256+
metric=[f"tinybench_metric_{name}"],
257+
generation_size=generation_size,
258+
stop_sequence=stop_sequence,
259+
)
260+
_TASKS.append(task)
261+
262+
# CUSTOM METRIC
263+
for task_param in task_params:
264+
name = task_param["name"]
265+
if name == "gsm8k":
266+
category = MetricCategory.GENERATIVE
267+
use_case = MetricUseCase.MATH
268+
else:
269+
category = MetricCategory.MULTICHOICE
270+
use_case = MetricUseCase.ACCURACY
271+
272+
extend_enum(
273+
Metrics,
274+
f"tinybench_metric_{name}",
275+
CorpusLevelMetricGrouping(
276+
metric=TinyCorpusAggregator.METRICS,
277+
higher_is_better={m: True for m in TinyCorpusAggregator.METRICS},
278+
sample_level_fn=TinyCorpusAggregator(name).compute,
279+
category=category,
280+
use_case=use_case,
281+
corpus_level_fn=TinyCorpusAggregator(name).aggregate,
282+
),
283+
)
284+
285+
286+
# MODULE LOGIC
287+
# You should not need to touch this
288+
# Convert to dict for lighteval
289+
TASKS_TABLE = [task.as_dict() for task in _TASKS]
290+
291+
if __name__ == "__main__":
292+
print(t["name"] for t in TASKS_TABLE)
293+
print(len(TASKS_TABLE))

src/lighteval/logging/info_loggers.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -458,21 +458,29 @@ def aggregate(self, task_dict: dict[str, LightevalTask], bootstrap_iters: int =
458458
# fix the fact that we need the task_dict
459459
task = task_dict[cur_task_name]
460460

461+
skip_metric = []
461462
for metric_name, metric_values in metrics.items():
463+
if metric_name in skip_metric:
464+
# The metric is in a subset which has already been computed and saved
465+
continue
466+
462467
try:
463468
metric_result = task.aggregation()[metric_name](metric_values)
464469
except OverflowError:
465470
hlog_warn(f"{task_name}, {metric_name} got an OVERFLOW ERROR when aggregating.")
466471
metric_result = float("nan")
467472

468-
if isinstance(metric_result, dict): # in which cases do we get a dict here?
473+
if isinstance(metric_result, dict): # For some corpus level grouping metrics
469474
self.metric_aggregated[task_name].update(metric_result)
475+
skip_metric.extend(list(metric_result.keys())) # no need to recompute them later
470476
else:
471477
self.metric_aggregated[task_name][metric_name] = metric_result
472478

473-
aggregation = task.aggregation()[metric_name]
474-
475-
stderr = get_stderr_function(aggregation=aggregation, number_experiments=1000)
479+
if isinstance(metric_result, dict):
480+
stderr = None # We skip stderr for some corpus metrics that return dicts
481+
else:
482+
aggregation = task.aggregation()[metric_name]
483+
stderr = get_stderr_function(aggregation=aggregation, number_experiments=1000)
476484
if stderr is not None and len(metric_values) > 1:
477485
try:
478486
self.metric_aggregated[task_name][f"{metric_name}_stderr"] = stderr(metric_values)

src/lighteval/metrics/metrics.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,13 +522,20 @@ def higher_is_better():
522522
return res
523523

524524
@staticmethod
525-
def corpus_level_fns() -> dict[str, callable]:
525+
def corpus_level_fns(metrics: list[str]) -> dict[str, callable]:
526526
res = {}
527527
for metric in Metrics:
528+
if metric.name not in metrics:
529+
continue
528530
if metric.value.category == MetricCategory.IGNORED:
529531
continue
530532
if isinstance(metric.value, MetricGrouping):
531-
res.update(metric.value.corpus_level_fn)
533+
if isinstance(metric.value.corpus_level_fn, dict):
534+
res.update(metric.value.corpus_level_fn)
535+
else:
536+
# Must make sure there is a caching implementation here
537+
for m in metric.value.metric:
538+
res[m] = metric.value.corpus_level_fn
532539
else:
533540
res[metric.value.metric] = metric.value.corpus_level_fn
534541
return res

src/lighteval/tasks/lighteval_task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]:
321321
"""
322322
if self.dataset is None:
323323
self.dataset = download_dataset_worker((self.dataset_path, self.dataset_config_name, self.trust_dataset))
324+
splits = as_list(splits)
324325

325326
docs = []
326327
for split in splits:
@@ -553,7 +554,7 @@ def aggregation(self):
553554
Return a dict with metric name and its aggregation function for all
554555
metrics
555556
"""
556-
return Metrics.corpus_level_fns()
557+
return Metrics.corpus_level_fns(self.metrics)
557558

558559
@staticmethod
559560
def load_datasets(tasks: list["LightevalTask"], dataset_loading_processes: int = 1) -> None:

0 commit comments

Comments
 (0)