Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Blue Score the respective folders #360

Merged
merged 24 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
77645b9
Added Blue Score the respective folders
karthikrangasai Jul 8, 2021
1b9bcf1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2021
be4ccbd
File naming correction and moved existing tests
karthikrangasai Jul 9, 2021
70cf0bb
Fixes from pre-commit
karthikrangasai Jul 11, 2021
672ea95
Updated function definitions to be in sync with nltk style
karthikrangasai Jul 11, 2021
b60e474
Added Blue Score the respective folders
karthikrangasai Jul 8, 2021
c03e97c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2021
dad9026
File naming correction and moved existing tests
karthikrangasai Jul 9, 2021
ee6166a
Fixes from pre-commit
karthikrangasai Jul 11, 2021
60afb0b
Updated function definitions to be in sync with nltk style
karthikrangasai Jul 11, 2021
c15958d
Merge branch 'PyTorchLightning:master' into feature/352_add_blue_score
karthikrangasai Jul 11, 2021
7d4389c
Merge branch 'feature/352_add_blue_score' of https://github.com/karth…
karthikrangasai Jul 11, 2021
39ef530
Updated docs references in the rst files to reflect in the HTML.
karthikrangasai Jul 12, 2021
a38a409
Made naming changes for consistency, updated references in docs, adde…
karthikrangasai Jul 12, 2021
5b4f938
Added functional/nlp.py back with Deprecation Warning for current sup…
karthikrangasai Jul 12, 2021
853ffd4
Fixed import error
karthikrangasai Jul 13, 2021
819376f
Updated docstring for deprecation and added tests for metric computat…
karthikrangasai Jul 15, 2021
6d5e599
Merge branch 'master' into feature/352_add_blue_score
SkafteNicki Jul 15, 2021
9273fb8
deprecate
Borda Jul 15, 2021
89e61c2
chlog
Borda Jul 15, 2021
b5e5e41
types
Borda Jul 15, 2021
8c15a6b
Fixing doctests, updating test variables types
karthikrangasai Jul 16, 2021
7c466bb
Merge branch 'master' into feature/352_add_blue_score
SkafteNicki Jul 17, 2021
6284f6a
type
Borda Jul 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,6 @@ ssim [func]
.. autofunction:: torchmetrics.functional.ssim
:noindex:


***
NLP
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we call it NLP or Text?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this. NLP also includes speech processing. So, if we are to add those metrics as well then we can call it NLP.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, speech processing you mean conversion from audio > text and back, right?
bu then you still shall measure the quality against each independently as your prediction and target are always other audio or text, right? so we can split NLP as text and audio... 🐰
cc: @SkafteNicki @maximsch2

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer text, as the data modality that bleu works on is text, similar to how we have grouped other metrics based on data modality.
Also, this does not matter for the end users, as all modular metrics can just be imported with from torchmetrics import * and functional `from torchmetrics.functional import *

***

bleu_score [func]
~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.bleu_score
:noindex:

********
Pairwise
********
Expand Down Expand Up @@ -346,3 +335,13 @@ retrieval_normalized_dcg [func]

.. autofunction:: torchmetrics.functional.retrieval_normalized_dcg
:noindex:

****
Text
****

bleu_score [func]
~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.bleu_score
:noindex:
10 changes: 10 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,16 @@ RetrievalNormalizedDCG
.. autoclass:: torchmetrics.RetrievalNormalizedDCG
:noindex:

****
Text
****

BLEUScore
~~~~~~~~~

.. autoclass:: torchmetrics.BLEUScore
:noindex:


********
Wrappers
Expand Down
82 changes: 74 additions & 8 deletions tests/functional/test_nlp.py → tests/text/test_blue.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu
from torch import tensor

from torchmetrics.functional import bleu_score
from torchmetrics.functional.text.bleu import bleu_score
from torchmetrics.text.bleu import BLEUScore

# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu
Expand All @@ -42,6 +44,11 @@
LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]]
HYPOTHESES = [HYP1, HYP2]

BATCHES = [
dict(reference_corpus=[[REF1A, REF1B, REF1C]], translate_corpus=[HYP1]),
dict(reference_corpus=[[REF2A]], translate_corpus=[HYP2])
]

# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction
smooth_func = SmoothingFunction().method2

Expand All @@ -55,28 +62,87 @@
pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True),
],
)
def test_bleu_score(weights, n_gram, smooth_func, smooth):
def test_bleu_score_functional(weights, n_gram, smooth_func, smooth):
nltk_output = sentence_bleu(
[REFERENCE1, REFERENCE2, REFERENCE3],
HYPOTHESIS1,
weights=weights,
smoothing_function=smooth_func,
)
pl_output = bleu_score([[REFERENCE1, REFERENCE2, REFERENCE3]], [HYPOTHESIS1], n_gram=n_gram, smooth=smooth)
assert torch.allclose(pl_output, tensor(nltk_output))

nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)
pl_output = bleu_score(LIST_OF_REFERENCES, HYPOTHESES, n_gram=n_gram, smooth=smooth)
assert torch.allclose(pl_output, tensor(nltk_output))


def test_bleu_empty_functional():
hyp = [[]]
ref = [[[]]]
assert bleu_score(ref, hyp) == tensor(0.0)


def test_no_4_gram_functional():
hyps = [["My", "full", "pytorch-lightning"]]
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
assert bleu_score(refs, hyps) == tensor(0.0)


@pytest.mark.parametrize(
["weights", "n_gram", "smooth_func", "smooth"],
[
pytest.param([1], 1, None, False),
pytest.param([0.5, 0.5], 2, smooth_func, True),
pytest.param([0.333333, 0.333333, 0.333333], 3, None, False),
pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True),
],
)
def test_bleu_score_class(weights, n_gram, smooth_func, smooth):
bleu = BLEUScore(n_gram=n_gram, smooth=smooth)
nltk_output = sentence_bleu(
[REFERENCE1, REFERENCE2, REFERENCE3],
HYPOTHESIS1,
weights=weights,
smoothing_function=smooth_func,
)
pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth)
pl_output = bleu([[REFERENCE1, REFERENCE2, REFERENCE3]], [HYPOTHESIS1])
assert torch.allclose(pl_output, tensor(nltk_output))

nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)
pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth)
pl_output = bleu(LIST_OF_REFERENCES, HYPOTHESES)
assert torch.allclose(pl_output, tensor(nltk_output))


@pytest.mark.parametrize(
["weights", "n_gram", "smooth_func", "smooth"],
[
pytest.param([1], 1, None, False),
pytest.param([0.5, 0.5], 2, smooth_func, True),
pytest.param([0.333333, 0.333333, 0.333333], 3, None, False),
pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True),
],
)
def test_bleu_score_class_batches(weights, n_gram, smooth_func, smooth):
bleu = BLEUScore(n_gram=n_gram, smooth=smooth)

nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)

for batch in BATCHES:
bleu.update(batch['reference_corpus'], batch['translate_corpus'])
pl_output = bleu.compute()
assert torch.allclose(pl_output, tensor(nltk_output))


def test_bleu_empty():
def test_bleu_empty_class():
bleu = BLEUScore()
hyp = [[]]
ref = [[[]]]
assert bleu_score(hyp, ref) == tensor(0.0)
assert bleu(ref, hyp) == tensor(0.0)


def test_no_4_gram():
def test_no_4_gram_class():
bleu = BLEUScore()
hyps = [["My", "full", "pytorch-lightning"]]
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
assert bleu_score(hyps, refs) == tensor(0.0)
assert bleu(refs, hyps) == tensor(0.0)
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,5 @@
RetrievalPrecision,
RetrievalRecall,
)
from torchmetrics.text import BLEUScore # noqa: F401 E402
from torchmetrics.wrappers import BootStrapper # noqa: F401 E402
2 changes: 1 addition & 1 deletion torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from torchmetrics.functional.classification.specificity import specificity # noqa: F401
from torchmetrics.functional.classification.stat_scores import stat_scores # noqa: F401
from torchmetrics.functional.image_gradients import image_gradients # noqa: F401
from torchmetrics.functional.nlp import bleu_score # noqa: F401
from torchmetrics.functional.regression.cosine_similarity import cosine_similarity # noqa: F401
from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401
from torchmetrics.functional.regression.mean_absolute_error import mean_absolute_error # noqa: F401
Expand All @@ -55,3 +54,4 @@
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
from torchmetrics.functional.self_supervised import embedding_similarity # noqa: F401
from torchmetrics.functional.text.bleu import bleu_score # noqa: F401
103 changes: 32 additions & 71 deletions torchmetrics/functional/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,54 +11,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# referenced from
# Library Name: torchtext
# Authors: torchtext authors and @sluks
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
from collections import Counter
from typing import Sequence
from warnings import warn

import torch
from torch import Tensor, tensor


def _count_ngram(ngram_input_list: Sequence[str], n_gram: int) -> Counter:
"""
Counting how many times each word appears in a given text with ngram

Args:
ngram_input_list: A list of translated text or reference texts
n_gram: gram value ranged 1 to 4

Return:
ngram_counter: a collections.Counter object of ngram
"""

ngram_counter: Counter = Counter()

for i in range(1, n_gram + 1):
for j in range(len(ngram_input_list) - i + 1):
ngram_key = tuple(ngram_input_list[j:(i + j)])
ngram_counter[ngram_key] += 1

return ngram_counter
from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update


def bleu_score(
translate_corpus: Sequence[Sequence[str]],
reference_corpus: Sequence[Sequence[Sequence[str]]],
translate_corpus: Sequence[Sequence[str]],
n_gram: int = 4,
smooth: bool = False
) -> Tensor:
"""
Calculate BLEU score of machine translated text with one or more references
Calculate `BLEU score <https://en.wikipedia.org/wiki/BLEU>`_ of machine translated text with one or more references

Args:
translate_corpus: An iterable of machine translated corpus
reference_corpus: An iterable of iterables of reference corpus
n_gram: Gram value ranged from 1 to 4
smooth: Whether or not to apply smoothing – Lin et al. 2004
reference_corpus:
An iterable of iterables of reference corpus
translate_corpus:
An iterable of machine translated corpus
n_gram:
Gram value ranged from 1 to 4 (Default 4)
smooth:
Whether or not to apply smoothing – see [2]

Return:
Tensor with BLEU Score
Expand All @@ -69,49 +48,31 @@ def bleu_score(
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> bleu_score(translate_corpus, reference_corpus)
tensor(0.7598)

References:
[1] BLEU: a Method for Automatic Evaluation of Machine Translation by Papineni,
Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu http://www.aclweb.org/anthology/P02-1040.pdf

[2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence
and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och https://aclanthology.org/P04-1077.pdf

.. deprecated:: v0.5
Use :func:`torchmetrics.functional.text.bleu.bleu_score`. Will be removed in v0.6.
"""
warn(
"Function `functional.nlp.bleu_score` will be deprecated in v0.5 and will be removed in v0.6."
"Use `functional.text.bleu.bleu_score` instead.", DeprecationWarning
)

if len(translate_corpus) != len(reference_corpus):
raise ValueError(f"Corpus has different size {len(translate_corpus)} != {len(reference_corpus)}")
numerator = torch.zeros(n_gram)
denominator = torch.zeros(n_gram)
c = 0.0
r = 0.0

for (translation, references) in zip(translate_corpus, reference_corpus):
c += len(translation)
ref_len_list = [len(ref) for ref in references]
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
r += ref_len_list[ref_len_diff.index(min(ref_len_diff))]
translation_counter: Counter = _count_ngram(translation, n_gram)
reference_counter: Counter = Counter()

for ref in references:
reference_counter |= _count_ngram(ref, n_gram)

ngram_counter_clip = translation_counter & reference_counter

for counter_clip in ngram_counter_clip:
numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]

for counter in translation_counter:
denominator[len(counter) - 1] += translation_counter[counter]

trans_len = tensor(c)
ref_len = tensor(r)

if min(numerator) == 0.0:
return tensor(0.0)

if smooth:
precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram))
precision_scores[0] = numerator[0] / denominator[0]
else:
precision_scores = numerator / denominator
trans_len = tensor(0, dtype=torch.float)
ref_len = tensor(0, dtype=torch.float)

log_precision_scores = tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores)
geometric_mean = torch.exp(torch.sum(log_precision_scores))
brevity_penalty = tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len))
bleu = brevity_penalty * geometric_mean
trans_len, ref_len = _bleu_score_update(
reference_corpus, translate_corpus, numerator, denominator, trans_len, ref_len, n_gram
)

return bleu
return _bleu_score_compute(trans_len, ref_len, numerator, denominator, n_gram, smooth)
15 changes: 15 additions & 0 deletions torchmetrics/functional/text/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torchmetrics.functional.text.bleu import bleu_score # noqa: F401
Loading