Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Extended Edit Distance (EED) metric #668

Merged
merged 85 commits into from
Jan 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
8d0b667
add Extended Edit Distance (EED) metric
mathemusician Dec 7, 2021
0febf60
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2021
ac641b5
flake8, mypy, and doctest
mathemusician Dec 7, 2021
8c23dd2
flake8, mypy, doctest, Merge branch 'master' of https://github.com/ma…
mathemusician Dec 7, 2021
1163bf5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2021
aa841e8
doctest
mathemusician Dec 7, 2021
21e178a
flake8 long lines
mathemusician Dec 7, 2021
5a31d8e
flake8 trailing whitespace
mathemusician Dec 7, 2021
4ccb994
alphabetical order
mathemusician Dec 8, 2021
8210ad9
camelCase to lower_case, type-checking, and other style changes
mathemusician Dec 8, 2021
b3afa73
change errors and add np.isclose
mathemusician Dec 8, 2021
aa1a60c
fix imports
mathemusician Dec 8, 2021
d6940ea
fixed weird bug where parallelized metric was giving different answer…
mathemusician Dec 8, 2021
c071e98
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2021
5b32d16
flake8
mathemusician Dec 8, 2021
6329506
Merge branch 'master' of https://github.com/mathemusician/metrics
mathemusician Dec 8, 2021
c6d2cf7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2021
9c957d7
update CHANGELOG.md
mathemusician Dec 8, 2021
c882f2e
Merge branch 'master' of https://github.com/mathemusician/metrics
mathemusician Dec 8, 2021
2ca9884
import Literal keyword from typing_extensions
mathemusician Dec 8, 2021
3269e32
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2021
92a61c3
fix errors from doctest
mathemusician Dec 8, 2021
a04ea27
Merge branch 'master' of https://github.com/mathemusician/metrics
mathemusician Dec 8, 2021
c4087e4
fix bug on doctest example
mathemusician Dec 8, 2021
5b8594d
favor np.isclose to strict value comparison because of how different …
mathemusician Dec 8, 2021
9034430
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2021
a322421
convert np.bool_ type to bool for assertion check
mathemusician Dec 8, 2021
c342d49
Merge branch 'master' of https://github.com/mathemusician/metrics
mathemusician Dec 8, 2021
caef80f
Merge branch 'PyTorchLightning:master' into master
mathemusician Dec 9, 2021
f0f2512
reorder hypotheses and references, rename to hypothesis_corpus and re…
mathemusician Dec 10, 2021
d55c871
reorder references and hypotheses
mathemusician Dec 10, 2021
09c074d
add control for tunable parameters, documentation, better documentation
mathemusician Dec 10, 2021
62fb435
added sentence level scoring
mathemusician Dec 10, 2021
e9be5ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2021
3beda35
more documentation, fixed unused variable, cleanup documentation
mathemusician Dec 10, 2021
d314c8f
Merge branch 'master' of https://github.com/mathemusician/metrics
mathemusician Dec 10, 2021
70acd1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2021
2eb7f95
Apply suggestions from code review
Borda Dec 12, 2021
2f83b67
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2021
0e3f686
Merge branch 'master' into master
Borda Dec 12, 2021
f36bb11
stylistic changes, add ability to return sentence level scores
mathemusician Dec 17, 2021
aea6e63
added TextTester, other stylistic code changes
mathemusician Dec 17, 2021
1eef7e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2021
827b700
Merge branch 'master' into master
Borda Dec 17, 2021
938ceee
Update paper.md (#690)
maximsch2 Dec 21, 2021
67f78ab
ci: rename oldest
Borda Dec 21, 2021
c2a1277
add tuple_of_single_references
mathemusician Dec 22, 2021
d739287
fix differentiability; stylistic changes
mathemusician Dec 22, 2021
c917dab
update test to use TextTester correctly
mathemusician Dec 22, 2021
07eefaf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 22, 2021
6a17dd3
Merge branch 'master' into master
Borda Dec 22, 2021
a34d8c1
Merge branch 'master' into master
Borda Dec 22, 2021
4ccdd5a
Merge branch 'master' into master
Borda Dec 27, 2021
9e5a195
only use necessary tests
mathemusician Dec 29, 2021
e0d6888
add space to beginning and end of sentence
mathemusician Dec 29, 2021
5e12fc7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 29, 2021
3f3d4f6
add more decimal places
mathemusician Dec 30, 2021
21cdb3c
mypy
mathemusician Dec 30, 2021
7b3c44d
Merge branch 'master' of https://github.com/mathemusician/metrics
mathemusician Dec 30, 2021
3d1010c
fix mypy issues
mathemusician Dec 30, 2021
baa4560
fix class_metric_test
mathemusician Dec 30, 2021
ca9fd18
update documentation
mathemusician Dec 30, 2021
8f4f50d
Rename collections.py to metrics_collections.py (#695)
tkupek Dec 28, 2021
d7e2542
add JOSS status badge
Borda Dec 29, 2021
5b2c0c7
uncomment tests
mathemusician Jan 4, 2022
d978d2b
add comments on eed test
mathemusician Jan 4, 2022
a18df4a
switch to using preds, target
mathemusician Jan 4, 2022
c31ad70
[pre-commit.ci] pre-commit suggestions (#705)
pre-commit-ci[bot] Jan 3, 2022
83976a8
use length comparison instead of object comparison
mathemusician Jan 4, 2022
f716e5f
2022
Borda Jan 5, 2022
5953668
rename eed and EED to extended_edit_distance and ExtendedEditDistance…
mathemusician Jan 5, 2022
c6932b5
add comma, flake8
mathemusician Jan 5, 2022
67203f1
use extended_edit_distance at functional/__init__.py
mathemusician Jan 5, 2022
2617b7a
flake8
mathemusician Jan 5, 2022
383f985
update functional.rst
mathemusician Jan 5, 2022
667e610
add tilde to fix docs build
mathemusician Jan 6, 2022
d1e31bf
Apply suggestions from code review
SkafteNicki Jan 7, 2022
1e44e59
Update docs/source/references/modules.rst
SkafteNicki Jan 8, 2022
de5adf1
Update torchmetrics/functional/text/eed.py
SkafteNicki Jan 8, 2022
b88572b
Update tests/text/test_eed.py
SkafteNicki Jan 8, 2022
e0b25c3
fix flake
SkafteNicki Jan 8, 2022
8be9464
Merge branch 'master' into master
Borda Jan 10, 2022
3c3395c
Apply suggestions from code review
Borda Jan 10, 2022
cedc44f
Merge branch 'master' into master
mergify[bot] Jan 10, 2022
be3814e
Merge branch 'master' into master
Borda Jan 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `SQuAD` ([#623](https://github.com/PyTorchLightning/metrics/pull/623))
- `CHRFScore` ([#641](https://github.com/PyTorchLightning/metrics/pull/641))
- `TranslationEditRate` ([#646](https://github.com/PyTorchLightning/metrics/pull/646))
- `ExtendedEditDistance` ([#668](https://github.com/PyTorchLightning/metrics/pull/668))


- Added `MultiScaleSSIM` into image metrics ([#679](https://github.com/PyTorchLightning/metrics/pull/679))

Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,5 @@
.. _chrF score: https://aclanthology.org/W15-3049.pdf
.. _chrF++ score: https://aclanthology.org/W17-4770.pdf
.. _TER: https://aclanthology.org/2006.amta-papers.25.pdf
.. _ExtendedEditDistance: https://aclanthology.org/W19-5359.pdf
.. _MultiScaleSSIM: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf
6 changes: 6 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,12 @@ chrf_score [func]
.. autofunction:: torchmetrics.functional.chrf_score
:noindex:

extended_edit_distance [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.extended_edit_distance
:noindex:

match_error_rate [func]
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,12 @@ CHRFScore
.. autoclass:: torchmetrics.CHRFScore
:noindex:

ExtendedEditDistance
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.ExtendedEditDistance
:noindex:

MatchErrorRate
~~~~~~~~~~~~~~

Expand Down
4 changes: 4 additions & 0 deletions tests/text/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,7 @@
_inputs_error_rate_batch_size_1 = Input(**ERROR_RATES_BATCHES_1)

_inputs_error_rate_batch_size_2 = Input(**ERROR_RATES_BATCHES_2)

# single reference
TUPLE_OF_SINGLE_REFERENCES = (((REFERENCE_1A), (REFERENCE_1B)), ((REFERENCE_1B), (REFERENCE_1C)))
_inputs_single_reference = Input(preds=TUPLE_OF_HYPOTHESES, targets=TUPLE_OF_SINGLE_REFERENCES)
120 changes: 120 additions & 0 deletions tests/text/test_eed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import pytest
from torch import Tensor, tensor

from tests.text.helpers import TextTester
from tests.text.inputs import _inputs_single_reference, _inputs_single_sentence_multiple_references
from torchmetrics.functional.text.eed import extended_edit_distance
from torchmetrics.text.eed import ExtendedEditDistance


def rwth_manual_metric(preds, targets) -> Tensor:
"""The results were obtained w.r.t.

the examples defined in `tests.text.inputs` with the script from https://github.com/rwth-i6/ExtendedEditDistance.
"""
ans_1 = tensor(0.24248056001808083)
ans_2 = tensor(0.19152276295133436)

HYPOTHESIS_A = "It is a guide to action which ensures that the military always obeys the commands of the party"

# If hypothesis A and B are in preds, the average of ans_1 and ans_2 is given
if len(preds) == 4:
return (ans_1 + ans_2) / 2
# If only hypothesis A or B are given, ans_1 and ans_2 are given, respectively
if HYPOTHESIS_A in preds:
return ans_1
return ans_2


@pytest.mark.parametrize(
["preds", "targets"],
[(_inputs_single_reference.preds, _inputs_single_reference.targets)],
)
class TestExtendedEditDistance(TextTester):
@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_eed_class(self, preds, targets, ddp, dist_sync_on_step):
rwth_metric = partial(rwth_manual_metric)
self.run_class_metric_test(
ddp=ddp,
preds=preds,
targets=targets,
metric_class=ExtendedEditDistance,
sk_metric=rwth_metric,
dist_sync_on_step=dist_sync_on_step,
)

def test_eed_functional(self, preds, targets):
rwth_metric = partial(rwth_manual_metric)
self.run_functional_metric_test(
preds,
targets,
metric_functional=extended_edit_distance,
sk_metric=rwth_metric,
)

def test_eed_differentiability(self, preds, targets):
self.run_differentiability_test(
preds=preds,
targets=targets,
metric_module=ExtendedEditDistance,
metric_functional=extended_edit_distance,
)

mathemusician marked this conversation as resolved.
Show resolved Hide resolved

# test blank edge cases
def test_eed_empty_functional():
hyp = []
ref = [[]]
assert extended_edit_distance(hyp, ref) == tensor(0.0)


def test_eed_empty_class():
eed_metric = ExtendedEditDistance()
hyp = []
ref = [[]]
assert eed_metric(hyp, ref) == tensor(0.0)


def test_eed_empty_with_non_empty_hyp_functional():
hyp = ["python"]
ref = [[]]
assert extended_edit_distance(hyp, ref) == tensor(0.0)


def test_eed_empty_with_non_empty_hyp_class():
eed_metric = ExtendedEditDistance()
hyp = ["python"]
ref = [[]]
assert eed_metric(hyp, ref) == tensor(0.0)


def test_eed_return_sentence_level_score_functional():
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, sentence_eed = extended_edit_distance(hyp, ref, return_sentence_level_score=True)
isinstance(sentence_eed, Tensor)


def test_eed_return_sentence_level_class():
metric = ExtendedEditDistance(return_sentence_level_score=True)
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, sentence_eed = metric(hyp, ref)
isinstance(sentence_eed, Tensor)
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
BLEUScore,
CharErrorRate,
CHRFScore,
ExtendedEditDistance,
MatchErrorRate,
SacreBLEUScore,
SQuAD,
Expand Down Expand Up @@ -115,6 +116,7 @@
"CosineSimilarity",
"TweedieDevianceScore",
"ExplainedVariance",
"ExtendedEditDistance",
"F1",
"F1Score",
"FBeta",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from torchmetrics.functional.text.bleu import bleu_score
from torchmetrics.functional.text.cer import char_error_rate
from torchmetrics.functional.text.chrf import chrf_score
from torchmetrics.functional.text.eed import extended_edit_distance
from torchmetrics.functional.text.mer import match_error_rate
from torchmetrics.functional.text.rouge import rouge_score
from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score
Expand All @@ -93,6 +94,7 @@
"tweedie_deviance_score",
"dice_score",
"explained_variance",
"extended_edit_distance",
"f1",
"f1_score",
"fbeta",
Expand Down
Loading