Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Rouge score #399

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3a23b94
Adding new metric ROUGE Metric for text
karthikrangasai Jul 23, 2021
a398b6f
Added tests for the ROUGE metric
karthikrangasai Jul 23, 2021
bca2f12
Updated docs and imports, added types
karthikrangasai Jul 23, 2021
586e82a
Apply suggestions from code review
Borda Jul 23, 2021
5169561
Applied changes suggested in code review
karthikrangasai Jul 23, 2021
2dfcd9f
Merge branch 'master' into feature/51_add_rouge_score
karthikrangasai Jul 25, 2021
7e9fed1
Updated text dependencies and CHANGELOG
karthikrangasai Jul 25, 2021
8027712
Fix typing issues
karthikrangasai Jul 25, 2021
067c7f0
Updated docs dependencies
karthikrangasai Jul 25, 2021
cba098a
pkg
Borda Jul 26, 2021
506cfc7
pkg
Borda Jul 26, 2021
b0f28bd
set jiwer
Borda Jul 26, 2021
5b5872b
Merge branch 'master' into feature/51_add_rouge_score
karthikrangasai Jul 26, 2021
d5782fe
Merge branch 'master' into feature/51_add_rouge_score
karthikrangasai Jul 27, 2021
49f40dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 27, 2021
a0ef6e7
Simplified the implementation for batches and added more tests.
karthikrangasai Jul 27, 2021
e9ecc62
Updated docs requirements and removed unused imports.
karthikrangasai Jul 27, 2021
8fc553a
Merge branch 'master' into feature/51_add_rouge_score
karthikrangasai Jul 28, 2021
8d2e102
Merge branch 'master' into feature/51_add_rouge_score
Borda Jul 28, 2021
36e3a7d
Merge branch 'master' into feature/51_add_rouge_score
karthikrangasai Jul 28, 2021
af2c103
Fix typing, rigorously check rouge_keys, add tests for rouge_keys err…
karthikrangasai Jul 28, 2021
34f8394
Remove unused imports
karthikrangasai Jul 28, 2021
5fd1160
Apply suggestions from code review
Borda Jul 28, 2021
1439497
Fixed typing and added docstrings for update and compute method
karthikrangasai Jul 29, 2021
71d17f4
Merge branch 'master' into feature/51_add_rouge_score
karthikrangasai Jul 29, 2021
218418f
Changes based on review
karthikrangasai Jul 29, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Symmetric Mean Absolute Percentage error (SMAPE) ([#375](https://github.com/PyTorchLightning/metrics/issues/375))


- Added ROUGE Metric ([#399](https://github.com/PyTorchLightning/metrics/issues/399))
Borda marked this conversation as resolved.
Show resolved Hide resolved


- Allowed passing labels in (n_samples, n_classes) to `AveragePrecision` ([#386](https://github.com/PyTorchLightning/metrics/issues/386))


Expand Down
4 changes: 4 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,10 @@ bleu_score [func]
.. autofunction:: torchmetrics.functional.bleu_score
:noindex:

rouge_score [func]
~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.rouge_score

wer [func]
~~~~~~~~~~
Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,12 @@ BLEUScore
.. autoclass:: torchmetrics.BLEUScore
:noindex:

ROUGEScore
~~~~~~~~~~

.. autoclass:: torchmetrics.ROUGEScore
:noindex:


WER
~~~
Expand Down
1 change: 0 additions & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ phmdoctest>=1.1.1
cloudpickle>=1.3
scikit-learn>=0.24
scikit-image>0.17.1
nltk>=3.6
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

# add extra requirements
-r image.txt
Expand Down
4 changes: 3 additions & 1 deletion requirements/text.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
jiwer==2.2.0
jiwer>=2.2.0
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
nltk>=3.6
rouge-score>=0.0.4
205 changes: 205 additions & 0 deletions tests/text/test_rouge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

import pytest
import torch
from torch import tensor

from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.text.rouge import ROUGEScore
from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE

if _ROUGE_SCORE_AVAILABLE:
from rouge_score.rouge_scorer import RougeScorer
from rouge_score.scoring import BootstrapAggregator
else:
RougeScorer, BootstrapAggregator = object, object

ROUGE_KEYS = ("rouge1", "rouge2", "rougeL", "rougeLsum")

PRECISION = 0
RECALL = 1
F_MEASURE = 2

SINGLE_SENTENCE_EXAMPLE_PREDS = 'The quick brown fox jumps over the lazy dog'
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
SINGLE_SENTENCE_EXAMPLE_TARGET = 'The quick brown dog jumps on the log.'

PREDS = "My name is John".split()
TARGETS = "Is your name John".split()

BATCHES_RS_PREDS = [SINGLE_SENTENCE_EXAMPLE_PREDS]
BATCHES_RS_PREDS.extend(PREDS)
BATCHES_RS_TARGETS = [SINGLE_SENTENCE_EXAMPLE_TARGET]
BATCHES_RS_TARGETS.extend(TARGETS)

BATCHES = [
dict(preds=[SINGLE_SENTENCE_EXAMPLE_PREDS], targets=[SINGLE_SENTENCE_EXAMPLE_TARGET]),
dict(preds=PREDS, targets=TARGETS)
]


def _compute_rouge_score(preds: List[str], targets: List[str], use_stemmer: bool):
scorer = RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
aggregator = BootstrapAggregator()
for pred, target in zip(preds, targets):
aggregator.add_scores(scorer.score(pred, target))
return aggregator.aggregate()


@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason='test requires nltk and rouge-score')
@pytest.mark.parametrize(
["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
[
pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
],
)
def test_rouge_metric_functional_single_sentence(
pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep
):
scorer = RougeScorer(ROUGE_KEYS)
rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_PREDS, SINGLE_SENTENCE_EXAMPLE_TARGET)
rs_output = round(rs_scores[rouge_score_key][metric], decimal_places)

pl_output = rouge_score([SINGLE_SENTENCE_EXAMPLE_PREDS], [SINGLE_SENTENCE_EXAMPLE_TARGET],
newline_sep=newline_sep,
use_stemmer=use_stemmer,
decimal_places=decimal_places)

assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))


@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason='test requires nltk and rouge-score')
@pytest.mark.parametrize(
["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
[
pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
],
)
def test_rouge_metric_functional(
pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep
):
rs_scores = _compute_rouge_score(PREDS, TARGETS, use_stemmer=use_stemmer)
rs_output = round(rs_scores[rouge_score_key].mid[metric], decimal_places)

pl_output = rouge_score(
PREDS, TARGETS, newline_sep=newline_sep, use_stemmer=use_stemmer, decimal_places=decimal_places
)

assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))


@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason='test requires nltk and rouge-score')
@pytest.mark.parametrize(
["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
[
pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
],
)
def test_rouge_metric_class(pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep):
scorer = RougeScorer(ROUGE_KEYS)
rs_scores = scorer.score(SINGLE_SENTENCE_EXAMPLE_PREDS, SINGLE_SENTENCE_EXAMPLE_TARGET)
rs_output = round(rs_scores[rouge_score_key][metric], decimal_places)

rouge = ROUGEScore(newline_sep=newline_sep, use_stemmer=use_stemmer, decimal_places=decimal_places)
pl_output = rouge([SINGLE_SENTENCE_EXAMPLE_PREDS], [SINGLE_SENTENCE_EXAMPLE_TARGET])

assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))


@pytest.mark.skipif(not (_NLTK_AVAILABLE or _ROUGE_SCORE_AVAILABLE), reason='test requires nltk and rouge-score')
@pytest.mark.parametrize(
["pl_rouge_metric_key", "rouge_score_key", "metric", "decimal_places", "use_stemmer", "newline_sep"],
[
pytest.param("rouge1_precision", "rouge1", PRECISION, 1, True, True),
pytest.param("rouge1_recall", "rouge1", RECALL, 2, True, False),
pytest.param("rouge1_fmeasure", "rouge1", F_MEASURE, 3, False, True),
pytest.param("rouge2_precision", "rouge2", PRECISION, 4, False, False),
pytest.param("rouge2_recall", "rouge2", RECALL, 5, True, True),
pytest.param("rouge2_fmeasure", "rouge2", F_MEASURE, 6, True, False),
pytest.param("rougeL_precision", "rougeL", PRECISION, 6, False, True),
pytest.param("rougeL_recall", "rougeL", RECALL, 5, False, False),
pytest.param("rougeL_fmeasure", "rougeL", F_MEASURE, 3, True, True),
pytest.param("rougeLsum_precision", "rougeLsum", PRECISION, 2, True, False),
pytest.param("rougeLsum_recall", "rougeLsum", RECALL, 1, False, True),
pytest.param("rougeLsum_fmeasure", "rougeLsum", F_MEASURE, 8, False, False),
],
)
def test_rouge_metric_class_batches(
pl_rouge_metric_key, rouge_score_key, metric, decimal_places, use_stemmer, newline_sep
):
rs_scores = _compute_rouge_score(BATCHES_RS_PREDS, BATCHES_RS_TARGETS, use_stemmer=use_stemmer)
rs_output = round(rs_scores[rouge_score_key].mid[metric], decimal_places)

rouge = ROUGEScore(newline_sep=newline_sep, use_stemmer=use_stemmer, decimal_places=decimal_places)
for batch in BATCHES:
rouge.update(batch['preds'], batch['targets'])
pl_output = rouge.compute()

assert torch.allclose(pl_output[pl_rouge_metric_key], tensor(rs_output, dtype=torch.float32))


def test_rouge_metric_raises_errors_and_warnings():
""" Test that expected warnings and errors are raised """
if not (_NLTK_AVAILABLE and _ROUGE_SCORE_AVAILABLE):
with pytest.raises(
ValueError,
match='ROUGE metric requires that both nltk and rouge-score is installed.'
'Either as `pip install torchmetrics[text]` or `pip install nltk rouge-score`'
):
ROUGEScore()


def test_rouge_metric_wrong_key_value_error():
key = ("rouge1", "rouge")

with pytest.raises(ValueError):
ROUGEScore(rouge_keys=key)

with pytest.raises(ValueError):
rouge_score(PREDS, TARGETS, rouge_keys=key)
2 changes: 1 addition & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,5 @@
RetrievalPrecision,
RetrievalRecall,
)
from torchmetrics.text import WER, BLEUScore # noqa: E402, F401
from torchmetrics.text import WER, BLEUScore, ROUGEScore # noqa: E402, F401
from torchmetrics.wrappers import BootStrapper # noqa: E402, F401
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,5 @@
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
from torchmetrics.functional.self_supervised import embedding_similarity # noqa: F401
from torchmetrics.functional.text.bleu import bleu_score # noqa: F401
from torchmetrics.functional.text.rouge import rouge_score # noqa: F401
from torchmetrics.functional.text.wer import wer # noqa: F401
Loading