Skip to content

Commit acef117

Browse files
chuandudxclefourrierNathanHB
authored
Fix BLEURT evaluation errors (#316)
These changes address the issues described in: #315 I made the code changes such that it built on the BERTScore changes (#311) that haven't been merged yet, so we see those changes here. Please let me know if there is preference on removing those from this PR. Thank you! --------- Co-authored-by: Clémentine Fourrier <[email protected]> Co-authored-by: Nathan Habib <[email protected]>
1 parent b4cafa7 commit acef117

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

src/lighteval/metrics/metrics.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,13 @@ class Metrics(Enum):
121121
corpus_level_fn=np.mean,
122122
higher_is_better=True,
123123
)
124+
124125
bleurt = SampleLevelMetric(
125126
metric_name="bleurt",
126-
sample_level_fn=BLEURT.compute,
127+
sample_level_fn=BLEURT().compute,
127128
category=MetricCategory.GENERATIVE,
128129
use_case=MetricUseCase.TRANSLATION,
129-
corpus_level_fn=lambda x: np.mean(x.flatten()), # flatten, then average
130+
corpus_level_fn=np.mean,
130131
higher_is_better=True,
131132
)
132133
byte_perplexity = CorpusLevelMetric(

src/lighteval/metrics/metrics_sample.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,7 @@ def __init__(self):
702702
self.model = AutoModelForSequenceClassification.from_pretrained("Elron/bleurt-tiny-512")
703703
self.model.eval()
704704

705-
def compute(self, golds: list[str], predictions: list[str]) -> float:
705+
def compute(self, golds: list[str], predictions: list[str], **kwargs) -> float:
706706
"""Uses the stored BLEURT scorer to compute the score on the current sample.
707707
708708
Args:
@@ -715,8 +715,7 @@ def compute(self, golds: list[str], predictions: list[str]) -> float:
715715
if len(predictions) == 1:
716716
predictions = predictions * len(golds)
717717
scores = self.model(**self.tokenizer(golds, predictions, return_tensors="pt"))[0].squeeze()
718-
719-
return scores
718+
return scores.item()
720719

721720

722721
class BLEU:

0 commit comments

Comments
 (0)