Skip to content

Commit ab05ad3

Browse files
Class implementations of faithfulness and extractiveness metrics (#323)
Refactor extractiveness to be a class and modified its instantiation in metrics.py Refactor faithfulness to be a class and modified its instantiation in metrics.py Enable configurable summaCZS model, and configurable input_column. --------- Co-authored-by: Clémentine Fourrier <[email protected]>
1 parent 0ada14b commit ab05ad3

File tree

2 files changed

+106
-22
lines changed

2 files changed

+106
-22
lines changed

src/lighteval/metrics/metrics.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@
4040
ROUGE,
4141
BertScore,
4242
ExactMatches,
43+
Extractiveness,
4344
F1_score,
45+
Faithfulness,
4446
JudgeLLM,
4547
LoglikelihoodAcc,
4648
MajAtK,
4749
Recall,
4850
StringDistance,
4951
acc_golds_likelihood,
50-
extractiveness,
51-
faithfulness,
5252
)
5353
from lighteval.metrics.normalizations import (
5454
LogProbCharNorm,
@@ -175,7 +175,9 @@ class Metrics(Enum):
175175
)
176176
extractiveness = SampleLevelMetricGrouping(
177177
metric_name=["summarization_coverage", "summarization_density", "summarization_compression"],
178-
sample_level_fn=extractiveness,
178+
sample_level_fn=Extractiveness(
179+
normalize_input=remove_braces, normalize_pred=remove_braces_and_strip, input_column="text"
180+
).compute,
179181
category=MetricCategory.GENERATIVE,
180182
use_case=MetricUseCase.SUMMARIZATION,
181183
corpus_level_fn={
@@ -223,7 +225,9 @@ class Metrics(Enum):
223225
)
224226
faithfulness = SampleLevelMetric(
225227
metric_name="summac",
226-
sample_level_fn=faithfulness,
228+
sample_level_fn=Faithfulness(
229+
normalize_input=remove_braces, normalize_pred=remove_braces_and_strip, input_column="text"
230+
).compute,
227231
category=MetricCategory.GENERATIVE,
228232
use_case=MetricUseCase.SUMMARIZATION,
229233
corpus_level_fn=np.mean,

src/lighteval/metrics/metrics_sample.py

Lines changed: 98 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -593,24 +593,104 @@ def compute(self, golds: list[str], predictions: list[str]) -> dict:
593593
return {"BERTScore-P": p[0].item(), "BERTScore-R": r[0].item(), "BERTScore-F": f[0].item()}
594594

595595

596-
# todo: make into clean classes with call to normalizer
597-
def extractiveness(formatted_doc: Doc, predictions: list[str], **kwargs):
598-
inp = remove_braces(formatted_doc.specific["text"])
599-
pred = remove_braces_and_strip(predictions[0])
600-
stats = DataStatsMetric().evaluate_example(pred, inp)
601-
return {
602-
"summarization_coverage": stats["coverage"],
603-
"summarization_density": stats["density"],
604-
"summarization_compression": stats["compression"],
605-
}
606-
607-
608-
# todo: make into clean classes with call to normalizer
609-
def faithfulness(formatted_doc: Doc, predictions: list[str], **kwargs):
610-
inp = remove_braces(formatted_doc.specific["text"])
611-
pred = remove_braces_and_strip(predictions[0])
612-
summac = SummaCZS(granularity="sentence", model_name="vitc", imager_load_cache=False) # , device=device)
613-
return summac.score_one(inp, pred)["score"]
596+
class Extractiveness:
597+
def __init__(
598+
self,
599+
normalize_input: callable = remove_braces,
600+
normalize_pred: callable = remove_braces_and_strip,
601+
input_column: str = "text",
602+
):
603+
"""
604+
Extractiveness metric class.
605+
606+
Args:
607+
normalize_input (callable, optional): Function to normalize the input strings.
608+
Defaults to remove_braces from lighteval.metrics.normalizations if no normalization is applied.
609+
normalize_pred (callable, optional): Function to use to normalize the predicted strings.
610+
Defaults to remove_braces_and_strip from lighteval.metrics.normalizations if no normalization is applied.
611+
input_column (str): Column in the formatted_doc to use for the input. Defaults to "text".
612+
"""
613+
self.stats_metric = None
614+
self.normalize_input = normalize_input
615+
self.normalize_pred = normalize_pred
616+
self.input_column = input_column
617+
618+
def compute(self, predictions: list[str], formatted_doc: Doc, **kwargs) -> dict[str, float]:
619+
"""
620+
Compute the extractiveness of the predictions.
621+
622+
This method calculates coverage, density, and compression scores for a single
623+
prediction against the input text.
624+
625+
Args:
626+
predictions (list[str]): Predicted strings, a list of length 1.
627+
formatted_doc (Doc): The formatted document.
628+
629+
Returns:
630+
dict[str, float]: The extractiveness scores.
631+
"""
632+
if self.stats_metric is None:
633+
self.stats_metric = DataStatsMetric()
634+
635+
inp = formatted_doc.specific[self.input_column]
636+
prediction = predictions[0]
637+
if self.normalize_input:
638+
inp = self.normalize_input(inp)
639+
if self.normalize_pred:
640+
prediction = self.normalize_pred(prediction)
641+
642+
stats = self.stats_metric.evaluate_example(prediction, inp)
643+
return {
644+
"summarization_coverage": stats["coverage"],
645+
"summarization_density": stats["density"],
646+
"summarization_compression": stats["compression"],
647+
}
648+
649+
650+
class Faithfulness:
651+
def __init__(
652+
self,
653+
normalize_input: callable = remove_braces,
654+
normalize_pred: callable = remove_braces_and_strip,
655+
input_column: str = "text",
656+
):
657+
"""
658+
Faithfulness metric class.
659+
660+
Args:
661+
normalize_input (callable, optional): Function to normalize the input strings.
662+
Defaults to remove_braces from lighteval.metrics.normalizations if no normalization is applied.
663+
normalize_pred (callable, optional): Function to use to normalize the predicted strings.
664+
Defaults to remove_braces_and_strip from lighteval.metrics.normalizations if no normalization is applied.
665+
input_column (str): Column in the formatted_doc to use for the input. Defaults to "text".
666+
"""
667+
self.summac = None
668+
self.normalize_input = normalize_input
669+
self.normalize_pred = normalize_pred
670+
self.input_column = input_column
671+
672+
def compute(self, predictions: list[str], formatted_doc: Doc, **kwargs) -> dict[str, float]:
673+
"""
674+
Compute the faithfulness of the predictions.
675+
676+
The SummaCZS (Summary Content Zero-Shot) model is used with configurable granularity and model variation.
677+
678+
Args:
679+
predictions (list[str]): Predicted strings, a list of length 1.
680+
formatted_doc (Doc): The formatted document.
681+
682+
Returns:
683+
dict[str, float]: The faithfulness scores.
684+
"""
685+
if self.summac is None:
686+
SummaCZS(granularity="sentence", model_name="vitc", imager_load_cache=False) # , device=device)
687+
inp = formatted_doc.specific[self.input_column]
688+
prediction = predictions[0]
689+
if self.normalize_input:
690+
inp = self.normalize_input(inp)
691+
if self.normalize_pred:
692+
prediction = self.normalize_pred(prediction)
693+
return self.summac.score_one(inp, prediction)["score"]
614694

615695

616696
class BLEURT:

0 commit comments

Comments
 (0)