From 0e20e73dc15a3968cbc33ee2aa7cc8811db34f96 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 15 Mar 2023 08:47:33 +0100 Subject: [PATCH 1/3] base adding --- src/torchmetrics/regression/pearson.py | 48 ++++++++++++++++++- src/torchmetrics/regression/r2.py | 48 ++++++++++++++++++- src/torchmetrics/regression/spearman.py | 48 ++++++++++++++++++- src/torchmetrics/regression/symmetric_mape.py | 48 ++++++++++++++++++- .../regression/tweedie_deviance.py | 48 ++++++++++++++++++- src/torchmetrics/regression/wmape.py | 48 ++++++++++++++++++- 6 files changed, 282 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/regression/pearson.py b/src/torchmetrics/regression/pearson.py index 8004c4f3b54..fd88916500a 100644 --- a/src/torchmetrics/regression/pearson.py +++ b/src/torchmetrics/regression/pearson.py @@ -11,13 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor from torchmetrics.functional.regression.pearson import _pearson_corrcoef_compute, _pearson_corrcoef_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MeanSquaredError.plot"] def _final_aggregation( @@ -159,3 +164,44 @@ def compute(self) -> Tensor: corr_xy = self.corr_xy n_total = self.n_total return _pearson_corrcoef_compute(var_x, var_y, corr_xy, n_total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import MeanSquaredError + >>> metric = MeanSquaredError() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import MeanSquaredError + >>> metric = MeanSquaredError() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/regression/r2.py b/src/torchmetrics/regression/r2.py index 5b2289354af..99ad2e1bcdc 100644 --- a/src/torchmetrics/regression/r2.py +++ b/src/torchmetrics/regression/r2.py @@ -11,13 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor, tensor from torchmetrics.functional.regression.r2 import _r2_score_compute, _r2_score_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MeanSquaredError.plot"] class R2Score(Metric): @@ -129,3 +134,44 @@ def compute(self) -> Tensor: return _r2_score_compute( self.sum_squared_error, self.sum_error, self.residual, self.total, self.adjusted, self.multioutput ) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import MeanSquaredError + >>> metric = MeanSquaredError() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import MeanSquaredError + >>> metric = MeanSquaredError() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/regression/spearman.py b/src/torchmetrics/regression/spearman.py index 6c09db268a5..312dc968a58 100644 --- a/src/torchmetrics/regression/spearman.py +++ b/src/torchmetrics/regression/spearman.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, List, Optional, Sequence, Union from torch import Tensor @@ -19,6 +19,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MeanSquaredError.plot"] class SpearmanCorrCoef(Metric): @@ -95,3 +100,44 @@ def compute(self) -> Tensor: preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return _spearman_corrcoef_compute(preds, target) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import MeanSquaredError + >>> metric = MeanSquaredError() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import MeanSquaredError + >>> metric = MeanSquaredError() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/regression/symmetric_mape.py b/src/torchmetrics/regression/symmetric_mape.py index faec0a517de..360f093da76 100644 --- a/src/torchmetrics/regression/symmetric_mape.py +++ b/src/torchmetrics/regression/symmetric_mape.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional, Sequence, Union from torch import Tensor, tensor @@ -20,6 +20,11 @@ _symmetric_mean_absolute_percentage_error_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MeanSquaredError.plot"] class SymmetricMeanAbsolutePercentageError(Metric): @@ -74,3 +79,44 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute mean absolute percentage error over state.""" return _symmetric_mean_absolute_percentage_error_compute(self.sum_abs_per_error, self.total) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import MeanSquaredError + >>> metric = MeanSquaredError() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import MeanSquaredError + >>> metric = MeanSquaredError() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/regression/tweedie_deviance.py b/src/torchmetrics/regression/tweedie_deviance.py index 70485af8e3c..4e493a437c4 100644 --- a/src/torchmetrics/regression/tweedie_deviance.py +++ b/src/torchmetrics/regression/tweedie_deviance.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor @@ -21,6 +21,11 @@ _tweedie_deviance_score_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MeanSquaredError.plot"] class TweedieDevianceScore(Metric): @@ -99,3 +104,44 @@ def update(self, preds: Tensor, targets: Tensor) -> None: def compute(self) -> Tensor: """Compute metric.""" return _tweedie_deviance_score_compute(self.sum_deviance_score, self.num_observations) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import MeanSquaredError + >>> metric = MeanSquaredError() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import MeanSquaredError + >>> metric = MeanSquaredError() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/regression/wmape.py b/src/torchmetrics/regression/wmape.py index 0df0d3e2d58..41140db58e1 100644 --- a/src/torchmetrics/regression/wmape.py +++ b/src/torchmetrics/regression/wmape.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor @@ -21,6 +21,11 @@ _weighted_mean_absolute_percentage_error_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MeanSquaredError.plot"] class WeightedMeanAbsolutePercentageError(Metric): @@ -75,3 +80,44 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute weighted mean absolute percentage error over state.""" return _weighted_mean_absolute_percentage_error_compute(self.sum_abs_error, self.sum_scale) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import MeanSquaredError + >>> metric = MeanSquaredError() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import MeanSquaredError + >>> metric = MeanSquaredError() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + """ + return self._plot(val, ax) From abf1a4bcfa43fd83776027bd156cb1336f1feceb Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 15 Mar 2023 08:54:49 +0100 Subject: [PATCH 2/3] tests --- src/torchmetrics/regression/pearson.py | 10 +++++----- src/torchmetrics/regression/r2.py | 10 +++++----- src/torchmetrics/regression/spearman.py | 10 +++++----- src/torchmetrics/regression/symmetric_mape.py | 10 +++++----- src/torchmetrics/regression/tweedie_deviance.py | 10 +++++----- src/torchmetrics/regression/wmape.py | 10 +++++----- tests/unittests/utilities/test_plot.py | 12 ++++++++++++ 7 files changed, 42 insertions(+), 30 deletions(-) diff --git a/src/torchmetrics/regression/pearson.py b/src/torchmetrics/regression/pearson.py index fd88916500a..498f8ff0307 100644 --- a/src/torchmetrics/regression/pearson.py +++ b/src/torchmetrics/regression/pearson.py @@ -22,7 +22,7 @@ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: - __doctest_skip__ = ["MeanSquaredError.plot"] + __doctest_skip__ = ["PearsonCorrCoef.plot"] def _final_aggregation( @@ -187,8 +187,8 @@ def plot( >>> from torch import randn >>> # Example plotting a single value - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import PearsonCorrCoef + >>> metric = PearsonCorrCoef() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot() @@ -197,8 +197,8 @@ def plot( >>> from torch import randn >>> # Example plotting multiple values - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import PearsonCorrCoef + >>> metric = PearsonCorrCoef() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) diff --git a/src/torchmetrics/regression/r2.py b/src/torchmetrics/regression/r2.py index 99ad2e1bcdc..f041662d9a6 100644 --- a/src/torchmetrics/regression/r2.py +++ b/src/torchmetrics/regression/r2.py @@ -22,7 +22,7 @@ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: - __doctest_skip__ = ["MeanSquaredError.plot"] + __doctest_skip__ = ["R2Score.plot"] class R2Score(Metric): @@ -157,8 +157,8 @@ def plot( >>> from torch import randn >>> # Example plotting a single value - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import R2Score + >>> metric = R2Score() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot() @@ -167,8 +167,8 @@ def plot( >>> from torch import randn >>> # Example plotting multiple values - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import R2Score + >>> metric = R2Score() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) diff --git a/src/torchmetrics/regression/spearman.py b/src/torchmetrics/regression/spearman.py index 312dc968a58..2bb0263cb8e 100644 --- a/src/torchmetrics/regression/spearman.py +++ b/src/torchmetrics/regression/spearman.py @@ -23,7 +23,7 @@ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: - __doctest_skip__ = ["MeanSquaredError.plot"] + __doctest_skip__ = ["SpearmanCorrCoef.plot"] class SpearmanCorrCoef(Metric): @@ -123,8 +123,8 @@ def plot( >>> from torch import randn >>> # Example plotting a single value - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import SpearmanCorrCoef + >>> metric = SpearmanCorrCoef() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot() @@ -133,8 +133,8 @@ def plot( >>> from torch import randn >>> # Example plotting multiple values - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import SpearmanCorrCoef + >>> metric = SpearmanCorrCoef() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) diff --git a/src/torchmetrics/regression/symmetric_mape.py b/src/torchmetrics/regression/symmetric_mape.py index 360f093da76..6c994d0ea38 100644 --- a/src/torchmetrics/regression/symmetric_mape.py +++ b/src/torchmetrics/regression/symmetric_mape.py @@ -24,7 +24,7 @@ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: - __doctest_skip__ = ["MeanSquaredError.plot"] + __doctest_skip__ = ["SymmetricMeanAbsolutePercentageError.plot"] class SymmetricMeanAbsolutePercentageError(Metric): @@ -102,8 +102,8 @@ def plot( >>> from torch import randn >>> # Example plotting a single value - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import SymmetricMeanAbsolutePercentageError + >>> metric = SymmetricMeanAbsolutePercentageError() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot() @@ -112,8 +112,8 @@ def plot( >>> from torch import randn >>> # Example plotting multiple values - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import SymmetricMeanAbsolutePercentageError + >>> metric = SymmetricMeanAbsolutePercentageError() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) diff --git a/src/torchmetrics/regression/tweedie_deviance.py b/src/torchmetrics/regression/tweedie_deviance.py index 4e493a437c4..53d5175c7fc 100644 --- a/src/torchmetrics/regression/tweedie_deviance.py +++ b/src/torchmetrics/regression/tweedie_deviance.py @@ -25,7 +25,7 @@ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: - __doctest_skip__ = ["MeanSquaredError.plot"] + __doctest_skip__ = ["TweedieDevianceScore.plot"] class TweedieDevianceScore(Metric): @@ -127,8 +127,8 @@ def plot( >>> from torch import randn >>> # Example plotting a single value - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import TweedieDevianceScore + >>> metric = TweedieDevianceScore() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot() @@ -137,8 +137,8 @@ def plot( >>> from torch import randn >>> # Example plotting multiple values - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import TweedieDevianceScore + >>> metric = TweedieDevianceScore() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) diff --git a/src/torchmetrics/regression/wmape.py b/src/torchmetrics/regression/wmape.py index 41140db58e1..d8a55389f13 100644 --- a/src/torchmetrics/regression/wmape.py +++ b/src/torchmetrics/regression/wmape.py @@ -25,7 +25,7 @@ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: - __doctest_skip__ = ["MeanSquaredError.plot"] + __doctest_skip__ = ["WeightedMeanAbsolutePercentageError.plot"] class WeightedMeanAbsolutePercentageError(Metric): @@ -103,8 +103,8 @@ def plot( >>> from torch import randn >>> # Example plotting a single value - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import WeightedMeanAbsolutePercentageError + >>> metric = WeightedMeanAbsolutePercentageError() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot() @@ -113,8 +113,8 @@ def plot( >>> from torch import randn >>> # Example plotting multiple values - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import WeightedMeanAbsolutePercentageError + >>> metric = WeightedMeanAbsolutePercentageError() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index f9fb0c243d7..8f9f11ba0f3 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -80,6 +80,12 @@ MeanSquaredError, MeanSquaredLogError, MinkowskiDistance, + PearsonCorrCoef, + R2Score, + SpearmanCorrCoef, + SymmetricMeanAbsolutePercentageError, + TweedieDevianceScore, + WeightedMeanAbsolutePercentageError, ) from torchmetrics.retrieval import RetrievalMRR, RetrievalPrecision, RetrievalRecall, RetrievalRPrecision @@ -293,6 +299,12 @@ pytest.param(MeanAbsoluteError, _rand_input, _rand_input, id="mean absolute error"), pytest.param(MeanAbsolutePercentageError, _rand_input, _rand_input, id="mean absolute percentage error"), pytest.param(partial(MinkowskiDistance, p=3), _rand_input, _rand_input, id="minkowski distance"), + pytest.param(PearsonCorrCoef, _rand_input, _rand_input, id="pearson corr coef"), + pytest.param(R2Score, _rand_input, _rand_input, id="r2 score"), + pytest.param(SpearmanCorrCoef, _rand_input, _rand_input, id="spearman corr coef"), + pytest.param(SymmetricMeanAbsolutePercentageError, _rand_input, _rand_input, id="symmetric mape"), + pytest.param(TweedieDevianceScore, _rand_input, _rand_input, id="tweedie deviance score"), + pytest.param(WeightedMeanAbsolutePercentageError, _rand_input, _rand_input, id="weighted mape"), ], ) @pytest.mark.parametrize("num_vals", [1, 5]) From b195d411543f325d3d6bc4221f80d631ad937985 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 15 Mar 2023 08:55:30 +0100 Subject: [PATCH 3/3] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b62d97e6cd0..d942a583ab8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1605](https://github.com/Lightning-AI/metrics/pull/1605), [#1610](https://github.com/Lightning-AI/metrics/pull/1610), [#1609](https://github.com/Lightning-AI/metrics/pull/1609), + [#1621](https://github.com/Lightning-AI/metrics/pull/1621), )