Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RougeL/RougeLSum implementation #944

Merged
merged 7 commits into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed `BestScore` on GPU ([#912](https://github.com/PyTorchLightning/metrics/pull/912))

- Fixed Lsum computation for `ROUGEScore` ([#944](https://github.com/PyTorchLightning/metrics/pull/944))


## [0.7.3] - 2022-03-23

Expand Down
67 changes: 56 additions & 11 deletions tests/text/test_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

import re
from functools import partial
from typing import Callable, Sequence
from typing import Callable, Sequence, Union

import numpy as np
import pytest
import torch
from torch import Tensor
from typing_extensions import Literal

from tests.text.helpers import TextTester
from tests.text.inputs import _inputs_multiple_references, _inputs_single_sentence_single_reference
from tests.text.inputs import Input, _inputs_multiple_references, _inputs_single_sentence_single_reference
from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.text.rouge import ROUGEScore
from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE
Expand All @@ -35,25 +36,36 @@
ROUGE_KEYS = ("rouge1", "rouge2", "rougeL", "rougeLsum")


# Some randomly adjusted input from CNN/DailyMail dataset which brakes the test
_preds = "A lawyer says him .\nMoschetto, 54 and prosecutors say .\nAuthority abc Moschetto ."
_target = "A trainer said her and Moschetto, 54s or weapons say . \nAuthorities Moschetto of ."
_inputs_summarization = Input(preds=_preds, targets=_target)


def _compute_rouge_score(
preds: Sequence[str],
target: Sequence[Sequence[str]],
preds: Union[str, Sequence[str]],
target: Union[str, Sequence[Union[str, Sequence[str]]]],
use_stemmer: bool,
rouge_level: str,
metric: str,
accumulate: str,
):
accumulate: Literal = ["avg", "best", None],
) -> Tensor:
"""Evaluates rouge scores from rouge-score package for baseline evaluation."""
if isinstance(target, list) and all(isinstance(tgt, str) for tgt in target):
target = [target] if isinstance(preds, str) else [[tgt] for tgt in target]

if isinstance(preds, str):
if isinstance(preds, str) and accumulate:
preds = [preds]

if isinstance(target, str):
if isinstance(target, str) and accumulate:
target = [[target]]

scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
if not accumulate:
rs_scores = scorer.score(target, preds)
rs_result = getattr(rs_scores[rouge_level], metric)
return torch.tensor(rs_result, dtype=torch.float)

aggregator = BootstrapAggregator()

for target_raw, pred_raw in zip(target, preds):
Expand All @@ -75,7 +87,7 @@ def _compute_rouge_score(

rs_scores = aggregator.aggregate()
rs_result = getattr(rs_scores[rouge_level].mid, metric)
return rs_result
return torch.tensor(rs_result, dtype=torch.float)


@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk")
Expand Down Expand Up @@ -208,4 +220,37 @@ def test_rouge_metric_normalizer_tokenizer(pl_rouge_metric_key):
)
metrics_score = scorer.compute()

np.isclose(metrics_score[rouge_level + "_" + metric], original_score, atol=1e-8, equal_nan=True)
assert torch.isclose(metrics_score[rouge_level + "_" + metric], original_score)


@pytest.mark.parametrize(
"pl_rouge_metric_key",
[
"rougeL_precision",
"rougeL_recall",
"rougeL_fmeasure",
"rougeLsum_precision",
"rougeLsum_recall",
"rougeLsum_fmeasure",
],
)
@pytest.mark.parametrize("use_stemmer", [False, True])
def test_rouge_lsum_score(pl_rouge_metric_key, use_stemmer):
"""Specific tests to verify the correctness of Rouge-L and Rouge-LSum metric."""
rouge_level, metric = pl_rouge_metric_key.split("_")
original_score = _compute_rouge_score(
preds=_inputs_summarization.preds,
target=_inputs_summarization.targets,
rouge_level=rouge_level,
metric=metric,
accumulate=None,
use_stemmer=use_stemmer,
)

metrics_score = rouge_score(
_inputs_summarization.preds,
_inputs_summarization.targets,
rouge_keys=rouge_level,
use_stemmer=use_stemmer,
)
assert torch.isclose(metrics_score[rouge_level + "_" + metric], original_score)
137 changes: 119 additions & 18 deletions torchmetrics/functional/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@
ALLOWED_ACCUMULATE_VALUES = ("avg", "best")


def _add_newline_to_end_of_each_sentence(x: str) -> str:
"""This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS."""
def _split_sentence(x: str) -> Sequence[str]:
"""The sentence is split to get rougeLsum scores matching published rougeL scores for BART and PEGASUS."""
if not _NLTK_AVAILABLE:
raise ModuleNotFoundError("ROUGE-Lsum calculation requires that `nltk` is installed. Use `pip install nltk`.")
import nltk

nltk.download("punkt", quiet=True, force=False)

re.sub("<n>", "", x) # remove pegasus newline char
return "\n".join(nltk.sent_tokenize(x))
return nltk.sent_tokenize(x)


def _compute_metrics(hits_or_lcs: int, pred_len: int, target_len: int) -> Dict[str, Tensor]:
Expand All @@ -72,7 +72,9 @@ def _compute_metrics(hits_or_lcs: int, pred_len: int, target_len: int) -> Dict[s
return dict(precision=tensor(precision), recall=tensor(recall), fmeasure=tensor(fmeasure))


def _lcs(pred_tokens: Sequence[str], target_tokens: Sequence[str]) -> int:
def _lcs(
pred_tokens: Sequence[str], target_tokens: Sequence[str], return_full_table: bool = False
) -> Union[int, Sequence[Sequence[int]]]:
"""Common DP algorithm to compute the length of the longest common subsequence.

Args:
Expand All @@ -88,9 +90,66 @@ def _lcs(pred_tokens: Sequence[str], target_tokens: Sequence[str]) -> int:
lcs[i][j] = lcs[i - 1][j - 1] + 1
else:
lcs[i][j] = max(lcs[i - 1][j], lcs[i][j - 1])
if return_full_table:
return lcs
return lcs[-1][-1]


def _backtracked_lcs(
lcs_table: Sequence[Sequence[int]], pred_tokens: Sequence[str], target_tokens: Sequence[str]
) -> Sequence[int]:
"""Backtrack LCS table.

Args:
lcs_table:
A table containing information for the calculation of the longest common subsequence.
pred_tokens:
A tokenized predicted sentence.
target_tokens:
A tokenized target sentence.
"""
i = len(pred_tokens)
j = len(target_tokens)
backtracked_lcs: List[int] = []
while i > 0 and j > 0:
if pred_tokens[i - 1] == target_tokens[j - 1]:
backtracked_lcs.insert(0, j - 1)
i -= 1
j -= 1
elif lcs_table[j][i - 1] > lcs_table[j - 1][i]:
i -= 1
else:
j -= 1
return backtracked_lcs


def _union_lcs(pred_tokens_list: Sequence[Sequence[str]], target_tokens: Sequence[str]) -> Sequence[str]:
"""Find union LCS between a target sentence and iterable of predicted tokens.

Args:
pred_tokens_list:
A tokenized predicted sentence split by '\n'.
target_tokens:
A tokenized single part of target sentence split by '\n'.

Return:
"""

def lcs_ind(pred_tokens: Sequence[str], target_tokens: Sequence[str]) -> Sequence[int]:
"""Returns one of the longest of longest common subsequence via backtracked lcs table."""
lcs_table: Sequence[Sequence[int]] = _lcs(pred_tokens, target_tokens, return_full_table=True) # type: ignore
backtracked_lcs_table = _backtracked_lcs(lcs_table, pred_tokens, target_tokens)
return backtracked_lcs_table

def find_union(lcs_tables: Sequence[Sequence[int]]) -> Sequence[int]:
"""Find union LCS given a list of LCS."""
return sorted(list(set().union(*lcs_tables))) # type: ignore

lcs_tables = [lcs_ind(pred_tokens, target_tokens) for pred_tokens in pred_tokens_list]
union_lcs = [target_tokens[i] for i in find_union(lcs_tables)]
return union_lcs


def _normalize_and_tokenize_text(
text: str,
stemmer: Optional[Any] = None,
Expand Down Expand Up @@ -160,7 +219,7 @@ def _create_ngrams(tokens: Sequence[str], n: int) -> Counter:


def _rouge_l_score(pred: Sequence[str], target: Sequence[str]) -> Dict[str, Tensor]:
"""This computes precision, recall and F1 score for the Rouge-L or Rouge-LSum metric.
"""This computes precision, recall and F1 score for the Rouge-L metric.

Args:
pred:
Expand All @@ -172,10 +231,52 @@ def _rouge_l_score(pred: Sequence[str], target: Sequence[str]) -> Dict[str, Tens
if 0 in (pred_len, target_len):
return dict(precision=tensor(0.0), recall=tensor(0.0), fmeasure=tensor(0.0))

lcs = _lcs(pred, target)
lcs: int = _lcs(pred, target) # type: ignore
return _compute_metrics(lcs, pred_len, target_len)


def _rouge_lsum_score(pred: Sequence[Sequence[str]], target: Sequence[Sequence[str]]) -> Dict[str, Tensor]:
"""This computes precision, recall and F1 score for the Rouge-LSum metric. More information can be found in Section
3.2 of the referenced paper [1]. This implementation follow the official implementation from:
https://github.com/google-research/google-research/blob/master/rouge/rouge_scorer.py

Args:
pred:
An iterable of predicted sentence split by '\n'.
target:
An iterable target sentence split by '\n'.

References
[1] ROUGE: A Package for Automatic Evaluation of Summaries by Chin-Yew Lin. https://aclanthology.org/W04-1013/
"""
pred_len = sum(map(len, pred))
target_len = sum(map(len, target))
if 0 in (pred_len, target_len):
return dict(precision=tensor(0.0), recall=tensor(0.0), fmeasure=tensor(0.0))

# Get token counts
def _get_token_counts(sentences: Sequence[Sequence[str]]) -> Counter:
ngrams: Counter = Counter()
for sentence in sentences:
ngrams.update(sentence)
return ngrams

pred_tokens_count = _get_token_counts(pred)
target_tokens_count = _get_token_counts(target)

# Calculate hits
hits = 0
for tgt in target:
lcs = _union_lcs(pred, tgt)
for token in lcs:
if pred_tokens_count[token] > 0 and target_tokens_count[token] > 0:
hits += 1
pred_tokens_count[token] -= 1
target_tokens_count[token] -= 1

return _compute_metrics(hits, pred_len, target_len)


def _rouge_score_update(
preds: Sequence[str],
target: Sequence[Sequence[str]],
Expand Down Expand Up @@ -239,27 +340,27 @@ def _rouge_score_update(
result_avg: Dict[Union[int, str], List[Dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values}
list_results = []
pred = _normalize_and_tokenize_text(pred_raw, stemmer, normalizer, tokenizer)
pred_lsum = _normalize_and_tokenize_text(
_add_newline_to_end_of_each_sentence(pred_raw), stemmer, normalizer, tokenizer
)
pred_lsum = [
_normalize_and_tokenize_text(pred_sentence, stemmer, normalizer, tokenizer)
for pred_sentence in _split_sentence(pred_raw)
]

for target_raw_inner in target_raw:
tgt = _normalize_and_tokenize_text(target_raw_inner, stemmer, normalizer, tokenizer)

if "Lsum" in rouge_keys_values:
# rougeLsum expects "\n" separated sentences within a summary
target_lsum = _normalize_and_tokenize_text(
_add_newline_to_end_of_each_sentence(target_raw_inner), stemmer, normalizer, tokenizer
)
target_lsum = [
_normalize_and_tokenize_text(tgt_sentence, stemmer, normalizer, tokenizer)
for tgt_sentence in _split_sentence(target_raw_inner)
]

for rouge_key in rouge_keys_values:
if isinstance(rouge_key, int):
score = _rouge_n_score(pred, tgt, rouge_key)
else:
score = _rouge_l_score(
pred if rouge_key != "Lsum" else pred_lsum,
tgt if rouge_key != "Lsum" else target_lsum,
)
elif rouge_key == "L":
score = _rouge_l_score(pred, tgt)
elif rouge_key == "Lsum":
score = _rouge_lsum_score(pred_lsum, target_lsum)
result_inner[rouge_key] = score
result_avg[rouge_key].append(score)
list_results.append(result_inner.copy())
Expand Down