Skip to content

Commit 286912e

Browse files
authored
Add BBH (#7)
* manage logprobs for one token, when several tasks have different number of choices * add bbh to test suite
1 parent db44c93 commit 286912e

File tree

8 files changed

+588
-14
lines changed

8 files changed

+588
-14
lines changed

src/lighteval/main_accelerate.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def main(args):
7373
model_config = create_model_config(args=args, accelerator=accelerator)
7474

7575
with htrack_block("Model loading"):
76-
# We need to load the model in the main process first to avoid downloading the model multiple times
7776
with accelerator.main_process_first() if accelerator is not None else nullcontext():
7877
model, model_info = load_model(config=model_config, env_config=env_config)
7978
evaluation_tracker.general_config_logger.log_model_info(model_info)
@@ -84,7 +83,6 @@ def main(args):
8483
task_dict = Registry(cache_dir=env_config.cache_dir).get_task_dict(
8584
task_names_list, custom_tasks=args.custom_tasks
8685
)
87-
# Loading all the dataset in a distributed manner
8886
LightevalTask.load_datasets(task_dict.values(), args.dataset_loading_processes)
8987

9088
evaluation_tracker.task_config_logger.log(task_dict)

src/lighteval/models/base_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch
2727
import torch.nn.functional as F
2828
import transformers
29+
from torch.nn.utils.rnn import pad_sequence
2930
from torch.utils.data import DataLoader
3031
from tqdm import tqdm
3132
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -834,9 +835,11 @@ def _loglikelihood_single_token(
834835
# Sync all
835836
# Need reshape before gather
836837
batched_inputs, len_inputs = self.pad_and_gather(prepared_batch.input_ids)
837-
batch_probs = torch.stack(batch_probs)
838+
# We sometimes have different tasks with a different number of choices.
839+
# Padding to -10000 makes sure that we won't reach index problems later as all log probs will be smaller than that
840+
batch_probs = pad_sequence(batch_probs, batch_first=True, padding_value=-10000000)
838841
batch_probs, len_probs = self.pad_and_gather(batch_probs)
839-
batch_cont_tokens = torch.stack(batch_cont_tokens)
842+
batch_cont_tokens = pad_sequence(batch_cont_tokens, batch_first=True, padding_value=-10000000)
840843
batch_cont_tokens, len_cont = self.pad_and_gather(batch_cont_tokens)
841844

842845
# No reshape

src/lighteval/tasks/lighteval_task.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pathlib import Path
2828
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
2929

30-
from datasets import DownloadMode, load_dataset
30+
from datasets import load_dataset
3131

3232
from lighteval.few_shot_manager import FewShotSampler
3333
from lighteval.logging.hierarchical_logger import hlog, hlog_warn
@@ -108,6 +108,8 @@ class LightevalTaskConfig:
108108

109109
trust_dataset: bool = None
110110

111+
must_remove_duplicate_docs: bool = None
112+
111113
def as_dict(self):
112114
return {
113115
"name": self.name,
@@ -213,6 +215,9 @@ def __init__(self, name: str, cfg: LightevalTaskConfig, cache_dir: Optional[str]
213215
self.generation_size = cfg.generation_size
214216
self.stop_sequence = cfg.stop_sequence
215217
self.output_regex = cfg.output_regex
218+
self.must_remove_duplicate_docs = cfg.must_remove_duplicate_docs
219+
if self.must_remove_duplicate_docs is None:
220+
self.must_remove_duplicate_docs = False
216221

217222
# Save options
218223
self.save_queries: bool = False
@@ -318,6 +323,14 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]:
318323
docs.extend(as_list(cur_docs))
319324
return docs
320325

326+
def remove_duplicate_docs(self, docs: list[Doc]) -> list[Doc]:
327+
seen_examples, res = set(), []
328+
for doc in docs:
329+
if str(doc) not in seen_examples:
330+
res.append(doc)
331+
seen_examples.add(str(doc))
332+
return res
333+
321334
def fewshot_docs(self) -> list[Doc]:
322335
"""
323336
Returns the few shot documents. If the few shot documents are not
@@ -346,6 +359,8 @@ def eval_docs(self) -> list[Doc]:
346359
"""
347360
if self._docs is None:
348361
self._docs = self._get_docs_from_split(self.evaluation_split)
362+
if self.must_remove_duplicate_docs:
363+
self._docs = self.remove_duplicate_docs(self._docs)
349364
return self._docs
350365

351366
def doc_to_target(self, formatted_doc: Doc, few_shot: bool = False) -> str:
@@ -360,12 +375,8 @@ def doc_to_target(self, formatted_doc: Doc, few_shot: bool = False) -> str:
360375
Returns:
361376
str: Target of the document, which is the correct answer for a document.
362377
"""
363-
if few_shot:
364-
if formatted_doc.target_for_fewshot_sorting is not None:
365-
return formatted_doc.target_for_fewshot_sorting
366-
367378
# likely we mostly need one example not all
368-
return formatted_doc.get_golds()[0]
379+
return as_list(formatted_doc.get_golds(few_shot=few_shot))[0]
369380

370381
# Requests
371382
def get_request_type(self) -> list[RequestType]:
@@ -572,7 +583,7 @@ def download_dataset_worker(args):
572583
name=dataset_config_name,
573584
data_dir=None,
574585
cache_dir=None,
575-
download_mode=DownloadMode.FORCE_REDOWNLOAD, # None
586+
download_mode=None,
576587
trust_remote_code=trust_dataset,
577588
)
578589
return dataset

src/lighteval/tasks/requests.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23-
from dataclasses import dataclass
23+
import json
24+
from dataclasses import asdict, dataclass
2425
from enum import Enum, auto
2526
from typing import NamedTuple, Optional, Union
2627

@@ -175,12 +176,22 @@ class Doc:
175176
num_asked_few_shots: int = -1
176177
num_effective_few_shots: int = -1
177178

178-
def get_golds(self):
179+
def get_golds(self, few_shot: bool = False):
179180
"""Return gold targets extracted from the target dict"""
180181
gold_indices = as_list(self.gold_index)
182+
if few_shot and self.target_for_fewshot_sorting is not None:
183+
choices = self.target_for_fewshot_sorting
184+
if isinstance(choices, str): # correct choice is already selected
185+
return choices
186+
else:
187+
choices = self.choices
181188
golds = []
182189
for gold_ix in gold_indices:
183-
local_golds = as_list(self.choices[gold_ix])
190+
local_golds = as_list(choices[gold_ix])
184191
for local_gold in local_golds:
185192
golds.append(local_gold)
186193
return golds
194+
195+
def __repr__(self):
196+
doc_dict = asdict(self)
197+
return json.dumps(doc_dict)

src/lighteval/tasks/tasks_prompt_formatting.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
import re
2727
import string
2828

29+
import numpy as np
2930
import pycountry
3031

32+
from lighteval.logging.hierarchical_logger import hlog_warn
3133
from lighteval.tasks.requests import Doc
3234
from lighteval.utils import as_list
3335

@@ -137,6 +139,239 @@ def process_path(path: str) -> str:
137139
return queries
138140

139141

142+
def bbh_harness(line, task_name: str = None):
143+
line = {k: v for k, v in line.items() if v not in [None, ""]}
144+
145+
task_prefix = line.get("task_prefix", "")
146+
example_input_prefix = line.get("example_input_prefix", "\nQ: ")
147+
query = f"{task_prefix}{example_input_prefix}{line['input']}"
148+
149+
rng = np.random.RandomState(seed=42)
150+
choice_prefix = line.get("choice_prefix", "\n choice: ")
151+
append_choices = bool(line.get("append_choices", True))
152+
# default
153+
correct_index = line["target_idx"]
154+
choices = line["choices"]
155+
if append_choices:
156+
choices = list(rng.permutation(sorted(line["choices"])))
157+
query = f"{query}{choice_prefix}{choice_prefix.join(choices)}"
158+
gold_item = line["choices"][line["target_idx"]]
159+
correct_index = choices.index(gold_item)
160+
161+
example_output_prefix = line.get("example_output_prefix", "\nA: ")
162+
query = f"{query}{example_output_prefix}"
163+
164+
return Doc(
165+
task_name=task_name,
166+
query=query,
167+
choices=choices,
168+
gold_index=correct_index,
169+
target_for_fewshot_sorting=choices,
170+
instruction=line.get("task_prefix", None),
171+
)
172+
173+
174+
def bbh_lighteval(line, task_name: str = None):
175+
line = {k: v for k, v in line.items() if v is not None}
176+
177+
query = line.get("task_prefix", "")
178+
query += line.get("example_input_prefix", "\nQuestion: ")
179+
query += line["input"]
180+
query += line.get("choice_prefix", "\n Choices: ")
181+
query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"])])
182+
query += line.get("example_output_prefix", "\nAnswer: ")
183+
184+
return Doc(
185+
task_name=task_name,
186+
query=query,
187+
choices=LETTER_INDICES[: len(line["choices"])],
188+
gold_index=line["target_idx"],
189+
target_for_fewshot_sorting=LETTER_INDICES[: len(line["choices"])],
190+
instruction=line.get("task_prefix", None),
191+
)
192+
193+
194+
def bbh(line, instruction, choices, task_name: str = None):
195+
return Doc(
196+
task_name=task_name,
197+
query=f"{instruction}Q: {line['input']}\nA:",
198+
choices=choices,
199+
gold_index=choices.index(line["target"]),
200+
target_for_fewshot_sorting=[f" {c}" for c in choices],
201+
instruction=instruction,
202+
)
203+
204+
205+
def bbh_boolean_expressions(line, task_name: str = None):
206+
instruction = "Evaluate the result of a random Boolean expression.\n\n"
207+
choices = ["False", "True"]
208+
return bbh(line, instruction, choices, task_name)
209+
210+
211+
def bbh_causal_judgment(line, task_name: str = None):
212+
instruction = "Answer questions about causal attribution.\n\n"
213+
choices = ["Yes", "No"]
214+
return bbh(line, instruction, choices, task_name)
215+
216+
217+
def bbh_date_understanding(line, task_name: str = None):
218+
instruction = "Infer the date from context.\n\n"
219+
choices = [f"({c})" for c in LETTER_INDICES[:6]]
220+
return bbh(line, instruction, choices, task_name)
221+
222+
223+
def bbh_disambiguation_qa(line, task_name: str = None):
224+
instruction = "Clarify the meaning of sentences with ambiguous pronouns.\n\n"
225+
choices = [f"({c})" for c in LETTER_INDICES[:3]]
226+
return bbh(line, instruction, choices, task_name)
227+
228+
229+
def bbh_dyck_languages(line, task_name: str = None): # Can only be done in generative
230+
instruction = "Correctly close a Dyck-n word.\n\n"
231+
choices = [line["target"]]
232+
return bbh(line, instruction, choices, task_name)
233+
234+
235+
def bbh_formal_fallacies(line, task_name: str = None):
236+
instruction = "Distinguish deductively valid arguments from formal fallacies.\n\n"
237+
choices = ["valid", "invalid"]
238+
return bbh(line, instruction, choices, task_name)
239+
240+
241+
def bbh_geometric_shapes(line, task_name: str = None):
242+
instruction = "Name geometric shapes from their SVG paths.\n\n"
243+
choices = [f"({c})" for c in LETTER_INDICES[:11]]
244+
return bbh(line, instruction, choices, task_name)
245+
246+
247+
def bbh_hyperbaton(line, task_name: str = None):
248+
instruction = "Order adjectives correctly in English sentences.\n\n"
249+
choices = [f"({c})" for c in LETTER_INDICES[:2]]
250+
return bbh(line, instruction, choices, task_name)
251+
252+
253+
def bbh_logical_deduction_five_objects(line, task_name: str = None):
254+
instruction = "A logical deduction task which requires deducing the order of a sequence of objects.\n\n"
255+
choices = [f"({c})" for c in LETTER_INDICES[:5]]
256+
return bbh(line, instruction, choices, task_name)
257+
258+
259+
def bbh_logical_deduction_seven_objects(line, task_name: str = None):
260+
instruction = "A logical deduction task which requires deducing the order of a sequence of objects.\n\n"
261+
choices = [f"({c})" for c in LETTER_INDICES[:7]]
262+
return bbh(line, instruction, choices, task_name)
263+
264+
265+
def bbh_logical_deduction_three_objects(line, task_name: str = None):
266+
instruction = "A logical deduction task which requires deducing the order of a sequence of objects.\n\n"
267+
choices = [f"({c})" for c in LETTER_INDICES[:3]]
268+
return bbh(line, instruction, choices, task_name)
269+
270+
271+
def bbh_movie_recommendation(line, task_name: str = None):
272+
if line["target"] == "Monsters, Inc": # this line is not correctly formatted
273+
hlog_warn("One sample removed from task bbh:movie_recommentation because its line is incorrectly formatted.")
274+
return []
275+
instruction = "Recommend movies similar to the given list of movies.\n\n"
276+
choices = [f"({c})" for c in LETTER_INDICES[:6]]
277+
return bbh(line, instruction, choices, task_name)
278+
279+
280+
def bbh_multistep_arithmetic_two(line, task_name: str = None):
281+
instruction = "Solve multi-step arithmetic problems.\n\n" # Can only be done in generative
282+
choices = [line["target"]]
283+
return bbh(line, instruction, choices, task_name)
284+
285+
286+
def bbh_navigate(line, task_name: str = None):
287+
instruction = (
288+
"Given a series of navigation instructions, determine whether one would end up back at the starting point.\n\n"
289+
)
290+
choices = ["Yes", "No"]
291+
return bbh(line, instruction, choices, task_name)
292+
293+
294+
def bbh_object_counting(line, task_name: str = None):
295+
instruction = "Questions that involve enumerating objects and asking the model to count them.\n\n"
296+
choices = [str(i) for i in range(1, 19)]
297+
return bbh(line, instruction, choices, task_name)
298+
299+
300+
def bbh_penguins_in_a_table(line, task_name: str = None):
301+
instruction = "Answer questions about a table of penguins and their attributes.\n\n"
302+
choices = [f"({c})" for c in LETTER_INDICES[:5]]
303+
return bbh(line, instruction, choices, task_name)
304+
305+
306+
def bbh_reasoning_about_colored_objects(line, task_name: str = None):
307+
instruction = "Answer extremely simple questions about the colors of objects on a surface.\n\n"
308+
choices = [f"({c})" for c in LETTER_INDICES[:18]]
309+
return bbh(line, instruction, choices, task_name)
310+
311+
312+
def bbh_ruin_names(line, task_name: str = None):
313+
if line["target"] in ["dearth, wind, & fire", "rita, sue and bob poo"]: # line not correctly formatted
314+
hlog_warn("One sample removed from task bbh:ruin_names because its line is incorrectly formatted.")
315+
return []
316+
instruction = "Select the humorous edit that 'ruins' the input movie or musical artist name.\n\n"
317+
choices = [f"({c})" for c in LETTER_INDICES[:6]]
318+
return bbh(line, instruction, choices, task_name)
319+
320+
321+
def bbh_salient_translation_error_detection(line, task_name: str = None):
322+
instruction = "Detect the type of error in an English translation of a German source sentence.\n\n"
323+
choices = [f"({c})" for c in LETTER_INDICES[:6]]
324+
return bbh(line, instruction, choices, task_name)
325+
326+
327+
def bbh_snarks(line, task_name: str = None):
328+
instruction = 'Determine which of two sentences is sarcastic.\n\nAccording to Cambridge University Dictionary, sarcasm is "the use of remarks that clearly mean the opposite of what they say, made in order to hurt someone\'s feelings or to criticize something in a humorous way." Sarcastic sentences often contain satirical or ironic utterances, hyperboles, ambivalent or witty remarks.\n\n'
329+
choices = [f"({c})" for c in LETTER_INDICES[:2]]
330+
return bbh(line, instruction, choices, task_name)
331+
332+
333+
def bbh_sports_understanding(line, task_name: str = None):
334+
instruction = "Determine whether an artificially constructed sentence relating to sports is plausible or not.\n\n"
335+
choices = ["yes", "no"]
336+
return bbh(line, instruction, choices, task_name)
337+
338+
339+
def bbh_temporal_sequences(line, task_name: str = None):
340+
instruction = "Task description: Answer questions about which times certain events could have occurred.\n\n"
341+
choices = [f"({c})" for c in LETTER_INDICES[:4]]
342+
return bbh(line, instruction, choices, task_name)
343+
344+
345+
def bbh_tracking_shuffled_objects_five_objects(line, task_name: str = None):
346+
instruction = "A task requiring determining the final positions of a set of objects given their initial positions and a description of a sequence of swaps.\n\n"
347+
choices = [f"({c})" for c in LETTER_INDICES[:5]]
348+
return bbh(line, instruction, choices, task_name)
349+
350+
351+
def bbh_tracking_shuffled_objects_seven_objects(line, task_name: str = None):
352+
instruction = "A task requiring determining the final positions of a set of objects given their initial positions and a description of a sequence of swaps.\n\n"
353+
choices = [f"({c})" for c in LETTER_INDICES[:7]]
354+
return bbh(line, instruction, choices, task_name)
355+
356+
357+
def bbh_tracking_shuffled_objects_three_objects(line, task_name: str = None):
358+
instruction = "A task requiring determining the final positions of a set of objects given their initial positions and a description of a sequence of swaps.\n\n"
359+
choices = [f"({c})" for c in LETTER_INDICES[:3]]
360+
return bbh(line, instruction, choices, task_name)
361+
362+
363+
def bbh_web_of_lies(line, task_name: str = None):
364+
instruction = "Evaluate a random boolean function expressed as a word problem.\n\n"
365+
choices = ["Yes", "No"]
366+
return bbh(line, instruction, choices, task_name)
367+
368+
369+
def bbh_word_sorting(line, task_name: str = None):
370+
instruction = "Sort a list of words.\n\n" # Can only be done in generative
371+
choices = [line["target"]]
372+
return bbh(line, instruction, choices, task_name)
373+
374+
140375
def bbq(line, task_name: str = None): # HELM
141376
query = f"The following are multiple choice questions (with answers).\nPassage: {line['context']}\nQuestion: {line['question']}"
142377
query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"])])

0 commit comments

Comments
 (0)