Skip to content

Commit 80770da

Browse files
rolshovenNathanHB
andauthored
Fixed bug that prevented the metrics from being mixed (some batched and others not batched) (#958)
Co-authored-by: Nathan Habib <[email protected]>
1 parent 2b0fffe commit 80770da

File tree

1 file changed

+30
-25
lines changed

1 file changed

+30
-25
lines changed

src/lighteval/metrics/__init__.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,36 @@
2727

2828

2929
def apply_metric(responses: list[ModelResponse], docs: list[Doc], metrics: list[Metric]):
30+
# Separate batched and non-batched metrics
31+
batched_metrics = [m for m in metrics if m.batched_compute]
32+
non_batched_metrics = [m for m in metrics if not m.batched_compute]
33+
3034
outputs = []
31-
for metric in metrics:
32-
if metric.batched_compute:
33-
outputs_per_metrics: list = []
34-
35-
outputs_per_metrics.append(metric.compute_sample(responses=responses, docs=docs))
36-
37-
# We merge the outputs per metric in a list of dict for each sample
38-
# example: [{metric1_sample1, metric2_sample1}, {metric1_sample2, metric2_sample2}]
39-
for i in range(len(docs)):
40-
output = {}
41-
for metric_outputs in outputs_per_metrics:
42-
output.update(metric_outputs[i])
43-
outputs.append(output)
44-
45-
else:
46-
for model_response, doc in zip(responses, docs):
47-
output = {}
48-
for metric in metrics:
49-
output.update(
50-
metric.compute_sample(
51-
model_response=model_response,
52-
doc=doc,
53-
)
54-
)
55-
outputs.append(output)
35+
36+
# Handle batched metrics first
37+
batched_outputs = []
38+
if batched_metrics:
39+
for metric in batched_metrics:
40+
metric_outputs = metric.compute_sample(responses=responses, docs=docs)
41+
batched_outputs.append(metric_outputs)
42+
43+
# Initialize outputs with the correct structure
44+
for i in range(len(docs)):
45+
output = {}
46+
47+
# Add batched metric results for this sample
48+
for metric_outputs in batched_outputs:
49+
output.update(metric_outputs[i])
50+
51+
# Add non-batched metric results for this sample
52+
for metric in non_batched_metrics:
53+
output.update(
54+
metric.compute_sample(
55+
model_response=responses[i],
56+
doc=docs[i],
57+
)
58+
)
59+
60+
outputs.append(output)
5661

5762
return outputs

0 commit comments

Comments
 (0)