Skip to content

Commit df22af8

Browse files
authored
Add mt-bench (#75)
What this PR does: - Uses custom metrics and tasks to add llm a as judge - adds multi turn generation - Adds mt-bench metric This implementation uses mt-bench prompts from [InflectionAI](https://github.com/InflectionAI/Inflection-Benchmarks). The code is inspired from the original implementation of mt-bench with notable differences. - mt-bench uses a custom-made chat templating system, we use the tokenizer - mt-bench uses an old version of the openai API, we use the newest one, with very simplified logic for chat prompt formating. We can easily add more models to act as judge. - We do not use varying temperature based on the sample we are evaluating. All samples are generated using `do_sample=False` and temperature set to `0.0`.
1 parent a91cff3 commit df22af8

File tree

16 files changed

+716
-29
lines changed

16 files changed

+716
-29
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ quality = ["ruff==v0.2.2","pre-commit"]
8888
tests = ["pytest==7.4.0"]
8989
dev = ["lighteval[accelerate,quality,tests]"]
9090
extended_tasks = [
91-
"langdetect", #ifeval
91+
"langdetect", # ifeval
92+
"openai", # mt-bench
9293
]
9394

9495
[project.urls]

src/lighteval/evaluator.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ def evaluate( # noqa: C901
8888
full_resps = lm.greedy_until_with_logits(requests, override_bs=override_bs)
8989
elif request_type == RequestType.LOGLIKELIHOOD_ROLLING:
9090
full_resps = lm.loglikelihood_rolling(requests, override_bs=override_bs)
91+
elif request_type == RequestType.GREEDY_UNTIL_MULTI_TURN:
92+
full_resps = lm.greedy_until_multi_turn(requests, override_bs=override_bs)
9193
else:
9294
raise NotImplementedError(f"Request type {request_type} not supported")
9395

@@ -115,8 +117,22 @@ def evaluate( # noqa: C901
115117
# using a deep copy here because process results pops from the model responses
116118
metrics = task.process_results(doc, copy.deepcopy(model_responses))
117119

120+
# Remove the user_prompt from the metrics in case of llm-as-judge metric
121+
if "user_prompt" in metrics:
122+
user_prompt = metrics["user_prompt"]
123+
del metrics["user_prompt"]
124+
else:
125+
user_prompt = None
126+
if "judgement" in metrics:
127+
judgement = metrics["judgement"]
128+
del metrics["judgement"]
129+
else:
130+
judgement = None
131+
118132
evaluation_tracker.metrics_logger.log(task_example_id.task_name, metrics)
119-
evaluation_tracker.details_logger.log(task_example_id.task_name, task, doc, model_responses, metrics)
133+
evaluation_tracker.details_logger.log(
134+
task_example_id.task_name, task, doc, model_responses, metrics, (user_prompt, judgement)
135+
)
120136

121137
return evaluation_tracker
122138

src/lighteval/few_shot_manager.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from itertools import cycle
2828
from typing import TYPE_CHECKING, Optional
2929

30-
from transformers import AutoTokenizer
30+
from transformers import AutoTokenizer, PreTrainedTokenizer
3131

3232
from lighteval.logging.hierarchical_logger import hlog_warn
3333
from lighteval.tasks.requests import Doc
@@ -219,6 +219,46 @@ def get_examples(
219219
)
220220
return instruction + labeled_examples + example
221221

222+
def create_multi_turn_contexts(
223+
self, doc: Doc, use_chat_template: bool, system_prompt: Optional[str], tokenizer: PreTrainedTokenizer
224+
) -> list[str]:
225+
"""Creates N contexts (depending on the number of turn) for a tasks.
226+
Multi turn tasks need use chat templating.
227+
228+
Args:
229+
doc (Doc): Formated document.
230+
use_chat_template (bool): wether or not to use chat template. Will fail if false.
231+
system_prompt (Optional[str]): The system prompt to use
232+
tokenizer (PreTrainedTokenizer): The tokenizer used for the chat template
233+
234+
Raises:
235+
ValueError: If use_chat_template is set to false.
236+
237+
Returns:
238+
list[str]: contexts for every turn
239+
"""
240+
if not use_chat_template:
241+
raise ValueError("You need to use the chat template to create multi turn contexts")
242+
243+
role_content_list = []
244+
if system_prompt is not None:
245+
role_content_list.append({"role": "system", "content": system_prompt})
246+
247+
for i in doc.specific["multi_turn_queries"]:
248+
role_content_list.append({"role": "user", "content": i})
249+
role_content_list.append({"role": "assistant", "content": "{model_response}"})
250+
role_content_list.pop(-1)
251+
252+
contexts = []
253+
offset = 2 if system_prompt is not None else 1
254+
for i in range(0, len(role_content_list), offset + 1):
255+
c = tokenizer.apply_chat_template(
256+
role_content_list[: i + offset], add_generation_prompt=True, tokenize=False, add_special_tokens=False
257+
)
258+
contexts.append(c)
259+
260+
return contexts, 0
261+
222262
def fewshot_context(
223263
self,
224264
task: "LightevalTask",

src/lighteval/logging/info_loggers.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import os
2525
import time
2626
from dataclasses import asdict, dataclass, field
27-
from typing import Union
27+
from typing import Optional, Union
2828

2929
import git
3030
import numpy as np
@@ -205,6 +205,9 @@ class Detail:
205205
choices: list = field(default_factory=list)
206206
gold_index: list = field(default_factory=list)
207207
metrics: dict = field(default_factory=dict)
208+
judement_prompt: str = None
209+
judgement: str = None
210+
specifics: dict = field(default_factory=dict)
208211

209212
@dataclass
210213
class CompiledDetail:
@@ -302,7 +305,15 @@ class CompiledHash:
302305
compiled_details: dict[str, CompiledDetail] = collections.defaultdict(CompiledDetail)
303306
compiled_details_over_all_tasks: CompiledDetailOverAllTasks = CompiledDetailOverAllTasks()
304307

305-
def log(self, task_name: str, task: LightevalTask, doc: Doc, outputs: list[ModelReturn], metrics: dict) -> None:
308+
def log(
309+
self,
310+
task_name: str,
311+
task: LightevalTask,
312+
doc: Doc,
313+
outputs: list[ModelReturn],
314+
metrics: dict,
315+
llm_as_prompt_judgement: Optional[tuple[str, str]] = None,
316+
) -> None:
306317
"""Stores the relevant information for one sample of one task to the total list of samples stored in the DetailsLogger.
307318
308319
Args:
@@ -311,6 +322,8 @@ def log(self, task_name: str, task: LightevalTask, doc: Doc, outputs: list[Model
311322
doc (Doc): Current sample that we want to store.
312323
outputs (list[ModelReturn]): Model outputs for the current sample
313324
metrics (_type_): Model scores for said sample on the current task's metrics.
325+
llm_as_prompt_judgement (tuple[str, str]): Tuple containing the
326+
prompt passed to the judge and the judgement for the current sample when using llm-as-judge metric.
314327
"""
315328
detail = self.Detail()
316329
detail.example = doc.query
@@ -354,6 +367,11 @@ def log(self, task_name: str, task: LightevalTask, doc: Doc, outputs: list[Model
354367
detail.choices = doc.choices
355368
detail.gold_index = as_list(doc.gold_index)
356369
pred_saved = True
370+
if task.has_metric_category[MetricCategory.GENERATIVE_MULTI_TURN]:
371+
pred_saved = True
372+
detail.judement_prompt = llm_as_prompt_judgement[0]
373+
detail.judgement = llm_as_prompt_judgement[1]
374+
detail.specifics = doc.specific
357375
if not pred_saved:
358376
raise NotImplementedError(
359377
"No metric prediction saved."
@@ -364,7 +382,7 @@ def log(self, task_name: str, task: LightevalTask, doc: Doc, outputs: list[Model
364382

365383
hash = self.Hash()
366384
hash.example = xxhash.xxh64(doc.query).hexdigest()
367-
hash.full_prompt = xxhash.xxh64(doc.ctx).hexdigest()
385+
hash.full_prompt = xxhash.xxh64(str(doc.ctx)).hexdigest()
368386
hash.input_tokens = xxhash.xxh64(str([o.input_tokens for o in outputs])).hexdigest()
369387
hash.cont_tokens = xxhash.xxh64(str([o.generated_tokens for o in outputs])).hexdigest()
370388
self.hashes[task_name].append(hash)

src/lighteval/metrics/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,14 @@ def apply_multichoice_metric_one_token(results: list[ModelReturn], formatted_doc
146146
)
147147

148148
return results, outputs
149+
150+
151+
def apply_generative_multi_turn_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[str]):
152+
outputs = {}
153+
predictions = results.pop(0).result
154+
155+
for metric in metrics:
156+
if Metrics[metric].value.category == MetricCategory.GENERATIVE_MULTI_TURN:
157+
outputs.update(Metrics[metric].value.compute(predictions=predictions, formatted_doc=formatted_doc))
158+
159+
return results, outputs

src/lighteval/metrics/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class MetricCategory(Enum):
2828
TARGET_PERPLEXITY = auto()
2929
PERPLEXITY = auto()
3030
GENERATIVE = auto()
31+
GENERATIVE_MULTI_TURN = auto()
3132
GENERATIVE_LOGPROB = auto()
3233
MULTICHOICE = auto()
3334
MULTICHOICE_ONE_TOKEN = auto()

src/lighteval/models/abstract_model.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,14 @@
2727
from transformers import BatchEncoding
2828

2929
from lighteval.models.model_config import EnvConfig
30-
from lighteval.models.model_output import GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn
30+
from lighteval.models.model_output import (
31+
GenerateMultiTurnReturn,
32+
GenerateReturn,
33+
LoglikelihoodReturn,
34+
LoglikelihoodSingleTokenReturn,
35+
)
3136
from lighteval.tasks.requests import (
37+
GreedyUntilMultiTurnRequest,
3238
GreedyUntilRequest,
3339
GreedyUntilWithLogitsRequest,
3440
LoglikelihoodRequest,
@@ -102,6 +108,12 @@ def greedy_until_with_logits(
102108
returns_logits=True,
103109
)
104110

111+
def greedy_until_multi_turn( # noqa: C901
112+
self, requests: list[GreedyUntilMultiTurnRequest], override_bs: Optional[int] = None
113+
) -> GenerateMultiTurnReturn:
114+
"""Generates responses using a greedy decoding strategy until certain ending conditions are met."""
115+
return NotImplemented
116+
105117
@abstractmethod
106118
def greedy_until(
107119
self,

0 commit comments

Comments
 (0)