Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New metric: EditDistance #1906

Merged
merged 21 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `VisualInformationFidelity` to image package ([#1830](https://github.com/Lightning-AI/torchmetrics/pull/1830))


- Added `EditDistance` to text package ([#1906](https://github.com/Lightning-AI/torchmetrics/pull/1906))


### Changed

-
Expand Down
23 changes: 23 additions & 0 deletions docs/source/text/edit.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Edit Distance
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/summarization.svg
:tags: Text

.. include:: ../links.rst

#############
Edit Distance
#############

Module Interface
________________

.. autoclass:: torchmetrics.text.EditDistance
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.text.edit_distance
:noindex:
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchmetrics.functional.text.bleu import bleu_score
from torchmetrics.functional.text.cer import char_error_rate
from torchmetrics.functional.text.chrf import chrf_score
from torchmetrics.functional.text.edit import edit_distance
from torchmetrics.functional.text.eed import extended_edit_distance
from torchmetrics.functional.text.mer import match_error_rate
from torchmetrics.functional.text.perplexity import perplexity
Expand All @@ -36,6 +37,7 @@
"bleu_score",
"char_error_rate",
"chrf_score",
"edit_distance",
"extended_edit_distance",
"match_error_rate",
"perplexity",
Expand Down
118 changes: 118 additions & 0 deletions src/torchmetrics/functional/text/edit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal, Optional, Sequence, Union

import torch
from torch import Tensor

from torchmetrics.functional.text.helper import _LevenshteinEditDistance as _LE_distance


def _edit_distance_update(
preds: Union[str, Sequence[str]],
target: Union[str, Sequence[str]],
substitution_cost: int = 1,
) -> Tensor:
if isinstance(preds, str):
preds = [preds]
if isinstance(target, str):
target = [target]
if not all(isinstance(x, str) for x in preds):
raise ValueError(f"Expected all values in argument `preds` to be string type, but got {preds}")
if not all(isinstance(x, str) for x in target):
raise ValueError(f"Expected all values in argument `target` to be string type, but got {target}")
if len(preds) != len(target):
raise ValueError(
f"Expected argument `preds` and `target` to have same length, but got {len(preds)} and {len(target)}"
)

distance = [
_LE_distance(t, op_substitute=substitution_cost)(p)[0] for p, t in zip(preds, target) # type: ignore[arg-type]
]
return torch.tensor(distance, dtype=torch.int)


def _edit_distance_compute(
edit_scores: Tensor,
num_elements: Union[Tensor, int],
reduction: Optional[Literal["mean", "sum", "none"]] = "mean",
) -> Tensor:
"""Compute final edit distance reduced over the batch."""
if edit_scores.numel() == 0:
return torch.tensor(0, dtype=torch.int32)
if reduction == "mean":
return edit_scores.sum() / num_elements
if reduction == "sum":
return edit_scores.sum()
if reduction is None or reduction == "none":
return edit_scores
raise ValueError("Expected argument `reduction` to either be 'sum', 'mean', 'none' or None")


def edit_distance(
preds: Union[str, Sequence[str]],
target: Union[str, Sequence[str]],
substitution_cost: int = 1,
reduction: Optional[Literal["mean", "sum", "none"]] = "mean",
) -> Tensor:
"""Calculates the Levenshtein edit distance between two sequences.

The edit distance is the number of characters that need to be substituted, inserted, or deleted, to transform the
predicted text into the reference text. The lower the distance, the more accurate the model is considered to be.

Implementation is similar to `nltk.edit_distance <https://www.nltk.org/_modules/nltk/metrics/distance.html>`_.

Args:
preds: An iterable of predicted texts (strings).
target: An iterable of reference texts (strings).
substitution_cost: The cost of substituting one character for another.
reduction: a method to reduce metric score over samples.

- ``'mean'``: takes the mean over samples
- ``'sum'``: takes the sum over samples
- ``None`` or ``'none'``: return the score per sample

Raises:
ValueError:
If ``preds`` and ``target`` do not have the same length.
ValueError:
If ``preds`` or ``target`` contain non-string values.

Example::
Basic example with two strings. Going from “rain” -> “sain” -> “shin” -> “shine” takes 3 edits:

>>> from torchmetrics.functional.text import edit_distance
>>> edit_distance(["rain"], ["shine"])
tensor(3.)

Example::
Basic example with two strings and substitution cost of 2. Going from “rain” -> “sain” -> “shin” -> “shine”
takes 3 edits, where two of them are substitutions:

>>> from torchmetrics.functional.text import edit_distance
>>> edit_distance(["rain"], ["shine"], substitution_cost=2)
tensor(5.)

Example::
Multiple strings example:

>>> from torchmetrics.functional.text import edit_distance
>>> edit_distance(["rain", "lnaguaeg"], ["shine", "language"], reduction=None)
tensor([3, 4], dtype=torch.int32)
>>> edit_distance(["rain", "lnaguaeg"], ["shine", "language"], reduction="mean")
tensor(3.5000)

"""
distance = _edit_distance_update(preds, target, substitution_cost)
return _edit_distance_compute(distance, num_elements=distance.numel(), reduction=reduction)
52 changes: 24 additions & 28 deletions src/torchmetrics/functional/text/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,37 +51,35 @@ class _EditOperations(str, Enum):
OP_UNDEFINED = "undefined"


class _EditOperationsCost(IntEnum):
"""Enumerations for the Levenhstein edit operation costs."""

OP_INSERT = 1
OP_DELETE = 1
OP_SUBSTITUTE = 1
OP_NOTHING = 0
OP_UNDEFINED = _INT_INFINITY


class _LevenshteinEditDistance:
"""A convenience class for calculating the Levenshtein edit distance.

Class will cache some intermediate values to hasten the calculation. The implementation follows the implemenation
from https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/lib_ter.py, where the most of this
implementation is adapted and copied from.
"""

def __init__(self, reference_tokens: List[str]) -> None:
"""Initialize _LevenshteinEditDistance object.
Args:
reference_tokens: list of reference tokens
op_insert: cost of insertion operation
op_delete: cost of deletion operation
op_substitute: cost of substitution operation
"""

Args:
reference_tokens:
A tokenized reference sentence.
"""
def __init__(
self, reference_tokens: List[str], op_insert: int = 1, op_delete: int = 1, op_substitute: int = 1
) -> None:
self.reference_tokens = reference_tokens
self.reference_len = len(reference_tokens)

self.cache: Dict[str, Tuple[int, str]] = {}
self.cache_size = 0

self.op_insert = op_insert
self.op_delete = op_delete
self.op_substitute = op_substitute
self.op_nothing = 0
self.op_undefined = _INT_INFINITY

def __call__(self, prediction_tokens: List[str]) -> Tuple[int, Tuple[_EditOperations, ...]]:
"""Calculate edit distance between self._words_ref and the hypothesis. Uses cache to skip some computations.

Expand Down Expand Up @@ -140,15 +138,15 @@ def _levenshtein_edit_distance(
for j in range(min_j, max_j):
if j == 0:
edit_distance[i][j] = (
edit_distance[i - 1][j][0] + _EditOperationsCost.OP_DELETE,
edit_distance[i - 1][j][0] + self.op_delete,
_EditOperations.OP_DELETE,
)
else:
if prediction_tokens[i - 1] == self.reference_tokens[j - 1]:
cost_substitute = _EditOperationsCost.OP_NOTHING
cost_substitute = self.op_nothing
operation_substitute = _EditOperations.OP_NOTHING
else:
cost_substitute = _EditOperationsCost.OP_SUBSTITUTE
cost_substitute = self.op_substitute
operation_substitute = _EditOperations.OP_SUBSTITUTE

# Tercom prefers no-op/sub, then insertion, then deletion. But since we flip the trace and compute
Expand All @@ -157,8 +155,8 @@ def _levenshtein_edit_distance(
# Copied from: https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/ter.py.
operations = (
(edit_distance[i - 1][j - 1][0] + cost_substitute, operation_substitute),
(edit_distance[i - 1][j][0] + _EditOperationsCost.OP_DELETE, _EditOperations.OP_DELETE),
(edit_distance[i][j - 1][0] + _EditOperationsCost.OP_INSERT, _EditOperations.OP_INSERT),
(edit_distance[i - 1][j][0] + self.op_delete, _EditOperations.OP_DELETE),
(edit_distance[i][j - 1][0] + self.op_insert, _EditOperations.OP_INSERT),
)

for operation_cost, operation_name in operations:
Expand Down Expand Up @@ -265,8 +263,7 @@ def _find_cache(self, prediction_tokens: List[str]) -> Tuple[int, List[List[Tupl

return start_position, edit_distance

@staticmethod
def _get_empty_row(length: int) -> List[Tuple[int, _EditOperations]]:
def _get_empty_row(self, length: int) -> List[Tuple[int, _EditOperations]]:
"""Precomputed empty matrix row for Levenhstein edit distance.

Args:
Expand All @@ -275,10 +272,9 @@ def _get_empty_row(length: int) -> List[Tuple[int, _EditOperations]]:
Return:
A list of tuples containing infinite edit operation costs and yet undefined edit operations.
"""
return [(int(_EditOperationsCost.OP_UNDEFINED), _EditOperations.OP_UNDEFINED)] * (length + 1)
return [(int(self.op_undefined), _EditOperations.OP_UNDEFINED)] * (length + 1)

@staticmethod
def _get_initial_row(length: int) -> List[Tuple[int, _EditOperations]]:
def _get_initial_row(self, length: int) -> List[Tuple[int, _EditOperations]]:
"""First row corresponds to insertion operations of the reference, so we do 1 edit operation per reference word.

Args:
Expand All @@ -287,7 +283,7 @@ def _get_initial_row(length: int) -> List[Tuple[int, _EditOperations]]:
Return:
A list of tuples containing edit operation costs of insert and insert edit operations.
"""
return [(i * _EditOperationsCost.OP_INSERT, _EditOperations.OP_INSERT) for i in range(length + 1)]
return [(i * self.op_insert, _EditOperations.OP_INSERT) for i in range(length + 1)]


def _validate_inputs(
Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchmetrics.text.bleu import BLEUScore
from torchmetrics.text.cer import CharErrorRate
from torchmetrics.text.chrf import CHRFScore
from torchmetrics.text.edit import EditDistance
from torchmetrics.text.eed import ExtendedEditDistance
from torchmetrics.text.mer import MatchErrorRate
from torchmetrics.text.perplexity import Perplexity
Expand All @@ -35,6 +36,7 @@
"BLEUScore",
"CharErrorRate",
"CHRFScore",
"EditDistance",
"ExtendedEditDistance",
"MatchErrorRate",
"Perplexity",
Expand Down
Loading