Skip to content

Commit 7f6919a

Browse files
NathanHBNathan Habib
andauthored
Nathan add logging to metrics (#157)
what this PR does: If you want to log out something comming from the metrics, simply return it in the metric dict. for example, if you want to log out the judge response when using llm_as_judge, simply return the response in the dict. ``` { "score": score, "judgement": judge_response } ```` the `judgement` field is a string and will not be aggregated. however, it will be logged in the details for each sample. --------- Co-authored-by: Nathan Habib <[email protected]>
1 parent ddb5bea commit 7f6919a

File tree

6 files changed

+30
-30
lines changed

6 files changed

+30
-30
lines changed

src/lighteval/evaluator.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -117,22 +117,8 @@ def evaluate( # noqa: C901
117117
# using a deep copy here because process results pops from the model responses
118118
metrics = task.process_results(doc, copy.deepcopy(model_responses))
119119

120-
# Remove the user_prompt from the metrics in case of llm-as-judge metric
121-
if "user_prompt" in metrics:
122-
user_prompt = metrics["user_prompt"]
123-
del metrics["user_prompt"]
124-
else:
125-
user_prompt = None
126-
if "judgement" in metrics:
127-
judgement = metrics["judgement"]
128-
del metrics["judgement"]
129-
else:
130-
judgement = None
131-
132120
evaluation_tracker.metrics_logger.log(task_example_id.task_name, metrics)
133-
evaluation_tracker.details_logger.log(
134-
task_example_id.task_name, task, doc, model_responses, metrics, (user_prompt, judgement)
135-
)
121+
evaluation_tracker.details_logger.log(task_example_id.task_name, task, doc, model_responses, metrics)
136122

137123
return evaluation_tracker
138124

src/lighteval/logging/info_loggers.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,6 @@ class Detail:
205205
choices: list = field(default_factory=list)
206206
gold_index: list = field(default_factory=list)
207207
metrics: dict = field(default_factory=dict)
208-
judement_prompt: str = None
209-
judgement: str = None
210208
specifics: dict = field(default_factory=dict)
211209

212210
@dataclass
@@ -367,11 +365,16 @@ def log(
367365
detail.choices = doc.choices
368366
detail.gold_index = as_list(doc.gold_index)
369367
pred_saved = True
370-
if task.has_metric_category[MetricCategory.GENERATIVE_MULTI_TURN]:
368+
if (
369+
task.has_metric_category[MetricCategory.LLM_AS_JUDGE_MULTI_TURN]
370+
or task.has_metric_category[MetricCategory.LLM_AS_JUDGE]
371+
):
372+
detail.choices = doc.choices
373+
detail.gold_index = as_list(doc.gold_index)
371374
pred_saved = True
372-
detail.judement_prompt = llm_as_prompt_judgement[0]
373-
detail.judgement = llm_as_prompt_judgement[1]
375+
374376
detail.specifics = doc.specific
377+
375378
if not pred_saved:
376379
raise NotImplementedError(
377380
"No metric prediction saved."
@@ -487,6 +490,8 @@ def aggregate(self, task_dict: dict[str, LightevalTask], bootstrap_iters: int =
487490
except OverflowError:
488491
hlog_warn(f"{task_name}, {metric_name} got an OVERFLOW ERROR when aggregating.")
489492
metric_result = float("nan")
493+
except KeyError:
494+
continue
490495

491496
if isinstance(metric_result, dict): # For some corpus level grouping metrics
492497
self.metric_aggregated[task_name].update(metric_result)

src/lighteval/metrics/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,15 @@ def apply_multichoice_metric_one_token(results: list[ModelReturn], formatted_doc
148148
return results, outputs
149149

150150

151-
def apply_generative_multi_turn_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[str]):
151+
def apply_llm_as_judge_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[str]):
152152
outputs = {}
153153
predictions = results.pop(0).result
154154

155155
for metric in metrics:
156-
if Metrics[metric].value.category == MetricCategory.GENERATIVE_MULTI_TURN:
156+
if (
157+
Metrics[metric].value.category == MetricCategory.LLM_AS_JUDGE_MULTI_TURN
158+
or Metrics[metric].value.category == MetricCategory.LLM_AS_JUDGE
159+
):
157160
outputs.update(Metrics[metric].value.compute(predictions=predictions, formatted_doc=formatted_doc))
158161

159162
return results, outputs

src/lighteval/metrics/metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ class Metrics(Enum):
228228
llm_judge_multi_turn = SampleLevelMetricGrouping(
229229
metric=["single_turn", "multi_turn"],
230230
higher_is_better=True,
231-
category=MetricCategory.GENERATIVE_MULTI_TURN,
231+
category=MetricCategory.LLM_AS_JUDGE_MULTI_TURN,
232232
use_case=MetricUseCase.SUMMARIZATION,
233233
sample_level_fn=JudgeLLM(
234234
judge_model_name="gpt-3.5-turbo",
@@ -243,7 +243,7 @@ class Metrics(Enum):
243243
llm_judge = SampleLevelMetricGrouping(
244244
metric=["judge_score"],
245245
higher_is_better=True,
246-
category=MetricCategory.GENERATIVE,
246+
category=MetricCategory.LLM_AS_JUDGE,
247247
use_case=MetricUseCase.SUMMARIZATION,
248248
sample_level_fn=JudgeLLM(
249249
judge_model_name="gpt-3.5-turbo",

src/lighteval/metrics/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ class MetricCategory(Enum):
2828
TARGET_PERPLEXITY = auto()
2929
PERPLEXITY = auto()
3030
GENERATIVE = auto()
31-
GENERATIVE_MULTI_TURN = auto()
31+
LLM_AS_JUDGE_MULTI_TURN = auto()
32+
LLM_AS_JUDGE = auto()
3233
GENERATIVE_LOGPROB = auto()
3334
MULTICHOICE = auto()
3435
MULTICHOICE_ONE_TOKEN = auto()

src/lighteval/tasks/lighteval_task.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from lighteval.metrics import (
3535
apply_generative_logprob_metric,
3636
apply_generative_metric,
37-
apply_generative_multi_turn_metric,
37+
apply_llm_as_judge_metric,
3838
apply_multichoice_metric,
3939
apply_multichoice_metric_one_token,
4040
apply_perplexity_metric,
@@ -412,8 +412,10 @@ def get_request_type(self) -> list[RequestType]:
412412
request_types.append(RequestType.LOGLIKELIHOOD_ROLLING)
413413
if self.has_metric_category[MetricCategory.GENERATIVE]:
414414
request_types.append(RequestType.GREEDY_UNTIL)
415-
if self.has_metric_category[MetricCategory.GENERATIVE_MULTI_TURN]:
415+
if self.has_metric_category[MetricCategory.LLM_AS_JUDGE_MULTI_TURN]:
416416
request_types.append(RequestType.GREEDY_UNTIL_MULTI_TURN)
417+
if self.has_metric_category[MetricCategory.LLM_AS_JUDGE]:
418+
request_types.append(RequestType.GREEDY_UNTIL)
417419
if self.has_metric_category[MetricCategory.GENERATIVE_LOGPROB]:
418420
request_types.append(RequestType.GREEDY_UNTIL_WITH_LOGITS)
419421
if self.has_metric_category[MetricCategory.MULTICHOICE]:
@@ -504,7 +506,7 @@ def construct_requests(
504506
choices=formatted_doc.choices,
505507
)
506508
]
507-
if self.has_metric_category[MetricCategory.GENERATIVE_MULTI_TURN]:
509+
if self.has_metric_category[MetricCategory.LLM_AS_JUDGE_MULTI_TURN]:
508510
requests[RequestType.GREEDY_UNTIL_MULTI_TURN] += [
509511
GreedyUntilMultiTurnRequest(
510512
task_name=current_task_name,
@@ -561,8 +563,11 @@ def process_results(self, formatted_doc: Doc, results: list[ModelReturn]) -> dic
561563
results=results, formatted_doc=formatted_doc, metrics=self.metrics
562564
)
563565
outputs.update(cur_outputs)
564-
if self.has_metric_category[MetricCategory.GENERATIVE_MULTI_TURN]:
565-
results, cur_outputs = apply_generative_multi_turn_metric(
566+
if (
567+
self.has_metric_category[MetricCategory.LLM_AS_JUDGE_MULTI_TURN]
568+
or self.has_metric_category[MetricCategory.LLM_AS_JUDGE]
569+
):
570+
results, cur_outputs = apply_llm_as_judge_metric(
566571
results=results, formatted_doc=formatted_doc, metrics=self.metrics
567572
)
568573
outputs.update(cur_outputs)

0 commit comments

Comments
 (0)