Skip to content

Commit 1b703b4

Browse files
hynky1999NathanHBHynek Kydlicekclefourrier
authored
Multilingual Hellaswag tasks (#332)
* add multilignaul dynamic generative metrics * draft * finish multichoice config * update tokenizers + install nltk reqs * use punkt tab * Update src/lighteval/utils/imports.py Co-authored-by: Nathan Habib <[email protected]> * Update src/lighteval/metrics/normalizations.py Co-authored-by: Nathan Habib <[email protected]> * fix imports * remove unused import * finish implementation of templates + move stuff around * resolve nits * when in rome do as romans do (handle error messages the same way) * fix utils * nicers tests + fix them * nicer todo * add nice doscrings 📃 * add even more docstring * nit * fix test * add multilingual to dev group * merge nli, add languagees to literals * translation literals * add nli * add copa tasks + fix tranlation literals * add hellaswag tasks * remove custom telgu hellaswag * remove hindi hellaswag * add rcb + chinese nli * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <[email protected]> * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <[email protected]> * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <[email protected]> * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <[email protected]> * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <[email protected]> * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <[email protected]> * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <[email protected]> * add two new tasks + docs * add nice docs * update hellaswag with docs * move hellaswag to lighteval suite * Update src/lighteval/tasks/multilingual/tasks.py Co-authored-by: Clémentine Fourrier <[email protected]> * enable returning none from templates + better typing * change unoficial hellaswag names to have community_prefix + unify hellaswag preprocesisng * let strip be optional in hellaswag --------- Co-authored-by: Nathan Habib <[email protected]> Co-authored-by: Hynek Kydlicek <[email protected]> Co-authored-by: Clémentine Fourrier <[email protected]>
1 parent 6f30384 commit 1b703b4

File tree

11 files changed

+522
-30
lines changed

11 files changed

+522
-30
lines changed

src/lighteval/tasks/default_prompts.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -755,21 +755,29 @@ def headqa(line, task_name: str = None):
755755
)
756756

757757

758-
def hellaswag_harness(line, task_name: str = None):
759-
def preprocess(text):
760-
"""Comes from AiHarness"""
761-
# text = text.strip()
762-
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
763-
text = text.replace(" [title]", ". ")
764-
text = re.sub("\\[.*?\\]", "", text)
765-
text = text.replace(" ", " ")
766-
return text
758+
def hellaswag_preprocess(
759+
text: str, wikihow_artifacts: list[str] = [" [title]"], truncate_dots: bool = False, strip_text: bool = False
760+
):
761+
"""Comes from AiHarness"""
762+
# text = text.strip()
763+
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
764+
for dot_repl in wikihow_artifacts:
765+
text = text.replace(dot_repl, ". ")
766+
text = re.sub("\\[.*?\\]", "", text)
767+
text = text.replace(" ", " ")
768+
if truncate_dots:
769+
text = text.replace(r"\.+", r"\.")
770+
if strip_text:
771+
text = text.strip()
772+
return text
767773

774+
775+
def hellaswag_harness(line, task_name: str = None):
768776
ctx = f"{line['ctx_a']} {line['ctx_b'].capitalize()} "
769777
return Doc(
770778
task_name=task_name,
771-
query=preprocess(line["activity_label"] + ": " + ctx),
772-
choices=[preprocess(ending) for ending in line["endings"]],
779+
query=hellaswag_preprocess(line["activity_label"] + ": " + ctx),
780+
choices=[hellaswag_preprocess(ending) for ending in line["endings"]],
773781
gold_index=int(line["label"]) if line["label"] != "" else -1, # -1 for test
774782
# "metric": "choices_loglikelihood",
775783
)

src/lighteval/tasks/lighteval_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class LightevalTaskConfig:
8989
"""
9090

9191
name: str
92-
prompt_function: Callable[[dict, str], Doc]
92+
prompt_function: Callable[[dict, str], Doc | None]
9393
hf_repo: str
9494
hf_subset: str
9595
metric: ListLike[Metric | Metrics]

src/lighteval/tasks/multilingual/tasks.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from lighteval.metrics.normalizations import LogProbTokenNorm
2828
from lighteval.tasks.lighteval_task import LightevalTaskConfig
2929
from lighteval.tasks.templates.copa import get_copa_prompt_function
30+
from lighteval.tasks.templates.hellaswag import get_hellaswag_prompt_function
3031
from lighteval.tasks.templates.nli import get_nli_prompt_function
3132
from lighteval.tasks.templates.utils.formulation import (
3233
CFFormulation,
@@ -386,6 +387,9 @@
386387
),
387388
hf_repo="ai4bharat/IndicCOPA",
388389
hf_subset=f"translation-{standardize_tag(language.value)}",
390+
# Since we use trust_dataset, we have to be careful about what is inside the dataset
391+
# script. We thus lock the revision to ensure that the script doesn't change
392+
hf_revision="d356ef19a4eb287e88a51d07a56b73ba88c7f188",
389393
evaluation_splits=["test"],
390394
metric=[
391395
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
@@ -443,6 +447,141 @@
443447
]
444448

445449

450+
# ------------------------------- Hellaswag Tasks ------------------------------- #
451+
# Hellaswag is a commonsense reasoning task that requires models to complete a given scenario
452+
# with the most plausible ending. It tests the model's ability to understand and reason about
453+
# everyday situations and human behavior.
454+
455+
# MLMM-Hellaswag: Multilingual adaptation of Hellaswag
456+
# Paper: https://arxiv.org/abs/2306.07610
457+
# This is a multilingual version of Hellaswag, part of the MLMM (Massive Language Model Meta-Evaluation) benchmark.
458+
# It evaluates commonsense reasoning abilities across multiple languages.
459+
mlmm_hellaswag_tasks = [
460+
LightevalTaskConfig(
461+
name=f"hellaswag_{lang.value}_{formulation.name.lower()}",
462+
suite=["lighteval"],
463+
prompt_function=get_hellaswag_prompt_function(
464+
language=lang,
465+
adapter=lambda line: {
466+
# We don't use activity_label as they are not available
467+
"ctx_a": line["ctx_a"],
468+
"ctx_b": line["ctx_b"],
469+
"continuations": line["endings"],
470+
"gold_idx": int(line["label"]),
471+
},
472+
formulation=formulation,
473+
),
474+
hf_repo="jon-tow/okapi_hellaswag",
475+
hf_subset=standardize_tag(lang.value),
476+
# Since we use trust_dataset, we have to be careful about what is inside the dataset
477+
# script. We thus lock the revision to ensure that the script doesn't change
478+
hf_revision="96ed8e0dfc6172dad1d3df338d7b8ba6c1ff9d83",
479+
evaluation_splits=["validation"],
480+
metric=[
481+
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
482+
],
483+
trust_dataset=True,
484+
)
485+
for lang in [
486+
Language.ARABIC,
487+
Language.BENGALI,
488+
Language.CATALAN,
489+
Language.DANISH,
490+
Language.GERMAN,
491+
Language.SPANISH,
492+
Language.BASQUE,
493+
Language.FRENCH,
494+
Language.GUJARATI,
495+
Language.HINDI,
496+
Language.CROATIAN,
497+
Language.HUNGARIAN,
498+
Language.ARMENIAN,
499+
Language.INDONESIAN,
500+
Language.ICELANDIC,
501+
Language.ITALIAN,
502+
Language.KANNADA,
503+
Language.MALAYALAM,
504+
Language.MARATHI,
505+
Language.NORWEGIAN,
506+
Language.NEPALI,
507+
Language.DUTCH,
508+
Language.PORTUGUESE,
509+
Language.ROMANIAN,
510+
Language.RUSSIAN,
511+
Language.SLOVAK,
512+
Language.SERBIAN,
513+
Language.SWEDISH,
514+
Language.TAMIL,
515+
Language.TELUGU,
516+
Language.UKRAINIAN,
517+
Language.VIETNAMESE,
518+
Language.CHINESE,
519+
]
520+
for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
521+
]
522+
523+
# Hellaswag Turkish
524+
# This is a Turkish adaptation of the Hellaswag task.
525+
# While there's no specific paper for this version, it has been found to work well for evaluating
526+
# Turkish language models on commonsense reasoning tasks.
527+
528+
# We don't handle them in single task as there is quite a lot of differences (dataset/subset, dot replacement, etc.)
529+
# which would make it hard to read
530+
hellaswag_tur_tasks = [
531+
LightevalTaskConfig(
532+
name=f"community_hellaswag_{Language.TURKISH.value}_{formulation.name.lower()}",
533+
suite=["lighteval"],
534+
prompt_function=get_hellaswag_prompt_function(
535+
language=Language.TURKISH,
536+
adapter=lambda line: {
537+
"ctx_a": line["ctx_a"],
538+
"ctx_b": line["ctx_b"],
539+
"continuations": line["endings"],
540+
"gold_idx": int(line["label"]),
541+
},
542+
formulation=formulation,
543+
# https://github.com/malhajar17/lm-evaluation-harness_turkish/blob/main/lm_eval/tasks/hellaswag_tr-v0.2/utils.py
544+
wikihow_artifacts=[" [title]", " [başlık]", " [adım]", " [header]"],
545+
),
546+
hf_repo="malhajar/hellaswag_tr-v0.2",
547+
hf_subset="default",
548+
evaluation_splits=["validation"],
549+
metric=[
550+
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
551+
],
552+
)
553+
for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
554+
]
555+
556+
# Hellaswag Thai
557+
# This is a Thai adaptation of the Hellaswag task.
558+
# Similar to the Turkish version, there's no specific paper, but it has been found to be effective
559+
# for evaluating Thai language models on commonsense reasoning tasks.
560+
hellaswag_tha_tasks = [
561+
LightevalTaskConfig(
562+
name=f"community_hellaswag_{Language.THAI.value}_{formulation.name.lower()}",
563+
suite=["lighteval"],
564+
prompt_function=get_hellaswag_prompt_function(
565+
language=Language.THAI,
566+
adapter=lambda line: {
567+
"ctx_a": line["ctx_a"],
568+
"ctx_b": line["ctx_b"],
569+
"continuations": line["endings"],
570+
"gold_idx": int(line["label"]),
571+
},
572+
formulation=formulation,
573+
),
574+
hf_repo="HuggingFaceFW-Dev/hellaswag_thai",
575+
hf_subset="default",
576+
evaluation_splits=["validation"],
577+
few_shots_split="train",
578+
metric=[
579+
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
580+
],
581+
)
582+
for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
583+
]
584+
446585
TASKS_TABLE = [
447586
*xnli_tasks,
448587
*xnli2_tasks,
@@ -454,4 +593,7 @@
454593
*xcopa_tasks,
455594
*copa_indic_tasks,
456595
*parus_tasks,
596+
*mlmm_hellaswag_tasks,
597+
*hellaswag_tur_tasks,
598+
*hellaswag_tha_tasks,
457599
]

src/lighteval/tasks/templates/continuation.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class ContinuationDictAdapter(TypedDict):
8484

8585
def get_continuation_prompt_function(
8686
language: Language,
87-
adapter: Callable[[dict], ContinuationInput] | ContinuationDictAdapter,
87+
adapter: Callable[[dict], ContinuationInput | None] | ContinuationDictAdapter,
8888
formulation: Formulation = MCFFormulation(),
8989
):
9090
"""
@@ -121,11 +121,13 @@ def get_continuation_prompt_function(
121121
Returns:
122122
Callable: A function that generates Continuation prompt based on the given parameters.
123123
"""
124-
adapter_fn: Callable[[dict], ContinuationInput] = create_adapter_from_dict(adapter) # type: ignore
124+
adapter_fn = create_adapter_from_dict(adapter)
125125
translation_literals = TRANSLATION_LITERALS[language]
126126

127127
def prepare_prompt(line: dict):
128128
cont_input = adapter_fn(line)
129+
if cont_input is None:
130+
return None
129131

130132
instruction_val = cont_input.get("instruction")
131133
instruction = f"{instruction_val}\n" if instruction_val else ""
@@ -140,7 +142,11 @@ def prepare_prompt(line: dict):
140142
return cont_input, instruction, context, continuations
141143

142144
def prompt_fn_cf(line, task_name: str):
143-
cont_input, instruction, context, continuations = prepare_prompt(line)
145+
prepared_prompt = prepare_prompt(line)
146+
if prepared_prompt is None:
147+
return None
148+
149+
cont_input, instruction, context, continuations = prepared_prompt
144150

145151
context_follows_sentence_space = punctuation_ends_sentence(context, translation_literals)
146152
answers = build_answers(continuations, formulation, translation_literals, context_follows_sentence_space)
@@ -160,7 +166,11 @@ def prompt_fn_cf(line, task_name: str):
160166
)
161167

162168
def prompt_fn_mcf(line, task_name: str):
163-
cont_input, instruction, context, continuations = prepare_prompt(line)
169+
prepared_prompt = prepare_prompt(line)
170+
if prepared_prompt is None:
171+
return None
172+
173+
cont_input, instruction, context, continuations = prepared_prompt
164174

165175
options = build_choices(continuations, formulation, translation_literals)
166176
options = f"{options}\n" if options else ""

src/lighteval/tasks/templates/copa.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ class COPAAdapter(TypedDict):
7474

7575

7676
def get_copa_prompt_function(
77-
language: Language, adapter: Callable[[dict], COPAInput] | COPAAdapter, formulation: Formulation = MCFFormulation()
77+
language: Language,
78+
adapter: Callable[[dict], COPAInput | None] | COPAAdapter,
79+
formulation: Formulation = MCFFormulation(),
7880
):
7981
"""
8082
Create a templated prompt function for a COPA task.
@@ -109,7 +111,7 @@ def get_copa_prompt_function(
109111
Returns:
110112
Callable: A function that generates COPA prompts based on the given parameters.
111113
"""
112-
adapter_fn: Callable[[dict], COPAInput] = create_adapter_from_dict(adapter) # type: ignore
114+
adapter_fn = create_adapter_from_dict(adapter)
113115
continuation_prompt_fn = get_continuation_prompt_function(
114116
language, {"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"}, formulation
115117
)
@@ -120,6 +122,9 @@ def copa_prompt(
120122
task_name: str,
121123
):
122124
input_data = adapter_fn(line)
125+
if input_data is None:
126+
return None
127+
123128
context = capitalize(input_data["context"].rstrip(PUNCT))
124129
cause_or_effect_trans = (
125130
translation_literals.cause_word

0 commit comments

Comments
 (0)