Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Factored out ROGUE and BLEU metrics (#563)
Browse files Browse the repository at this point in the history
* Factored out ROGUE and BLEU metrics

* Changed docs references

* Updates

Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
karthikrangasai and ethanwharris authored Jul 12, 2021
1 parent a24746a commit b70f940
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 163 deletions.
18 changes: 6 additions & 12 deletions docs/source/code/text.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ Finetuning

.. automodule:: flash.text.seq2seq.core.finetuning

Metrics
*******

.. automodule:: flash.text.seq2seq.core.metrics
.. automodule:: flash.text.seq2seq.core.utils

Summarization
=============

Expand All @@ -55,13 +61,6 @@ Task

.. automodule:: flash.text.seq2seq.summarization.model

Metric
******

.. automodule:: flash.text.seq2seq.summarization.metric

.. automodule:: flash.text.seq2seq.summarization.utils

Translation
===========

Expand All @@ -74,8 +73,3 @@ Task
****

.. automodule:: flash.text.seq2seq.translation.model

Metric
******

.. automodule:: flash.text.seq2seq.translation.metric
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# referenced from
# Library Name: torchtext
# Authors: torchtext authors and @sluks
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
from collections import Counter
from typing import Dict, List, Tuple

import numpy as np
import torch
from torch import tensor
from torchmetrics import Metric

from flash.core.utilities.imports import _requires_extras, _TEXT_AVAILABLE
from flash.text.seq2seq.summarization.utils import add_newline_to_end_of_each_sentence
from flash.text.seq2seq.core.utils import add_newline_to_end_of_each_sentence

if _TEXT_AVAILABLE:
from rouge_score import rouge_scorer
Expand All @@ -27,6 +34,103 @@
AggregateScore, Score, BootstrapAggregator = None, None, object


def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
"""
Counting how many times each word appears in a given text with ngram
Args:
ngram_input_list: A list of translated text or reference texts
n_gram: gram value ranged 1 to 4
Return:
ngram_counter: a collections.Counter object of ngram
"""

ngram_counter = Counter()

for i in range(1, n_gram + 1):
for j in range(len(ngram_input_list) - i + 1):
ngram_key = tuple(ngram_input_list[j:(i + j)])
ngram_counter[ngram_key] += 1

return ngram_counter


class BLEUScore(Metric):
"""
Calculate BLEU score of machine translated text with one or more references.
Example:
>>> translate_corpus = ['the cat is on the mat'.split()]
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> metric = BLEUScore()
>>> metric(translate_corpus, reference_corpus)
tensor(0.7598)
"""

def __init__(self, n_gram: int = 4, smooth: bool = False):
"""
Args:
n_gram: Gram value ranged from 1 to 4 (Default 4)
smooth: Whether or not to apply smoothing – Lin et al. 2004
"""
super().__init__()
self.n_gram = n_gram
self.smooth = smooth

self.add_state("c", tensor(0, dtype=torch.float), dist_reduce_fx="sum")
self.add_state("r", tensor(0, dtype=torch.float), dist_reduce_fx="sum")
self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum")
self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum")

def compute(self):

trans_len = self.c.clone().detach()
ref_len = self.r.clone().detach()

if min(self.numerator) == 0.0:
return tensor(0.0, device=self.r.device)

if self.smooth:
precision_scores = (self.numerator + 1.0) / (self.denominator + 1.0)
else:
precision_scores = self.numerator / self.denominator

log_precision_scores = tensor([1.0 / self.n_gram] * self.n_gram,
device=self.r.device) * torch.log(precision_scores)
geometric_mean = torch.exp(torch.sum(log_precision_scores))
brevity_penalty = (
tensor(1.0, device=self.r.device) if self.c > self.r else torch.exp(1 - (ref_len / trans_len))
)
bleu = brevity_penalty * geometric_mean
return bleu

def update(self, translate_corpus, reference_corpus) -> None:
"""
Actual metric computation
Args:
translate_corpus: An iterable of machine translated corpus
reference_corpus: An iterable of iterables of reference corpus
"""
for (translation, references) in zip(translate_corpus, reference_corpus):
self.c += len(translation)
ref_len_list = [len(ref) for ref in references]
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
self.r += ref_len_list[ref_len_diff.index(min(ref_len_diff))]
translation_counter = _count_ngram(translation, self.n_gram)
reference_counter = Counter()

for ref in references:
reference_counter |= _count_ngram(ref, self.n_gram)

ngram_counter_clip = translation_counter & reference_counter

for counter_clip in ngram_counter_clip:
self.numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]

for counter in translation_counter:
self.denominator[len(counter) - 1] += translation_counter[counter]


class RougeMetric(Metric):
"""
Metric used for automatic summarization. https://www.aclweb.org/anthology/W04-1013/
Expand Down
2 changes: 1 addition & 1 deletion flash/text/seq2seq/summarization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import torch
from torchmetrics import Metric

from flash.text.seq2seq.core.metrics import RougeMetric
from flash.text.seq2seq.core.model import Seq2SeqTask
from flash.text.seq2seq.summarization.metric import RougeMetric


class SummarizationTask(Seq2SeqTask):
Expand Down
121 changes: 0 additions & 121 deletions flash/text/seq2seq/translation/metric.py

This file was deleted.

2 changes: 1 addition & 1 deletion flash/text/seq2seq/translation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import torch
from torchmetrics import Metric

from flash.text.seq2seq.core.metrics import BLEUScore
from flash.text.seq2seq.core.model import Seq2SeqTask
from flash.text.seq2seq.translation.metric import BLEUScore


class TranslationTask(Seq2SeqTask):
Expand Down
Empty file added tests/text/seq2seq/__init__.py
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,16 @@
import pytest
import torch

from flash.text.seq2seq.translation.metric import BLEUScore
from flash.text.seq2seq.core.metrics import BLEUScore, RougeMetric
from tests.helpers.utils import _TEXT_TESTING


@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_rouge():
preds = "My name is John".split()
target = "Is your name John".split()
metric = RougeMetric()
assert torch.allclose(torch.tensor(metric(preds, target)["rouge1_recall"]).float(), torch.tensor(0.25), 1e-4)


@pytest.mark.parametrize("smooth, expected", [(False, 0.7598), (True, 0.8091)])
Expand Down
26 changes: 0 additions & 26 deletions tests/text/seq2seq/summarization/test_metric.py

This file was deleted.

0 comments on commit b70f940

Please sign in to comment.