diff --git a/CHANGELOG.md b/CHANGELOG.md index 40ba935bf4b..73795d4194e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1621](https://github.com/Lightning-AI/metrics/pull/1621), [#1624](https://github.com/Lightning-AI/metrics/pull/1624), [#1623](https://github.com/Lightning-AI/metrics/pull/1623), + [#1631](https://github.com/Lightning-AI/metrics/pull/1631), ) diff --git a/src/torchmetrics/text/cer.py b/src/torchmetrics/text/cer.py index 6dcd06cc9a2..4d90ad1b754 100644 --- a/src/torchmetrics/text/cer.py +++ b/src/torchmetrics/text/cer.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Any, List, Union +from typing import Any, List, Optional, Sequence, Union import torch from torch import Tensor, tensor from torchmetrics.functional.text.cer import _cer_compute, _cer_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["CharErrorRate.plot"] class CharErrorRate(Metric): @@ -84,3 +88,46 @@ def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> def compute(self) -> Tensor: """Calculate the character error rate.""" return _cer_compute(self.errors, self.total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import CharErrorRate + >>> metric = CharErrorRate() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import CharErrorRate + >>> metric = CharErrorRate() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/eed.py b/src/torchmetrics/text/eed.py index a28c6b18294..56c5c27c230 100644 --- a/src/torchmetrics/text/eed.py +++ b/src/torchmetrics/text/eed.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Any, List, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union from torch import Tensor, stack from typing_extensions import Literal from torchmetrics.functional.text.eed import _eed_compute, _eed_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["ExtendedEditDistance.plot"] class ExtendedEditDistance(Metric): @@ -112,3 +116,46 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: if self.return_sentence_level_score: return average, stack(self.sentence_eed) return average + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import ExtendedEditDistance + >>> metric = ExtendedEditDistance() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import ExtendedEditDistance + >>> metric = ExtendedEditDistance() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/mer.py b/src/torchmetrics/text/mer.py index 088a5ab9fe0..1e6282b5c85 100644 --- a/src/torchmetrics/text/mer.py +++ b/src/torchmetrics/text/mer.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Any, List, Union +from typing import Any, List, Optional, Sequence, Union import torch from torch import Tensor, tensor from torchmetrics.functional.text.mer import _mer_compute, _mer_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MatchErrorRate.plot"] class MatchErrorRate(Metric): @@ -85,3 +89,46 @@ def update( def compute(self) -> Tensor: """Calculate the Match error rate.""" return _mer_compute(self.errors, self.total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import MatchErrorRate + >>> metric = MatchErrorRate() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import MatchErrorRate + >>> metric = MatchErrorRate() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/wer.py b/src/torchmetrics/text/wer.py index 4687da840b5..2d62438a292 100644 --- a/src/torchmetrics/text/wer.py +++ b/src/torchmetrics/text/wer.py @@ -11,13 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Union +from typing import Any, List, Optional, Sequence, Union import torch from torch import Tensor, tensor from torchmetrics.functional.text.wer import _wer_compute, _wer_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["WordErrorRate.plot"] class WordErrorRate(Metric): @@ -82,3 +87,46 @@ def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> def compute(self) -> Tensor: """Calculate the word error rate.""" return _wer_compute(self.errors, self.total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import WordErrorRate + >>> metric = WordErrorRate() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import WordErrorRate + >>> metric = WordErrorRate() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/wil.py b/src/torchmetrics/text/wil.py index d0c9c8bc96d..698854366f0 100644 --- a/src/torchmetrics/text/wil.py +++ b/src/torchmetrics/text/wil.py @@ -11,13 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Any, List, Union +from typing import Any, List, Optional, Sequence, Union from torch import Tensor, tensor from torchmetrics.functional.text.wil import _wil_compute, _wil_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["WordInfoLost.plot"] class WordInfoLost(Metric): @@ -83,3 +87,46 @@ def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> def compute(self) -> Tensor: """Calculate the Word Information Lost.""" return _wil_compute(self.errors, self.target_total, self.preds_total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import WordInfoLost + >>> metric = WordInfoLost() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import WordInfoLost + >>> metric = WordInfoLost() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/text/wip.py b/src/torchmetrics/text/wip.py index 335c524a00d..992d6021eca 100644 --- a/src/torchmetrics/text/wip.py +++ b/src/torchmetrics/text/wip.py @@ -11,13 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Any, List, Union +from typing import Any, List, Optional, Sequence, Union from torch import Tensor, tensor from torchmetrics.functional.text.wip import _wip_compute, _wip_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["WordInfoPreserved.plot"] class WordInfoPreserved(Metric): @@ -84,3 +88,46 @@ def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> def compute(self) -> Tensor: """Calculate the Word Information Preserved.""" return _wip_compute(self.errors, self.target_total, self.preds_total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import WordInfoPreserved + >>> metric = WordInfoPreserved() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import WordInfoPreserved + >>> metric = WordInfoPreserved() + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index dc4c742a87b..c181c88ceea 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -117,6 +117,14 @@ RetrievalRecallAtFixedPrecision, RetrievalRPrecision, ) +from torchmetrics.text import ( + CharErrorRate, + ExtendedEditDistance, + MatchErrorRate, + WordErrorRate, + WordInfoLost, + WordInfoPreserved, +) _rand_input = lambda: torch.rand(10) _binary_randint_input = lambda: torch.randint(2, (10,)) @@ -130,6 +138,8 @@ torch.tensor([1, 1, 0, 0, 0, 0, 1, 1]).float(), 40, replacement=True ).reshape(1, 5, 4, 2) _nominal_input = lambda: torch.randint(0, 4, (100,)) +_text_input_1 = lambda: ["this is the prediction", "there is an other sample"] +_text_input_2 = lambda: ["this is the reference", "there is another one"] @pytest.mark.parametrize( @@ -422,6 +432,12 @@ pytest.param(SymmetricMeanAbsolutePercentageError, _rand_input, _rand_input, id="symmetric mape"), pytest.param(TweedieDevianceScore, _rand_input, _rand_input, id="tweedie deviance score"), pytest.param(WeightedMeanAbsolutePercentageError, _rand_input, _rand_input, id="weighted mape"), + pytest.param(WordInfoPreserved, _text_input_1, _text_input_2, id="word info preserved"), + pytest.param(WordInfoLost, _text_input_1, _text_input_2, id="word info lost"), + pytest.param(WordErrorRate, _text_input_1, _text_input_2, id="word error rate"), + pytest.param(CharErrorRate, _text_input_1, _text_input_2, id="character error rate"), + pytest.param(ExtendedEditDistance, _text_input_1, _text_input_2, id="extended edit distance"), + pytest.param(MatchErrorRate, _text_input_1, _text_input_2, id="match error rate"), ], ) @pytest.mark.parametrize("num_vals", [1, 5])