Skip to content

Commit cb32024

Browse files
NathanHBanilaltunerclefourrier
authored
adds llm as judge using transformers (#223)
- lazy load of openai and transformers lib - able to use openai client to prompt transfrmers models - add llama 3.1 405B as llm as a judge in default metrics - update doc - make sure we are not rate limited - add option to use local transformer model as judge --------- Co-authored-by: anilaltuner <[email protected]> Co-authored-by: Clémentine Fourrier <[email protected]>
1 parent 38c03ad commit cb32024

File tree

9 files changed

+152
-91
lines changed

9 files changed

+152
-91
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,11 @@ These metrics need the model to generate an output. They are therefore slower.
415415
- `maj_at_4_math` (Lighteval): Majority choice evaluation, using the math normalisation for the predictions and gold
416416
- `quasi_exact_match_gsm8k` (Harness): Fraction of instances where the normalized prediction matches the normalized gold (normalization done for gsm8k, where latex symbols, units, etc are removed)
417417
- `maj_at_8_gsm8k` (Lighteval): Majority choice evaluation, using the gsm8k normalisation for the predictions and gold
418+
- LLM-as-Judge:
419+
- `llm_judge_gpt3p5`: Can be used for any generative task, the model will be scored by a GPT3.5 model using the openai API
420+
- `llm_judge_llama_3_405b`: Can be used for any generative task, the model will be scored by a Llama 3.405B model using the openai API
421+
- `llm_judge_multi_turn_gpt3p5`: Can be used for any generative task, the model will be scored by a GPT3.5 model using the openai API. It is used for multiturn tasks like mt-bench.
422+
- `llm_judge_multi_turn_llama_3_405b`: Can be used for any generative task, the model will be scored by a Llama 3.405B model using the openai API. It is used for multiturn tasks like mt-bench.
418423

419424
### Metrics for specific tasks
420425
To keep compatibility with the Harness for some specific tasks, we ported their evaluations more or less as such. They include `drop` (for the DROP dataset) and `truthfulqa_mc_metrics` (for TruthfulQA). In general, except for tasks where the dataset has very different formatting than usual (another language, programming language, math, ...), we want to use standard implementations of the above metrics. It makes little sense to have 10 different versions of an exact match depending on the task. However, most of the above metrics are parametrizable so that you can change the normalization applied easily for experimental purposes.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ tests = ["pytest==7.4.0"]
9292
dev = ["lighteval[accelerate,quality,tests]"]
9393
extended_tasks = [
9494
"langdetect", # ifeval
95-
"openai", # mt-bench
95+
"openai", # llm as a judge using openai models
9696
]
9797

9898
[project.urls]

src/lighteval/logging/evaluation_tracker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ class EnhancedJSONEncoder(json.JSONEncoder):
5757

5858
def default(self, o):
5959
if is_dataclass(o):
60-
return asdict(o)
60+
try:
61+
return asdict(o)
62+
except Exception:
63+
return str(o)
6164
if callable(o):
6265
return o.__name__
6366
if isinstance(o, Enum):

src/lighteval/metrics/llm_as_judge.py

Lines changed: 83 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -25,56 +25,57 @@
2525
import json
2626
import re
2727
import time
28-
from typing import Optional
28+
from typing import Any, Optional
29+
30+
import torch
31+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2932

3033
from lighteval.logging.hierarchical_logger import hlog_warn
31-
from lighteval.utils import NO_OPENAI_ERROR_MSG, is_openai_available
3234

3335

34-
class JudgeOpenAI:
36+
class JudgeLM:
3537
"""
36-
A class representing a judge for evaluating answers using the OpenAI API.
38+
A class representing a judge for evaluating answers using either the OpeanAI or Transformers library.
3739
3840
Args:
39-
model (str): The name of the OpenAI model to use.
40-
seed (int): The seed value for generating random responses.
41-
temperature (float): The temperature value for controlling the randomness of the responses.
41+
model (str): The name of the model to use.
4242
templates_path (str): The path to the JSON file containing the templates for prompts.
43+
multi_turn (bool): Whether to use multi-turn prompts
44+
url (Optional[str]): The URL for the OpenAI API.
45+
api_key (Optional[str]): The API key for the OpenAI API (either OpenAI or HF key).
4346
4447
Attributes:
45-
client: An instance of the OpenAI client.
46-
model (str): The name of the OpenAI model.
47-
seed (int): The seed value, passed to the API when generating responses.
48-
temperature (float): The temperature value, passed to the API when generating responses.
48+
model (str): The name of the model.
4949
templates (dict): A dictionary containing the templates for prompts.
5050
one_score_pattern (re.Pattern): A regular expression pattern for extracting scores from the response.
5151
one_score_pattern_backup (re.Pattern): A backup regular expression pattern for extracting scores.
52-
API_MAX_RETRY (int): The maximum number of API retries.
53-
API_RETRY_SLEEP (int): The sleep time between API retries.
54-
max_tokens (int): The maximum number of tokens allowed in the response.
52+
API_MAX_RETRY (int): The maximum number of retries for the API.
53+
API_RETRY_SLEEP (int): The sleep time between retries.
54+
client (Optional[OpenAI]): The OpenAI client.
55+
pipe (Optional[pipeline]): The Transformers pipeline.
56+
use_transformers (bool): Whether to use the Transformers library.
57+
url (Optional[str]): The URL for the OpenAI API.
58+
api_key (Optional[str]): The API key for the OpenAI API (either OpenAI or HF key).
5559
5660
Methods:
57-
evaluate_answer: Evaluates an answer using the OpenAI API.
61+
evaluate_answer: Evaluates an answer using the OpenAI API or Transformers library.
5862
__get_prompts_multi_turn: Generates prompts for multi-turn conversations.
5963
__get_prompts_single_turn: Generates prompts for single-turn conversations.
6064
__process_judge_response: Processes the judge's response and extracts the score.
65+
__call_openai_api: Calls the OpenAI API to get the judge's response.
66+
__lazy_load_client: Lazy loads the OpenAI client or Transformers pipeline.
6167
"""
6268

6369
def __init__(
6470
self,
6571
model: str,
66-
seed: int,
67-
temperature: float,
6872
templates_path: str,
69-
openai_api_key: str,
7073
multi_turn: bool = False,
74+
url: Optional[str] = None,
75+
api_key: Optional[str] = None,
7176
):
72-
self.client = None # loaded lazily
73-
self.openai_api_key = openai_api_key
74-
self.model = model
75-
self.seed = seed
76-
self.temperature = temperature
7777
self.multi_turn = multi_turn
78+
self.model = model
7879

7980
data = []
8081
with open(templates_path, "r") as f:
@@ -89,40 +90,59 @@ def __init__(
8990
# the second is for the backup case: [score]
9091
self.one_score_pattern = re.compile(r"\[\[(\d+\.?\d*)\]\]")
9192
self.one_score_pattern_backup = re.compile(r"\[(\d+\.?\d*)\]")
93+
self.API_MAX_RETRY = 3
94+
self.API_RETRY_SLEEP = 1
95+
96+
self.client = None
97+
self.pipe = None
98+
99+
self.use_transformers = url is None and api_key is None
100+
101+
self.url = url
102+
self.api_key = api_key
103+
104+
def __lazy_load_client(self):
105+
if self.use_transformers:
106+
if self.pipe is None:
107+
transformers_model = AutoModelForCausalLM.from_pretrained(
108+
self.model, torch_dtype=torch.bfloat16, trust_remote_code=False, device_map="cuda"
109+
)
110+
tokenizer = AutoTokenizer.from_pretrained(self.model)
111+
self.pipe = pipeline(
112+
"text-generation",
113+
model=transformers_model,
114+
tokenizer=tokenizer,
115+
max_new_tokens=50,
116+
)
117+
else:
118+
if self.client is None:
119+
from openai import OpenAI
92120

93-
self.API_MAX_RETRY = 16
94-
self.API_RETRY_SLEEP = 10
95-
self.max_tokens = 2048
121+
if self.url is None:
122+
self.client = OpenAI(api_key=self.api_key)
123+
else:
124+
self.client = OpenAI(base_url=self.url, api_key=self.api_key)
96125

97126
def evaluate_answer(
98127
self, questions: list[str], answers: list[str], references: list[str]
99-
) -> tuple[int, list[dict[str, str]], str]:
128+
) -> tuple[list[int], list[list[dict[str, str]]], list[str | None | Any]]:
100129
"""
101-
Evaluates an answer using the OpenAI API.
130+
Evaluates an answer using either Transformers or OpenAI API.
102131
103132
Args:
104133
questions (list[str]): A list of questions (can be a list because of multi-turn conversations)
105134
answers (list[str]): A list of answers, one for each question.
106135
references (list[str]): A list of reference answers, one for each question (sometimes not available)
107-
single_turn (bool): Indicates whether the conversation is single-turn or multi-turn.
108136
109137
Returns:
110138
A tuple containing the score, prompts, and judgment.
111-
112-
Raises:
113-
Exception: If an error occurs during the API call.
114139
"""
115-
if self.client is None:
116-
if not is_openai_available():
117-
raise ImportError(NO_OPENAI_ERROR_MSG)
118-
119-
from openai import OpenAI
120-
121-
self.client = OpenAI(api_key=self.openai_api_key)
140+
# lazy loading of the pipeline
141+
self.__lazy_load_client()
122142

123143
prompts = [
124144
self.__get_prompts_single_turn(
125-
questions[0], answers[0], references[0] if references is not None and len(references) > 0 else None
145+
questions[0], answers[0], references[0] if references and len(references) > 0 else None
126146
)
127147
]
128148

@@ -132,28 +152,15 @@ def evaluate_answer(
132152
)
133153
prompts.append(prompts_multi_turn)
134154

135-
responses = []
155+
judgments = []
136156
for prompt in prompts:
137-
for _ in range(self.API_MAX_RETRY):
138-
try:
139-
response = self.client.chat.completions.create(
140-
model=self.model,
141-
seed=self.seed,
142-
temperature=self.temperature,
143-
messages=prompt,
144-
max_tokens=self.max_tokens,
145-
n=1,
146-
)
147-
responses.append(response)
148-
break
149-
except Exception as e:
150-
hlog_warn(f"{type(e), e}")
151-
time.sleep(self.API_RETRY_SLEEP)
152-
153-
if len(responses) == 0:
154-
raise Exception("Failed to get response from the API")
155-
156-
judgments = [response.choices[0].message.content for response in responses]
157+
if self.client is not None:
158+
response = self.__call_api(prompt)
159+
else:
160+
response = self.pipe(prompt)[0]["generated_text"]
161+
response = response[-1]["content"]
162+
judgments.append(response)
163+
157164
scores = [self.__process_judge_response(judgment) for judgment in judgments]
158165

159166
return scores, prompts, judgments
@@ -235,3 +242,18 @@ def __process_judge_response(self, judgment: str) -> int:
235242
rating = -1
236243

237244
return rating
245+
246+
def __call_api(self, prompt):
247+
for _ in range(self.API_MAX_RETRY):
248+
try:
249+
response = self.client.chat.completions.create(
250+
model=self.model,
251+
messages=prompt,
252+
max_tokens=512,
253+
n=1,
254+
)
255+
return response.choices[0].message.content
256+
except Exception as e:
257+
hlog_warn(f"{type(e), e}")
258+
time.sleep(self.API_RETRY_SLEEP)
259+
raise Exception("Failed to get response from the API")

src/lighteval/metrics/metrics.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,9 @@ class Metrics(Enum):
228228
corpus_level_fn=np.mean,
229229
higher_is_better=True,
230230
)
231-
llm_judge_multi_turn_openai = SampleLevelMetricGrouping(
231+
llm_judge_multi_turn_gpt3p5 = SampleLevelMetricGrouping(
232232
metric_name=["single_turn", "multi_turn"],
233-
higher_is_better=True,
233+
higher_is_better={"single_turn": True, "multi_turn": True},
234234
category=MetricCategory.LLM_AS_JUDGE_MULTI_TURN,
235235
use_case=MetricUseCase.SUMMARIZATION,
236236
sample_level_fn=JudgeLLM(
@@ -243,9 +243,24 @@ class Metrics(Enum):
243243
"multi_turn": np.mean,
244244
},
245245
)
246-
llm_judge_openai = SampleLevelMetricGrouping(
246+
llm_judge_multi_turn_llama_3_405b = SampleLevelMetricGrouping(
247+
metric_name=["single_turn", "multi_turn"],
248+
higher_is_better={"single_turn": True, "multi_turn": True},
249+
category=MetricCategory.LLM_AS_JUDGE_MULTI_TURN,
250+
use_case=MetricUseCase.SUMMARIZATION,
251+
sample_level_fn=JudgeLLM(
252+
judge_model_name="meta-llama/Meta-Llama-3.1-405B-Instruct-FP8",
253+
template_path=os.path.join(os.path.dirname(__file__), "judge_prompts.jsonl"),
254+
multi_turn=True,
255+
).compute,
256+
corpus_level_fn={
257+
"single_turn": np.mean,
258+
"multi_turn": np.mean,
259+
},
260+
)
261+
llm_judge_gpt3p5 = SampleLevelMetricGrouping(
247262
metric_name=["judge_score"],
248-
higher_is_better=True,
263+
higher_is_better={"judge_score": True},
249264
category=MetricCategory.LLM_AS_JUDGE,
250265
use_case=MetricUseCase.SUMMARIZATION,
251266
sample_level_fn=JudgeLLM(
@@ -257,6 +272,20 @@ class Metrics(Enum):
257272
"judge_score": np.mean,
258273
},
259274
)
275+
llm_judge_llama_3_405b = SampleLevelMetricGrouping(
276+
metric_name=["judge_score"],
277+
higher_is_better={"judge_score": True},
278+
category=MetricCategory.LLM_AS_JUDGE,
279+
use_case=MetricUseCase.SUMMARIZATION,
280+
sample_level_fn=JudgeLLM(
281+
judge_model_name="meta-llama/Meta-Llama-3.1-405B-Instruct-FP8",
282+
template_path=os.path.join(os.path.dirname(__file__), "judge_prompts.jsonl"),
283+
multi_turn=False,
284+
).compute,
285+
corpus_level_fn={
286+
"judge_score": np.mean,
287+
},
288+
)
260289
loglikelihood_acc = SampleLevelMetric(
261290
metric_name="acc",
262291
sample_level_fn=LoglikelihoodAcc().compute,

src/lighteval/metrics/metrics_sample.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
import nltk
3131
import numpy as np
32+
from huggingface_hub import HfApi
3233
from nltk.metrics.distance import edit_distance
3334
from nltk.tokenize import word_tokenize
3435
from nltk.tokenize.treebank import TreebankWordTokenizer
@@ -40,7 +41,7 @@
4041
from lighteval.metrics.imports.bert_scorer import BERTScorer
4142
from lighteval.metrics.imports.data_stats_metric import DataStatsMetric
4243
from lighteval.metrics.imports.summac import SummaCZS
43-
from lighteval.metrics.llm_as_judge import JudgeOpenAI
44+
from lighteval.metrics.llm_as_judge import JudgeLM
4445
from lighteval.metrics.normalizations import remove_braces, remove_braces_and_strip
4546
from lighteval.tasks.requests import Doc
4647
from lighteval.utils import as_list
@@ -626,22 +627,32 @@ def edit_similarity(self, s1, s2):
626627

627628

628629
class JudgeLLM:
629-
available_models = ["gpt-3.5-turbo", "gpt-4o", "gpt-4-turbo", "gpt-4"]
630+
available_models_openai = ["gpt-3.5-turbo", "gpt-4o", "gpt-4-turbo", "gpt-4"]
630631

631-
def __init__(self, judge_model_name: str, template_path: str, multi_turn: bool = False):
632-
if judge_model_name not in self.available_models:
633-
raise ValueError(f"{judge_model_name} not in available models for llm as a judge metric")
632+
def __init__(
633+
self, judge_model_name: str, template_path: str, multi_turn: bool = False, use_transformers: bool = False
634+
) -> None:
635+
if judge_model_name in self.available_models_openai:
636+
api_key = os.getenv("OPENAI_API_KEY")
637+
url = None
638+
elif not use_transformers:
639+
api_key = os.getenv("HF_TOKEN")
640+
url = "https://api-inference.huggingface.co/v1/"
641+
else:
642+
api = HfApi()
643+
models = api.list_models(model_name=judge_model_name)
644+
url = None
645+
api_key = None
646+
if not models:
647+
raise ValueError(f"{judge_model_name} not in available models for llm as a judge metric")
634648

635-
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
636649
self.multi_turn = multi_turn
637-
638-
self.judge = JudgeOpenAI(
650+
self.judge = JudgeLM(
639651
model=judge_model_name,
640-
seed=42,
641-
temperature=0.0,
642652
templates_path=template_path,
643-
openai_api_key=OPENAI_API_KEY,
644653
multi_turn=multi_turn,
654+
api_key=api_key,
655+
url=url,
645656
)
646657

647658
def compute(self, predictions: list[str], formatted_doc: Doc, **kwargs) -> dict[str, float]:

src/lighteval/tasks/extended/mt_bench/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
# ruff: noqa: F405, F403, F401, I001
2424
from lighteval.tasks.lighteval_task import LightevalTaskConfig
2525
from lighteval.tasks.requests import Doc
26+
from lighteval.metrics.metrics import Metrics
2627

2728

2829
def mt_bench_prompt(line, task_name: str = None):
@@ -55,7 +56,7 @@ def mt_bench_prompt(line, task_name: str = None):
5556
evaluation_splits=["train"],
5657
few_shots_split="",
5758
few_shots_select="random",
58-
metric=["llm_judge_multi_turn_openai"],
59+
metric=[Metrics.llm_judge_multi_turn_gpt3p5],
5960
generation_size=1024,
6061
stop_sequence=[],
6162
)

0 commit comments

Comments
 (0)