diff --git a/CHANGELOG.md b/CHANGELOG.md index 55742672139..f8139fff992 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -132,6 +132,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed `BestScore` on GPU ([#912](https://github.com/PyTorchLightning/metrics/pull/912)) +- Fixed Lsum computation for `ROUGEScore` ([#944](https://github.com/PyTorchLightning/metrics/pull/944)) + ## [0.7.3] - 2022-03-23 diff --git a/tests/text/test_rouge.py b/tests/text/test_rouge.py index 91446ad7b1e..26cf7480405 100644 --- a/tests/text/test_rouge.py +++ b/tests/text/test_rouge.py @@ -14,14 +14,15 @@ import re from functools import partial -from typing import Callable, Sequence +from typing import Callable, Sequence, Union -import numpy as np import pytest import torch +from torch import Tensor +from typing_extensions import Literal from tests.text.helpers import TextTester -from tests.text.inputs import _inputs_multiple_references, _inputs_single_sentence_single_reference +from tests.text.inputs import Input, _inputs_multiple_references, _inputs_single_sentence_single_reference from torchmetrics.functional.text.rouge import rouge_score from torchmetrics.text.rouge import ROUGEScore from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE @@ -35,25 +36,36 @@ ROUGE_KEYS = ("rouge1", "rouge2", "rougeL", "rougeLsum") +# Some randomly adjusted input from CNN/DailyMail dataset which brakes the test +_preds = "A lawyer says him .\nMoschetto, 54 and prosecutors say .\nAuthority abc Moschetto ." +_target = "A trainer said her and Moschetto, 54s or weapons say . \nAuthorities Moschetto of ." +_inputs_summarization = Input(preds=_preds, targets=_target) + + def _compute_rouge_score( - preds: Sequence[str], - target: Sequence[Sequence[str]], + preds: Union[str, Sequence[str]], + target: Union[str, Sequence[Union[str, Sequence[str]]]], use_stemmer: bool, rouge_level: str, metric: str, - accumulate: str, -): + accumulate: Literal = ["avg", "best", None], +) -> Tensor: """Evaluates rouge scores from rouge-score package for baseline evaluation.""" if isinstance(target, list) and all(isinstance(tgt, str) for tgt in target): target = [target] if isinstance(preds, str) else [[tgt] for tgt in target] - if isinstance(preds, str): + if isinstance(preds, str) and accumulate: preds = [preds] - if isinstance(target, str): + if isinstance(target, str) and accumulate: target = [[target]] scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) + if not accumulate: + rs_scores = scorer.score(target, preds) + rs_result = getattr(rs_scores[rouge_level], metric) + return torch.tensor(rs_result, dtype=torch.float) + aggregator = BootstrapAggregator() for target_raw, pred_raw in zip(target, preds): @@ -75,7 +87,7 @@ def _compute_rouge_score( rs_scores = aggregator.aggregate() rs_result = getattr(rs_scores[rouge_level].mid, metric) - return rs_result + return torch.tensor(rs_result, dtype=torch.float) @pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk") @@ -208,4 +220,37 @@ def test_rouge_metric_normalizer_tokenizer(pl_rouge_metric_key): ) metrics_score = scorer.compute() - np.isclose(metrics_score[rouge_level + "_" + metric], original_score, atol=1e-8, equal_nan=True) + assert torch.isclose(metrics_score[rouge_level + "_" + metric], original_score) + + +@pytest.mark.parametrize( + "pl_rouge_metric_key", + [ + "rougeL_precision", + "rougeL_recall", + "rougeL_fmeasure", + "rougeLsum_precision", + "rougeLsum_recall", + "rougeLsum_fmeasure", + ], +) +@pytest.mark.parametrize("use_stemmer", [False, True]) +def test_rouge_lsum_score(pl_rouge_metric_key, use_stemmer): + """Specific tests to verify the correctness of Rouge-L and Rouge-LSum metric.""" + rouge_level, metric = pl_rouge_metric_key.split("_") + original_score = _compute_rouge_score( + preds=_inputs_summarization.preds, + target=_inputs_summarization.targets, + rouge_level=rouge_level, + metric=metric, + accumulate=None, + use_stemmer=use_stemmer, + ) + + metrics_score = rouge_score( + _inputs_summarization.preds, + _inputs_summarization.targets, + rouge_keys=rouge_level, + use_stemmer=use_stemmer, + ) + assert torch.isclose(metrics_score[rouge_level + "_" + metric], original_score) diff --git a/torchmetrics/functional/text/rouge.py b/torchmetrics/functional/text/rouge.py index 8670278d654..1bb1d92c12e 100644 --- a/torchmetrics/functional/text/rouge.py +++ b/torchmetrics/functional/text/rouge.py @@ -39,8 +39,8 @@ ALLOWED_ACCUMULATE_VALUES = ("avg", "best") -def _add_newline_to_end_of_each_sentence(x: str) -> str: - """This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS.""" +def _split_sentence(x: str) -> Sequence[str]: + """The sentence is split to get rougeLsum scores matching published rougeL scores for BART and PEGASUS.""" if not _NLTK_AVAILABLE: raise ModuleNotFoundError("ROUGE-Lsum calculation requires that `nltk` is installed. Use `pip install nltk`.") import nltk @@ -48,7 +48,7 @@ def _add_newline_to_end_of_each_sentence(x: str) -> str: nltk.download("punkt", quiet=True, force=False) re.sub("", "", x) # remove pegasus newline char - return "\n".join(nltk.sent_tokenize(x)) + return nltk.sent_tokenize(x) def _compute_metrics(hits_or_lcs: int, pred_len: int, target_len: int) -> Dict[str, Tensor]: @@ -72,7 +72,9 @@ def _compute_metrics(hits_or_lcs: int, pred_len: int, target_len: int) -> Dict[s return dict(precision=tensor(precision), recall=tensor(recall), fmeasure=tensor(fmeasure)) -def _lcs(pred_tokens: Sequence[str], target_tokens: Sequence[str]) -> int: +def _lcs( + pred_tokens: Sequence[str], target_tokens: Sequence[str], return_full_table: bool = False +) -> Union[int, Sequence[Sequence[int]]]: """Common DP algorithm to compute the length of the longest common subsequence. Args: @@ -88,9 +90,66 @@ def _lcs(pred_tokens: Sequence[str], target_tokens: Sequence[str]) -> int: lcs[i][j] = lcs[i - 1][j - 1] + 1 else: lcs[i][j] = max(lcs[i - 1][j], lcs[i][j - 1]) + if return_full_table: + return lcs return lcs[-1][-1] +def _backtracked_lcs( + lcs_table: Sequence[Sequence[int]], pred_tokens: Sequence[str], target_tokens: Sequence[str] +) -> Sequence[int]: + """Backtrack LCS table. + + Args: + lcs_table: + A table containing information for the calculation of the longest common subsequence. + pred_tokens: + A tokenized predicted sentence. + target_tokens: + A tokenized target sentence. + """ + i = len(pred_tokens) + j = len(target_tokens) + backtracked_lcs: List[int] = [] + while i > 0 and j > 0: + if pred_tokens[i - 1] == target_tokens[j - 1]: + backtracked_lcs.insert(0, j - 1) + i -= 1 + j -= 1 + elif lcs_table[j][i - 1] > lcs_table[j - 1][i]: + i -= 1 + else: + j -= 1 + return backtracked_lcs + + +def _union_lcs(pred_tokens_list: Sequence[Sequence[str]], target_tokens: Sequence[str]) -> Sequence[str]: + """Find union LCS between a target sentence and iterable of predicted tokens. + + Args: + pred_tokens_list: + A tokenized predicted sentence split by '\n'. + target_tokens: + A tokenized single part of target sentence split by '\n'. + + Return: + """ + + def lcs_ind(pred_tokens: Sequence[str], target_tokens: Sequence[str]) -> Sequence[int]: + """Returns one of the longest of longest common subsequence via backtracked lcs table.""" + lcs_table: Sequence[Sequence[int]] = _lcs(pred_tokens, target_tokens, return_full_table=True) # type: ignore + backtracked_lcs_table = _backtracked_lcs(lcs_table, pred_tokens, target_tokens) + return backtracked_lcs_table + + def find_union(lcs_tables: Sequence[Sequence[int]]) -> Sequence[int]: + """Find union LCS given a list of LCS.""" + return sorted(list(set().union(*lcs_tables))) # type: ignore + + lcs_tables = [lcs_ind(pred_tokens, target_tokens) for pred_tokens in pred_tokens_list] + union_lcs = [target_tokens[i] for i in find_union(lcs_tables)] + return union_lcs + + def _normalize_and_tokenize_text( text: str, stemmer: Optional[Any] = None, @@ -160,7 +219,7 @@ def _create_ngrams(tokens: Sequence[str], n: int) -> Counter: def _rouge_l_score(pred: Sequence[str], target: Sequence[str]) -> Dict[str, Tensor]: - """This computes precision, recall and F1 score for the Rouge-L or Rouge-LSum metric. + """This computes precision, recall and F1 score for the Rouge-L metric. Args: pred: @@ -172,10 +231,52 @@ def _rouge_l_score(pred: Sequence[str], target: Sequence[str]) -> Dict[str, Tens if 0 in (pred_len, target_len): return dict(precision=tensor(0.0), recall=tensor(0.0), fmeasure=tensor(0.0)) - lcs = _lcs(pred, target) + lcs: int = _lcs(pred, target) # type: ignore return _compute_metrics(lcs, pred_len, target_len) +def _rouge_lsum_score(pred: Sequence[Sequence[str]], target: Sequence[Sequence[str]]) -> Dict[str, Tensor]: + """This computes precision, recall and F1 score for the Rouge-LSum metric. More information can be found in Section + 3.2 of the referenced paper [1]. This implementation follow the official implementation from: + https://github.com/google-research/google-research/blob/master/rouge/rouge_scorer.py + + Args: + pred: + An iterable of predicted sentence split by '\n'. + target: + An iterable target sentence split by '\n'. + + References + [1] ROUGE: A Package for Automatic Evaluation of Summaries by Chin-Yew Lin. https://aclanthology.org/W04-1013/ + """ + pred_len = sum(map(len, pred)) + target_len = sum(map(len, target)) + if 0 in (pred_len, target_len): + return dict(precision=tensor(0.0), recall=tensor(0.0), fmeasure=tensor(0.0)) + + # Get token counts + def _get_token_counts(sentences: Sequence[Sequence[str]]) -> Counter: + ngrams: Counter = Counter() + for sentence in sentences: + ngrams.update(sentence) + return ngrams + + pred_tokens_count = _get_token_counts(pred) + target_tokens_count = _get_token_counts(target) + + # Calculate hits + hits = 0 + for tgt in target: + lcs = _union_lcs(pred, tgt) + for token in lcs: + if pred_tokens_count[token] > 0 and target_tokens_count[token] > 0: + hits += 1 + pred_tokens_count[token] -= 1 + target_tokens_count[token] -= 1 + + return _compute_metrics(hits, pred_len, target_len) + + def _rouge_score_update( preds: Sequence[str], target: Sequence[Sequence[str]], @@ -239,27 +340,27 @@ def _rouge_score_update( result_avg: Dict[Union[int, str], List[Dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values} list_results = [] pred = _normalize_and_tokenize_text(pred_raw, stemmer, normalizer, tokenizer) - pred_lsum = _normalize_and_tokenize_text( - _add_newline_to_end_of_each_sentence(pred_raw), stemmer, normalizer, tokenizer - ) + pred_lsum = [ + _normalize_and_tokenize_text(pred_sentence, stemmer, normalizer, tokenizer) + for pred_sentence in _split_sentence(pred_raw) + ] for target_raw_inner in target_raw: tgt = _normalize_and_tokenize_text(target_raw_inner, stemmer, normalizer, tokenizer) if "Lsum" in rouge_keys_values: - # rougeLsum expects "\n" separated sentences within a summary - target_lsum = _normalize_and_tokenize_text( - _add_newline_to_end_of_each_sentence(target_raw_inner), stemmer, normalizer, tokenizer - ) + target_lsum = [ + _normalize_and_tokenize_text(tgt_sentence, stemmer, normalizer, tokenizer) + for tgt_sentence in _split_sentence(target_raw_inner) + ] for rouge_key in rouge_keys_values: if isinstance(rouge_key, int): score = _rouge_n_score(pred, tgt, rouge_key) - else: - score = _rouge_l_score( - pred if rouge_key != "Lsum" else pred_lsum, - tgt if rouge_key != "Lsum" else target_lsum, - ) + elif rouge_key == "L": + score = _rouge_l_score(pred, tgt) + elif rouge_key == "Lsum": + score = _rouge_lsum_score(pred_lsum, target_lsum) result_inner[rouge_key] = score result_avg[rouge_key].append(score) list_results.append(result_inner.copy())