Skip to content

Commit 2dc1788

Browse files
rolshovenNathanHB
andauthored
Multilingual extractiveness (#956)
* Added German, French, and Italian language support to Extractiveness metric * Added minimum version for spacy dependency * Added changes from code review * Added missing newline --------- Co-authored-by: Nathan Habib <[email protected]>
1 parent 96a6882 commit 2dc1788

File tree

4 files changed

+87
-21
lines changed

4 files changed

+87
-21
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ extended_tasks = [
110110
s3 = ["s3fs"]
111111
multilingual = [
112112
"stanza",
113-
"spacy[ja,ko,th]",
113+
"spacy[ja,ko,th]>=3.8.0",
114114
"jieba", # for chinese tokenizer
115115
"pyvi", # for vietnamese tokenizer
116116
]

src/lighteval/metrics/imports/data_stats_metric.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,20 @@
2727
import logging
2828
from collections import Counter
2929
from multiprocessing import Pool
30+
from typing import Literal
3031

3132
from lighteval.metrics.imports.data_stats_utils import Fragments
3233
from lighteval.utils.imports import NO_SPACY_ERROR_MSG, is_spacy_available
3334

3435

3536
logger = logging.getLogger(__name__)
3637

37-
38-
_en = None
38+
LANGUAGE_TO_SPACY_MODEL_MAP = {
39+
"en": "en_core_web_sm",
40+
"de": "de_core_news_sm",
41+
"fr": "fr_core_news_sm",
42+
"it": "it_core_news_sm",
43+
}
3944

4045

4146
class Metric:
@@ -51,8 +56,16 @@ def find_ngrams(input_list, n):
5156

5257

5358
class DataStatsMetric(Metric):
54-
def __init__(self, n_gram=3, n_workers=24, case=False, tokenize=True):
55-
"""Data Statistics metric
59+
def __init__(
60+
self,
61+
n_gram: int = 3,
62+
n_workers: int = 24,
63+
case: bool = False,
64+
tokenize: bool = True,
65+
language: Literal["en", "de", "fr", "it"] = "en",
66+
):
67+
"""
68+
Data Statistics metric
5669
Makes use of Newsroom code: \
5770
https://github.com/lil-lab/newsroom/blob/master/newsroom/analyze/fragments.py
5871
Calculates extractive statistics such as coverage, density, compression as
@@ -69,6 +82,9 @@ def __init__(self, n_gram=3, n_workers=24, case=False, tokenize=True):
6982
case (bool): whether to lowercase input before calculating statistics.
7083
tokenize (bool): whether to tokenize the input; otherwise assumes that the input
7184
is a string of space-separated tokens.
85+
language (Literal["en", "de", "fr", "it"]): the language of the input text. This
86+
determines the spaCy model used for tokenization. Currently supports English,
87+
German, French, and Italian.
7288
"""
7389
if not is_spacy_available():
7490
raise ImportError(NO_SPACY_ERROR_MSG)
@@ -78,22 +94,24 @@ def __init__(self, n_gram=3, n_workers=24, case=False, tokenize=True):
7894
self.n_workers = n_workers
7995
self.case = case
8096
self.tokenize = tokenize
97+
self.language = language
98+
self.nlp = None
8199

82-
global _en
100+
spacy_model = LANGUAGE_TO_SPACY_MODEL_MAP.get(self.language, "en_core_web_sm")
83101
try:
84-
_en = spacy.load("en_core_web_sm")
102+
self.nlp = spacy.load(spacy_model)
85103
except OSError:
86-
logger.info("Downloading the spacy en_core_web_sm model\n(don't worry, this will only happen once)")
104+
logger.info(f"Downloading the spacy {spacy_model} model\n(don't worry, this will only happen once)")
87105
from spacy.cli import download
88106

89-
download("en_core_web_sm")
90-
_en = spacy.load("en_core_web_sm")
107+
download(spacy_model)
108+
self.nlp = spacy.load(spacy_model)
91109

92110
def evaluate_example(self, summary, input_text):
93111
if self.tokenize:
94-
input_text = _en(input_text, disable=["tagger", "parser", "ner", "textcat"])
112+
input_text = self.nlp(input_text, disable=["tagger", "parser", "ner", "textcat"])
95113
input_text = [tok.text for tok in input_text]
96-
summary = _en(summary, disable=["tagger", "parser", "ner", "textcat"])
114+
summary = self.nlp(summary, disable=["tagger", "parser", "ner", "textcat"])
97115
summary = [tok.text for tok in summary]
98116
fragments = Fragments(summary, input_text, case=self.case)
99117
coverage = fragments.coverage()

src/lighteval/metrics/metrics.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626
import numpy as np
2727
from aenum import Enum
2828

29-
from lighteval.metrics.dynamic_metrics import (
30-
MultilingualExtractiveMatchMetric,
31-
)
29+
from lighteval.metrics.dynamic_metrics import MultilingualExtractiveMatchMetric
3230
from lighteval.metrics.harness_compatibility.drop import DropMetrics
3331
from lighteval.metrics.harness_compatibility.truthful_qa import TruthfulqaMCMetrics
3432
from lighteval.metrics.metrics_corpus import (
@@ -57,11 +55,7 @@
5755
Recall,
5856
StringDistance,
5957
)
60-
from lighteval.metrics.normalizations import (
61-
bigbench_normalizer,
62-
remove_braces,
63-
remove_braces_and_strip,
64-
)
58+
from lighteval.metrics.normalizations import bigbench_normalizer, remove_braces, remove_braces_and_strip
6559
from lighteval.metrics.sample_preparator import (
6660
GenerativePreparator,
6761
LoglikelihoodPreparator,
@@ -231,6 +225,57 @@ class Metrics(Enum):
231225
"summarization_compression": True,
232226
},
233227
)
228+
extractiveness_de = SampleLevelMetricGrouping(
229+
metric_name=["summarization_coverage", "summarization_density", "summarization_compression"],
230+
sample_level_fn=Extractiveness(
231+
normalize_input=remove_braces, normalize_pred=remove_braces_and_strip, input_column="text", language="de"
232+
),
233+
category=SamplingMethod.GENERATIVE,
234+
corpus_level_fn={
235+
"summarization_coverage": np.mean,
236+
"summarization_density": np.mean,
237+
"summarization_compression": np.mean,
238+
},
239+
higher_is_better={
240+
"summarization_coverage": True,
241+
"summarization_density": True,
242+
"summarization_compression": True,
243+
},
244+
)
245+
extractiveness_fr = SampleLevelMetricGrouping(
246+
metric_name=["summarization_coverage", "summarization_density", "summarization_compression"],
247+
sample_level_fn=Extractiveness(
248+
normalize_input=remove_braces, normalize_pred=remove_braces_and_strip, input_column="text", language="fr"
249+
),
250+
category=SamplingMethod.GENERATIVE,
251+
corpus_level_fn={
252+
"summarization_coverage": np.mean,
253+
"summarization_density": np.mean,
254+
"summarization_compression": np.mean,
255+
},
256+
higher_is_better={
257+
"summarization_coverage": True,
258+
"summarization_density": True,
259+
"summarization_compression": True,
260+
},
261+
)
262+
extractiveness_it = SampleLevelMetricGrouping(
263+
metric_name=["summarization_coverage", "summarization_density", "summarization_compression"],
264+
sample_level_fn=Extractiveness(
265+
normalize_input=remove_braces, normalize_pred=remove_braces_and_strip, input_column="text", language="it"
266+
),
267+
category=SamplingMethod.GENERATIVE,
268+
corpus_level_fn={
269+
"summarization_coverage": np.mean,
270+
"summarization_density": np.mean,
271+
"summarization_compression": np.mean,
272+
},
273+
higher_is_better={
274+
"summarization_coverage": True,
275+
"summarization_density": True,
276+
"summarization_compression": True,
277+
},
278+
)
234279
f1_score = SampleLevelMetric(
235280
metric_name="f1",
236281
sample_level_fn=F1_score(),

src/lighteval/metrics/metrics_sample.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,7 @@ def __init__(
664664
normalize_input: callable = remove_braces,
665665
normalize_pred: callable = remove_braces_and_strip,
666666
input_column: str = "text",
667+
language: Literal["en", "de", "fr", "it"] = "en",
667668
):
668669
"""Extractiveness metric class.
669670
@@ -673,11 +674,13 @@ def __init__(
673674
normalize_pred (callable, optional): Function to use to normalize the predicted strings.
674675
Defaults to remove_braces_and_strip from lighteval.metrics.normalizations if no normalization is applied.
675676
input_column (str): Column in the formatted_doc to use for the input. Defaults to "text".
677+
language (str): Language ISO code for the input text. Defaults to "en".
676678
"""
677679
self.stats_metric = None
678680
self.normalize_input = normalize_input
679681
self.normalize_pred = normalize_pred
680682
self.input_column = input_column
683+
self.language = language
681684

682685
def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict[str, float]:
683686
"""Compute the extractiveness of the predictions.
@@ -694,7 +697,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict[str
694697
dict[str, float]: The extractiveness scores.
695698
"""
696699
if self.stats_metric is None:
697-
self.stats_metric = DataStatsMetric()
700+
self.stats_metric = DataStatsMetric(language=self.language)
698701

699702
inp = doc.specific[self.input_column]
700703
prediction = model_response.final_text[0]

0 commit comments

Comments
 (0)