From aa1457fce070a7f057b4dc80a48b4fad0ddf8f48 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 9 Mar 2024 12:51:45 +0100 Subject: [PATCH 01/30] init files --- src/torchmetrics/__init__.py | 36 ++++++++++--------- src/torchmetrics/functional/__init__.py | 24 +++++++------ .../functional/regression/__init__.py | 8 +++-- src/torchmetrics/regression/__init__.py | 8 +++-- 4 files changed, 42 insertions(+), 34 deletions(-) diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index b1549dfaf8b..8f6aee6867f 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -107,6 +107,7 @@ MeanSquaredError, MeanSquaredLogError, MinkowskiDistance, + NormalizedRootMeanSquaredError, PearsonCorrCoef, R2Score, RelativeSquaredError, @@ -151,25 +152,23 @@ ) __all__ = [ - "functional", - "Accuracy", "AUROC", + "Accuracy", "AveragePrecision", "BLEUScore", "BootStrapper", + "CHRFScore", "CalibrationError", "CatMetric", - "ClasswiseWrapper", "CharErrorRate", - "CHRFScore", - "ConcordanceCorrCoef", + "ClasswiseWrapper", "CohenKappa", + "ConcordanceCorrCoef", "ConfusionMatrix", "CosineSimilarity", "CramersV", "CriticalSuccessIndex", "Dice", - "TweedieDevianceScore", "ErrorRelativeGlobalDimensionlessSynthesis", "ExactMatch", "ExplainedVariance", @@ -180,8 +179,8 @@ "HammingDistance", "HingeLoss", "JaccardIndex", - "KendallRankCorrCoef", "KLDivergence", + "KendallRankCorrCoef", "LogCoshError", "MatchErrorRate", "MatthewsCorrCoef", @@ -194,14 +193,16 @@ "Metric", "MetricCollection", "MetricTracker", - "MinkowskiDistance", "MinMaxMetric", "MinMetric", + "MinkowskiDistance", "ModifiedPanopticQuality", + "MultiScaleStructuralSimilarityIndexMeasure", "MultioutputWrapper", "MultitaskWrapper", - "MultiScaleStructuralSimilarityIndexMeasure", + "NormalizedRootMeanSquaredError", "PanopticQuality", + "PeakSignalNoiseRatio", "PearsonCorrCoef", "PearsonsContingencyCoefficient", "PermutationInvariantTraining", @@ -209,8 +210,8 @@ "Precision", "PrecisionAtFixedRecall", "PrecisionRecallCurve", - "PeakSignalNoiseRatio", "R2Score", + "ROC", "Recall", "RecallAtFixedPrecision", "RelativeAverageSpectralError", @@ -221,37 +222,38 @@ "RetrievalMRR", "RetrievalNormalizedDCG", "RetrievalPrecision", - "RetrievalRecall", - "RetrievalRPrecision", "RetrievalPrecisionRecallCurve", + "RetrievalRPrecision", + "RetrievalRecall", "RetrievalRecallAtFixedPrecision", - "ROC", "RootMeanSquaredErrorUsingSlidingWindow", "RunningMean", "RunningSum", + "SQuAD", "SacreBLEUScore", - "SignalDistortionRatio", "ScaleInvariantSignalDistortionRatio", "ScaleInvariantSignalNoiseRatio", + "SensitivityAtSpecificity", + "SignalDistortionRatio", "SignalNoiseRatio", "SpearmanCorrCoef", "Specificity", "SpecificityAtSensitivity", - "SensitivityAtSpecificity", "SpectralAngleMapper", "SpectralDistortionIndex", - "SQuAD", - "StructuralSimilarityIndexMeasure", "StatScores", + "StructuralSimilarityIndexMeasure", "SumMetric", "SymmetricMeanAbsolutePercentageError", "TheilsU", "TotalVariation", "TranslationEditRate", "TschuprowsT", + "TweedieDevianceScore", "UniversalImageQualityIndex", "WeightedMeanAbsolutePercentageError", "WordErrorRate", "WordInfoLost", "WordInfoPreserved", + "functional", ] diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 30a7145aa71..7de7f261867 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -100,6 +100,7 @@ mean_squared_error, mean_squared_log_error, minkowski_distance, + normalized_root_mean_squared_error, pearson_corrcoef, r2_score, relative_squared_error, @@ -146,14 +147,13 @@ "calibration_error", "char_error_rate", "chrf_score", - "concordance_corrcoef", "cohen_kappa", + "concordance_corrcoef", "confusion_matrix", "cosine_similarity", "cramers_v", "cramers_v_matrix", "critical_success_index", - "tweedie_deviance_score", "dice", "error_relative_global_dimensionless_synthesis", "exact_match", @@ -177,12 +177,14 @@ "mean_squared_log_error", "minkowski_distance", "multiscale_structural_similarity_index_measure", + "normalized_root_mean_squared_error", "pairwise_cosine_similarity", "pairwise_euclidean_distance", "pairwise_linear_similarity", "pairwise_manhattan_distance", "pairwise_minkowski_distance", "panoptic_quality", + "peak_signal_noise_ratio", "pearson_corrcoef", "pearsons_contingency_coefficient", "pearsons_contingency_coefficient_matrix", @@ -190,10 +192,11 @@ "perplexity", "pit_permutate", "precision", + "precision_at_fixed_recall", "precision_recall_curve", - "peak_signal_noise_ratio", "r2_score", "recall", + "recall_at_fixed_precision", "relative_average_spectral_error", "relative_squared_error", "retrieval_average_precision", @@ -201,24 +204,27 @@ "retrieval_hit_rate", "retrieval_normalized_dcg", "retrieval_precision", + "retrieval_precision_recall_curve", "retrieval_r_precision", "retrieval_recall", "retrieval_reciprocal_rank", - "retrieval_precision_recall_curve", "roc", "root_mean_squared_error_using_sliding_window", "rouge_score", "sacre_bleu_score", - "signal_distortion_ratio", "scale_invariant_signal_distortion_ratio", "scale_invariant_signal_noise_ratio", + "sensitivity_at_specificity", + "signal_distortion_ratio", "signal_noise_ratio", "spearman_corrcoef", "specificity", + "specificity_at_sensitivity", + "spectral_angle_mapper", "spectral_distortion_index", "squad", - "structural_similarity_index_measure", "stat_scores", + "structural_similarity_index_measure", "symmetric_mean_absolute_percentage_error", "theils_u", "theils_u_matrix", @@ -226,14 +232,10 @@ "translation_edit_rate", "tschuprows_t", "tschuprows_t_matrix", + "tweedie_deviance_score", "universal_image_quality_index", - "spectral_angle_mapper", "weighted_mean_absolute_percentage_error", "word_error_rate", "word_information_lost", "word_information_preserved", - "precision_at_fixed_recall", - "recall_at_fixed_precision", - "sensitivity_at_specificity", - "specificity_at_sensitivity", ] diff --git a/src/torchmetrics/functional/regression/__init__.py b/src/torchmetrics/functional/regression/__init__.py index c2dab8c5f59..063fbc059e3 100644 --- a/src/torchmetrics/functional/regression/__init__.py +++ b/src/torchmetrics/functional/regression/__init__.py @@ -23,6 +23,7 @@ from torchmetrics.functional.regression.mape import mean_absolute_percentage_error from torchmetrics.functional.regression.minkowski import minkowski_distance from torchmetrics.functional.regression.mse import mean_squared_error +from torchmetrics.functional.regression.nrmse import normalized_root_mean_squared_error from torchmetrics.functional.regression.pearson import pearson_corrcoef from torchmetrics.functional.regression.r2 import r2_score from torchmetrics.functional.regression.rse import relative_squared_error @@ -39,13 +40,14 @@ "kendall_rank_corrcoef", "kl_divergence", "log_cosh_error", - "mean_squared_log_error", "mean_absolute_error", - "mean_squared_error", - "pearson_corrcoef", "mean_absolute_percentage_error", "mean_absolute_percentage_error", + "mean_squared_error", + "mean_squared_log_error", "minkowski_distance", + "normalized_root_mean_squared_error", + "pearson_corrcoef", "r2_score", "relative_squared_error", "spearman_corrcoef", diff --git a/src/torchmetrics/regression/__init__.py b/src/torchmetrics/regression/__init__.py index 03ba8023a10..6a41c01bcdb 100644 --- a/src/torchmetrics/regression/__init__.py +++ b/src/torchmetrics/regression/__init__.py @@ -23,6 +23,7 @@ from torchmetrics.regression.mape import MeanAbsolutePercentageError from torchmetrics.regression.minkowski import MinkowskiDistance from torchmetrics.regression.mse import MeanSquaredError +from torchmetrics.regression.nrmse import NormalizedRootMeanSquaredError from torchmetrics.regression.pearson import PearsonCorrCoef from torchmetrics.regression.r2 import R2Score from torchmetrics.regression.rse import RelativeSquaredError @@ -36,14 +37,15 @@ "CosineSimilarity", "CriticalSuccessIndex", "ExplainedVariance", - "KendallRankCorrCoef", "KLDivergence", + "KendallRankCorrCoef", "LogCoshError", - "MeanSquaredLogError", "MeanAbsoluteError", "MeanAbsolutePercentageError", - "MinkowskiDistance", "MeanSquaredError", + "MeanSquaredLogError", + "MinkowskiDistance", + "NormalizedRootMeanSquaredError", "PearsonCorrCoef", "R2Score", "RelativeSquaredError", From 58b75778c9f9d2667468fccf51dbd95851203525 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 9 Mar 2024 12:52:14 +0100 Subject: [PATCH 02/30] docs --- .../normalized_root_mean_squared_error.rst | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 docs/source/regression/normalized_root_mean_squared_error.rst diff --git a/docs/source/regression/normalized_root_mean_squared_error.rst b/docs/source/regression/normalized_root_mean_squared_error.rst new file mode 100644 index 00000000000..7bbc2f392d5 --- /dev/null +++ b/docs/source/regression/normalized_root_mean_squared_error.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Normalized Root Mean Squared Error (NRMSE) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Regression + +.. include:: ../links.rst + +########################################## +Normalized Root Mean Squared Error (NRMSE) +########################################## + +Module Interface +________________ + +.. autoclass:: torchmetrics.NormalizedRootMeanSquaredError + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.normalized_root_mean_squared_error From 7933f615be2d0df871a1bb82b50f5d1fa4ec0dc2 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 9 Mar 2024 12:53:37 +0100 Subject: [PATCH 03/30] requirements for testing --- docs/source/links.rst | 1 + requirements/_devel.txt | 1 + requirements/_docs.txt | 1 + requirements/reqression_test.txt | 1 + 4 files changed, 4 insertions(+) create mode 100644 requirements/reqression_test.txt diff --git a/docs/source/links.rst b/docs/source/links.rst index 7034f764d65..18def84c047 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -170,3 +170,4 @@ .. _FLORES-200: https://arxiv.org/abs/2207.04672 .. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html .. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013 +.. _Normalized Root Mean Squared Error: https://onlinelibrary.wiley.com/doi/abs/10.1111/1365-2478.12109 diff --git a/requirements/_devel.txt b/requirements/_devel.txt index 596cc138133..6a8ea2b8e7f 100644 --- a/requirements/_devel.txt +++ b/requirements/_devel.txt @@ -20,3 +20,4 @@ -r classification_test.txt -r nominal_test.txt -r segmentation_test.txt +-r regression_test.txt diff --git a/requirements/_docs.txt b/requirements/_docs.txt index 4fbb0d08291..b4b7ebb6633 100644 --- a/requirements/_docs.txt +++ b/requirements/_docs.txt @@ -24,3 +24,4 @@ pydantic > 1.0.0, < 3.0.0 -r multimodal.txt -r text.txt -r text_test.txt +-r regression_test.txt diff --git a/requirements/reqression_test.txt b/requirements/reqression_test.txt new file mode 100644 index 00000000000..859605fda3b --- /dev/null +++ b/requirements/reqression_test.txt @@ -0,0 +1 @@ +permetrics==2.0.0 From ba8848a0886f66b926363c52ae0d36e8b2138fec Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 9 Mar 2024 14:47:20 +0100 Subject: [PATCH 04/30] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 04c304c4712..7764c9c5da4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `QualityWithNoReference` metric ([#2288](https://github.com/Lightning-AI/torchmetrics/pull/2288)) +- Added `NormalizedRootMeanSquaredError` metric to regression subpackage ([#2442](https://github.com/Lightning-AI/torchmetrics/pull/2442)) + + ### Changed - Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424)) From 0bf7f56f31d469ed647f62e81422fc37e5fe04d7 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 9 Mar 2024 14:48:50 +0100 Subject: [PATCH 05/30] add class interface --- .../functional/regression/nrmse.py | 89 +++++++ src/torchmetrics/regression/nrmse.py | 236 ++++++++++++++++++ 2 files changed, 325 insertions(+) create mode 100644 src/torchmetrics/functional/regression/nrmse.py create mode 100644 src/torchmetrics/regression/nrmse.py diff --git a/src/torchmetrics/functional/regression/nrmse.py b/src/torchmetrics/functional/regression/nrmse.py new file mode 100644 index 00000000000..6bb5f87262e --- /dev/null +++ b/src/torchmetrics/functional/regression/nrmse.py @@ -0,0 +1,89 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.regression.mse import _mean_squared_error_update + + +def _normalized_root_mean_squared_error_update( + preds: Tensor, target: Tensor, num_outputs: int, normalization: Literal["mean", "range", "std"] = "mean" +) -> Tuple[Tensor, int, Tensor]: + sum_squared_error, num_obs = _mean_squared_error_update(preds, target, num_outputs) + + target = target.view(-1) if num_outputs == 1 else target + if normalization == "mean": + denom = torch.mean(target, dim=0) + elif normalization == "range": + denom = torch.max(target, dim=0).values - torch.min(target, dim=0).values + elif normalization == "std": + denom = torch.std(target, correction=0, dim=0) + else: + raise ValueError(f"Argument `normalization` should be either 'mean', 'range' or 'std', but got {normalization}") + return sum_squared_error, num_obs, denom + + +def _normalized_root_mean_squared_error_compute(sum_squared_error: Tensor, num_obs: int, denom: Tensor) -> Tensor: + """Calculates RMSE and normalizes it.""" + rmse = torch.sqrt(sum_squared_error / num_obs) + return rmse / denom + + +def normalized_root_mean_squared_error( + preds: Tensor, + target: Tensor, + normalization: Literal["mean", "range", "std"] = "mean", + num_outputs: int = 1, +) -> Tensor: + """Calculates the `Normalized Root Mean Squared Error`_ (NRMSE) also know as scatter index. + + Args: + preds: estimated labels + target: ground truth labels + normalization: type of normalization to be applied. Choose from "mean", "range", "std" which corresponds to + normalizing the RMSE by the mean of the target, the range of the target or the standard deviation of the + target. + num_outputs: Number of outputs in multioutput setting + + Return: + Tensor with the NRMSE score + + Example: + >>> import torch + >>> from torchmetrics.functional.regression import normalized_root_mean_squared_error + >>> preds = torch.tensor([0., 1, 2, 3]) + >>> target = torch.tensor([0., 1, 2, 2]) + >>> normalized_root_mean_squared_error(preds, target, normalization="mean") + tensor(0.4000) + >>> normalized_root_mean_squared_error(preds, target, normalization="range") + tensor(0.2500) + >>> normalized_root_mean_squared_error(preds, target, normalization="std") + tensor(0.5222) + + Example (multioutput): + >>> import torch + >>> from torchmetrics.functional.regression import normalized_root_mean_squared_error + >>> preds = torch.tensor([[0., 1], [2, 3], [4, 5], [6, 7]]) + >>> target = torch.tensor([[0., 1], [3, 3], [4, 5], [8, 9]]) + >>> normalized_root_mean_squared_error(preds, target, normalization="mean", num_outputs=2) + tensor([0.2981, 0.2222]) + + """ + sum_squared_error, num_obs, denom = _normalized_root_mean_squared_error_update( + preds, target, num_outputs=num_outputs, normalization=normalization + ) + return _normalized_root_mean_squared_error_compute(sum_squared_error, num_obs, denom) diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py new file mode 100644 index 00000000000..1b23e359a57 --- /dev/null +++ b/src/torchmetrics/regression/nrmse.py @@ -0,0 +1,236 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Sequence, Union + +import torch +from torch import Tensor, tensor +from typing_extensions import Literal + +from torchmetrics.functional.regression.nrmse import ( + _mean_squared_error_update, + _normalized_root_mean_squared_error_compute, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["NormalizedRootMeanSquaredError.plot"] + + +def _final_aggregation( + min_val: Tensor, + max_val: Tensor, + mean_val: Tensor, + std_val: Tensor, + total: Tensor, + normalization: Literal["mean", "range", "std"] = "mean", +) -> Tensor: + if len(min_val) == 1: + if normalization == "mean": + return mean_val[0] + if normalization == "range": + return max_val[0] - min_val[0] + if normalization == "std": + return std_val[0] + + min_val_1, max_val_1, mean_val_1, std_val_1, total_1 = min_val[0], max_val[0], mean_val[0], std_val[0], total[0] + for i in range(1, len(min_val)): + min_val_2, max_val_2, mean_val_2, std_val_2, total_2 = min_val[i], max_val[i], mean_val[i], std_val[i], total[i] + total = total_1 + total_2 + mean = (total_1 * mean_val_1 + total_2 * mean_val_2) / total + std = torch.sqrt( + ( + std_val_1**2 * (total_1 - 1) + + std_val_2**2 * (total_2 - 1) + + (mean_val_1 - mean) ** 2 * total_1 + + (mean_val_2 - mean) ** 2 * total_2 + ) + / (total - 1) + ) + min_val = torch.min(min_val_1, min_val_2) + max_val = torch.max(max_val_1, max_val_2) + + if normalization == "mean": + return mean + if normalization == "range": + return max_val - min_val + return std + + +class NormalizedRootMeanSquaredError(Metric): + r"""Calculates the `Normalized Root Mean Squared Error`_ (NRMSE) also know as scatter index. + + The metric is defined as: + + .. math:: + \text{NRMSE} = \frac{\text{RMSE}}{\text{denom}} + + where RMSE is the root mean squared error and `denom` is the normalization factor. The normalization factor can be + either be the mean, range or standard deviation of the target, which can be set using the `normalization` argument. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): Predictions from model + - ``target`` (:class:`~torch.Tensor`): Ground truth values + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``nrmse`` (:class:`~torch.Tensor`): A tensor with the mean squared error + + Args: + normalization: type of normalization to be applied. Choose from "mean", "range", "std" which corresponds to + normalizing the RMSE by the mean of the target, the range of the target or the standard deviation of the + target. + num_outputs: Number of outputs in multioutput setting + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example:: + Single output normalized root mean squared error computation: + + >>> import torch + >>> from torchmetrics import NormalizedRootMeanSquaredError + >>> target = tensor([2.5, 5.0, 4.0, 8.0]) + >>> preds = tensor([3.0, 5.0, 2.5, 7.0]) + >>> nrmse = NormalizedRootMeanSquaredError(normalization="mean") + >>> nrmse(preds, target) + tensor(0.4000) + >>> nrmse = NormalizedRootMeanSquaredError(normalization="range") + >>> nrmse(preds, target) + + Example:: + Multioutput normalized root mean squared error computation: + + >>> import torch + >>> from torchmetrics import NormalizedRootMeanSquaredError + >>> target = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + >>> preds = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]) + >>> nrmse = NormalizedRootMeanSquaredError(num_outputs=3) + >>> nrmse(preds, target) + tensor([1., 1., 1.]) + + """ + + is_differentiable = True + higher_is_better = False + full_state_update = False + plot_lower_bound: float = 0.0 + + sum_squared_error: Tensor + total: Tensor + + def __init__( + self, + normalization: Literal["mean", "range", "std"] = "mean", + num_outputs: int = 1, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + if normalization not in ("mean", "range", "std"): + raise ValueError( + f"Argument `normalization` should be either 'mean', 'range' or 'std', but got {normalization}" + ) + self.normalization = normalization + + if not (isinstance(num_outputs, int) and num_outputs > 0): + raise ValueError(f"Expected num_outputs to be a positive integer but got {num_outputs}") + self.num_outputs = num_outputs + + self.add_state("sum_squared_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + self.add_state("min_val", default=float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None) + self.add_state("max_val", default=-float("Inf") * torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("mean_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("std_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets. + + See `mean_squared_error_update` for details. + + """ + sum_squared_error, num_obs = _mean_squared_error_update(preds, target, self.num_outputs) + self.sum_squared_error += sum_squared_error + self.total += num_obs + + # Update normalization statistics + target = target.view(-1) if self.num_outputs == 1 else target + self.min_val = torch.minimum(target.min(dim=0).values, self.min_val) + self.max_val = torch.maximum(target.max(dim=0).values, self.max_val) + new_mean = torch.mean(target, dim=0) + self.mean_val = (self.mean_val * (self.total - num_obs) + new_mean * num_obs) / self.total + new_std = torch.std(target, correction=0, dim=0) + self.std_val = torch.sqrt((self.std_val**2 * (self.total - num_obs) + new_std**2 * num_obs) / self.total) + + def compute(self) -> Tensor: + """Computes NRMSE over state. + + See `mean_squared_error_compute` for details. + + """ + if (self.num_outputs == 1 and self.mean_val.numel() > 1) or (self.num_outputs > 1 and self.mean_val.ndim > 1): + denom = _final_aggregation( + self.min_val, self.max_val, self.mean_val, self.std_val, self.total, self.normalization + ) + else: + if self.normalization == "mean": + denom = self.mean_val + elif self.normalization == "range": + denom = self.max_val - self.min_val + else: + denom = self.std_val + return _normalized_root_mean_squared_error_compute(self.sum_squared_error, self.total, denom) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import MeanSquaredError + >>> metric = MeanSquaredError() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import MeanSquaredError + >>> metric = MeanSquaredError() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + + """ + return self._plot(val, ax) From b6c7011c4a82664f85bd223ba8ad0209792bf1cd Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 9 Mar 2024 14:49:54 +0100 Subject: [PATCH 06/30] add tests --- tests/unittests/regression/test_mean_error.py | 79 +++++++++++++++++-- 1 file changed, 74 insertions(+), 5 deletions(-) diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index c25882c3f37..420091bd9dd 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -18,6 +18,7 @@ import numpy as np import pytest import torch +from permetrics.regression import RegressionMetric from sklearn.metrics import mean_absolute_error as sk_mean_absolute_error from sklearn.metrics import mean_absolute_percentage_error as sk_mean_abs_percentage_error from sklearn.metrics import mean_squared_error as sk_mean_squared_error @@ -29,6 +30,7 @@ mean_absolute_percentage_error, mean_squared_error, mean_squared_log_error, + normalized_root_mean_squared_error, weighted_mean_absolute_percentage_error, ) from torchmetrics.functional.regression.symmetric_mape import symmetric_mean_absolute_percentage_error @@ -39,6 +41,7 @@ MeanSquaredLogError, WeightedMeanAbsolutePercentageError, ) +from torchmetrics.regression.nrmse import NormalizedRootMeanSquaredError from torchmetrics.regression.symmetric_mape import SymmetricMeanAbsolutePercentageError from unittests import BATCH_SIZE, NUM_BATCHES, _Input @@ -114,6 +117,23 @@ def _reference_symmetric_mape( return np.average(output_errors, weights=multioutput) +def _reference_normalized_root_mean_squared_error( + y_true: np.ndarray, y_pred: np.ndarray, normalization: str = "mean", num_outputs: int = 1 +): + """Reference implementation of Normalized Root Mean Squared Error (NRMSE) metric.""" + if num_outputs == 1: + y_true = y_true.flatten() + y_pred = y_pred.flatten() + evaluator = RegressionMetric(y_true, y_pred) if normalization == "range" else RegressionMetric(y_pred, y_true) + arg_mapping = { + "mean": 1, + "range": 2, + "std": 4, + } + + return evaluator.normalized_root_mean_square_error(model=arg_mapping[normalization]) + + def _reference_weighted_mean_abs_percentage_error(target, preds): return np.sum(np.abs(target - preds)) / np.sum(np.abs(target)) @@ -122,17 +142,26 @@ def _single_target_ref_wrapper(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() - res = sk_fn(sk_target, sk_preds) - - return math.sqrt(res) if (metric_args and not metric_args["squared"]) else res + if metric_args and "normalization" in metric_args: + res = sk_fn(sk_target, sk_preds, normalization=metric_args["normalization"]) + else: + res = sk_fn(sk_target, sk_preds) + if metric_args and "squared" in metric_args and not metric_args["squared"]: + res = math.sqrt(res) + return res def _multi_target_ref_wrapper(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1, NUM_TARGETS).numpy() sk_target = target.view(-1, NUM_TARGETS).numpy() sk_kwargs = {"multioutput": "raw_values"} if metric_args and "num_outputs" in metric_args else {} - res = sk_fn(sk_target, sk_preds, **sk_kwargs) - return math.sqrt(res) if (metric_args and not metric_args["squared"]) else res + if metric_args and "normalization" in metric_args: + res = sk_fn(sk_target, sk_preds, **metric_args) + else: + res = sk_fn(sk_target, sk_preds, **sk_kwargs) + if metric_args and "squared" in metric_args and not metric_args["squared"]: + res = math.sqrt(res) + return res @pytest.mark.parametrize( @@ -163,6 +192,42 @@ def _multi_target_ref_wrapper(preds, target, sk_fn, metric_args): _reference_weighted_mean_abs_percentage_error, {}, ), + ( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "mean", "num_outputs": 1}, + ), + ( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "range", "num_outputs": 1}, + ), + ( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "std", "num_outputs": 1}, + ), + ( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "mean", "num_outputs": NUM_TARGETS}, + ), + ( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "range", "num_outputs": NUM_TARGETS}, + ), + ( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "std", "num_outputs": NUM_TARGETS}, + ), ], ) class TestMeanError(MetricTester): @@ -173,6 +238,8 @@ def test_mean_error_class( self, preds, target, ref_metric, metric_class, metric_functional, sk_fn, metric_args, ddp ): """Test class implementation of metric.""" + if metric_args and "num_outputs" in metric_args and preds.ndim < 3: + pytest.skip("Test only runs for multi-output setting") self.run_class_metric_test( ddp=ddp, preds=preds, @@ -186,6 +253,8 @@ def test_mean_error_functional( self, preds, target, ref_metric, metric_class, metric_functional, sk_fn, metric_args ): """Test functional implementation of metric.""" + if metric_args and "num_outputs" in metric_args and preds.ndim < 3: + pytest.skip("Test only runs for multi-output setting") self.run_functional_metric_test( preds=preds, target=target, From 642dd2725a498fa376a0f7d0a29353cec5696a24 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 31 May 2024 15:35:12 +0200 Subject: [PATCH 07/30] Update NRMSE computation with normalization options --- .../functional/regression/nrmse.py | 15 ++++++-- src/torchmetrics/regression/nrmse.py | 36 +++++++++++-------- tests/unittests/regression/test_mean_error.py | 30 +++++++++------- 3 files changed, 51 insertions(+), 30 deletions(-) diff --git a/src/torchmetrics/functional/regression/nrmse.py b/src/torchmetrics/functional/regression/nrmse.py index 6bb5f87262e..c7ce025902f 100644 --- a/src/torchmetrics/functional/regression/nrmse.py +++ b/src/torchmetrics/functional/regression/nrmse.py @@ -23,6 +23,15 @@ def _normalized_root_mean_squared_error_update( preds: Tensor, target: Tensor, num_outputs: int, normalization: Literal["mean", "range", "std"] = "mean" ) -> Tuple[Tensor, int, Tensor]: + """Updates and returns the sum of squared errors and the number of observations for NRMSE computation. + + Args: + preds: Predicted tensor + target: Ground truth tensor + num_outputs: Number of outputs in multioutput setting + normalization: type of normalization to be applied. Choose from "mean", "range", "std" + + """ sum_squared_error, num_obs = _mean_squared_error_update(preds, target, num_outputs) target = target.view(-1) if num_outputs == 1 else target @@ -55,8 +64,8 @@ def normalized_root_mean_squared_error( preds: estimated labels target: ground truth labels normalization: type of normalization to be applied. Choose from "mean", "range", "std" which corresponds to - normalizing the RMSE by the mean of the target, the range of the target or the standard deviation of the - target. + normalizing the RMSE by the mean of the target, the range of the target or the standard deviation of the + target. num_outputs: Number of outputs in multioutput setting Return: @@ -72,7 +81,7 @@ def normalized_root_mean_squared_error( >>> normalized_root_mean_squared_error(preds, target, normalization="range") tensor(0.2500) >>> normalized_root_mean_squared_error(preds, target, normalization="std") - tensor(0.5222) + tensor(0.6030) Example (multioutput): >>> import torch diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index 1b23e359a57..8f210281e6e 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -105,7 +105,7 @@ class NormalizedRootMeanSquaredError(Metric): >>> preds = tensor([3.0, 5.0, 2.5, 7.0]) >>> nrmse = NormalizedRootMeanSquaredError(normalization="mean") >>> nrmse(preds, target) - tensor(0.4000) + tensor(0.1919) >>> nrmse = NormalizedRootMeanSquaredError(normalization="range") >>> nrmse(preds, target) @@ -122,13 +122,17 @@ class NormalizedRootMeanSquaredError(Metric): """ - is_differentiable = True - higher_is_better = False - full_state_update = False + is_differentiable: bool = True + higher_is_better: bool = False + full_state_update: bool = True plot_lower_bound: float = 0.0 sum_squared_error: Tensor total: Tensor + min_val: Tensor + max_val: Tensor + mean_val: Tensor + var_val: Tensor def __init__( self, @@ -151,9 +155,9 @@ def __init__( self.add_state("sum_squared_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") self.add_state("min_val", default=float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None) - self.add_state("max_val", default=-float("Inf") * torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("max_val", default=-float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None) self.add_state("mean_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) - self.add_state("std_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("var_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets. @@ -163,16 +167,18 @@ def update(self, preds: Tensor, target: Tensor) -> None: """ sum_squared_error, num_obs = _mean_squared_error_update(preds, target, self.num_outputs) self.sum_squared_error += sum_squared_error - self.total += num_obs - - # Update normalization statistics target = target.view(-1) if self.num_outputs == 1 else target + + # Update min and max self.min_val = torch.minimum(target.min(dim=0).values, self.min_val) self.max_val = torch.maximum(target.max(dim=0).values, self.max_val) - new_mean = torch.mean(target, dim=0) - self.mean_val = (self.mean_val * (self.total - num_obs) + new_mean * num_obs) / self.total - new_std = torch.std(target, correction=0, dim=0) - self.std_val = torch.sqrt((self.std_val**2 * (self.total - num_obs) + new_std**2 * num_obs) / self.total) + + # Update mean and variance + new_mean = (self.total * self.mean_val + target.sum(dim=0)) / (self.total + num_obs) + self.total += num_obs + new_var = ((target - new_mean) * (target - self.mean_val)).sum(dim=0) + self.mean_val = new_mean + self.var_val += new_var def compute(self) -> Tensor: """Computes NRMSE over state. @@ -182,7 +188,7 @@ def compute(self) -> Tensor: """ if (self.num_outputs == 1 and self.mean_val.numel() > 1) or (self.num_outputs > 1 and self.mean_val.ndim > 1): denom = _final_aggregation( - self.min_val, self.max_val, self.mean_val, self.std_val, self.total, self.normalization + self.min_val, self.max_val, self.mean_val, self.var_val, self.total, self.normalization ) else: if self.normalization == "mean": @@ -190,7 +196,7 @@ def compute(self) -> Tensor: elif self.normalization == "range": denom = self.max_val - self.min_val else: - denom = self.std_val + denom = torch.sqrt(self.var_val / self.total) return _normalized_root_mean_squared_error_compute(self.sum_squared_error, self.total, denom) def plot( diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index beabfdb7e4c..2d4e52c285c 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -125,20 +125,17 @@ def _reference_normalized_root_mean_squared_error( y_true = y_true.flatten() y_pred = y_pred.flatten() evaluator = RegressionMetric(y_true, y_pred) if normalization == "range" else RegressionMetric(y_pred, y_true) - arg_mapping = { - "mean": 1, - "range": 2, - "std": 4, - } - + arg_mapping = {"mean": 1, "range": 2, "std": 4} return evaluator.normalized_root_mean_square_error(model=arg_mapping[normalization]) def _reference_weighted_mean_abs_percentage_error(target, preds): + """Reference implementation of Weighted Mean Absolute Percentage Error (WMAPE) metric.""" return np.sum(np.abs(target - preds)) / np.sum(np.abs(target)) def _single_target_ref_wrapper(preds, target, sk_fn, metric_args): + """Reference implementation of single-target metrics.""" sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() @@ -152,6 +149,7 @@ def _single_target_ref_wrapper(preds, target, sk_fn, metric_args): def _multi_target_ref_wrapper(preds, target, sk_fn, metric_args): + """Reference implementation of multi-target metrics.""" sk_preds = preds.view(-1, NUM_TARGETS).numpy() sk_target = target.view(-1, NUM_TARGETS).numpy() sk_kwargs = {"multioutput": "raw_values"} if metric_args and "num_outputs" in metric_args else {} @@ -192,41 +190,47 @@ def _multi_target_ref_wrapper(preds, target, sk_fn, metric_args): _reference_weighted_mean_abs_percentage_error, {}, ), - ( + pytest.param( NormalizedRootMeanSquaredError, normalized_root_mean_squared_error, _reference_normalized_root_mean_squared_error, {"normalization": "mean", "num_outputs": 1}, + id="nrmse_singleoutput_mean", ), - ( + pytest.param( NormalizedRootMeanSquaredError, normalized_root_mean_squared_error, _reference_normalized_root_mean_squared_error, {"normalization": "range", "num_outputs": 1}, + id="nrmse_singleoutput_range", ), - ( + pytest.param( NormalizedRootMeanSquaredError, normalized_root_mean_squared_error, _reference_normalized_root_mean_squared_error, {"normalization": "std", "num_outputs": 1}, + id="nrmse_singleoutput_std", ), - ( + pytest.param( NormalizedRootMeanSquaredError, normalized_root_mean_squared_error, _reference_normalized_root_mean_squared_error, {"normalization": "mean", "num_outputs": NUM_TARGETS}, + id="nrmse_multioutput_mean", ), - ( + pytest.param( NormalizedRootMeanSquaredError, normalized_root_mean_squared_error, _reference_normalized_root_mean_squared_error, {"normalization": "range", "num_outputs": NUM_TARGETS}, + id="nrmse_multioutput_range", ), - ( + pytest.param( NormalizedRootMeanSquaredError, normalized_root_mean_squared_error, _reference_normalized_root_mean_squared_error, {"normalization": "std", "num_outputs": NUM_TARGETS}, + id="nrmse_multioutput_std", ), ], ) @@ -267,6 +271,8 @@ def test_mean_error_differentiability( self, preds, target, ref_metric, metric_class, metric_functional, sk_fn, metric_args ): """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + if metric_args and "num_outputs" in metric_args and preds.ndim < 3: + pytest.skip("Test only runs for multi-output setting") self.run_differentiability_test( preds=preds, target=target, From 84604ad1cb4a5fb06e88958bcd81671df88935b4 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 31 May 2024 15:43:31 +0200 Subject: [PATCH 08/30] try fixing docs --- requirements/_docs.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/requirements/_docs.txt b/requirements/_docs.txt index 05f27bee3c1..36eadfa04a2 100644 --- a/requirements/_docs.txt +++ b/requirements/_docs.txt @@ -23,8 +23,6 @@ pydantic > 1.0.0, < 3.0.0 -r image.txt -r multimodal.txt -r text.txt --r text_test.txt --r regression_test.txt # Gallery extra requirements # -------------------------- From 0271d913419e74da60c6d7f3aaa15851193e99d7 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 31 May 2024 16:02:54 +0200 Subject: [PATCH 09/30] fix mypy --- src/torchmetrics/functional/regression/nrmse.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/regression/nrmse.py b/src/torchmetrics/functional/regression/nrmse.py index c7ce025902f..5633329cfa3 100644 --- a/src/torchmetrics/functional/regression/nrmse.py +++ b/src/torchmetrics/functional/regression/nrmse.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Tuple, Union import torch from torch import Tensor @@ -46,7 +46,9 @@ def _normalized_root_mean_squared_error_update( return sum_squared_error, num_obs, denom -def _normalized_root_mean_squared_error_compute(sum_squared_error: Tensor, num_obs: int, denom: Tensor) -> Tensor: +def _normalized_root_mean_squared_error_compute( + sum_squared_error: Tensor, num_obs: Union[int, Tensor], denom: Tensor +) -> Tensor: """Calculates RMSE and normalizes it.""" rmse = torch.sqrt(sum_squared_error / num_obs) return rmse / denom From d279774717923f9c38a75816e0f1d0200c9dc251 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 31 May 2024 16:03:16 +0200 Subject: [PATCH 10/30] fix naming of file --- requirements/{reqression_test.txt => regression_test.txt} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename requirements/{reqression_test.txt => regression_test.txt} (100%) diff --git a/requirements/reqression_test.txt b/requirements/regression_test.txt similarity index 100% rename from requirements/reqression_test.txt rename to requirements/regression_test.txt From 42ffbeaf165c68117d9b375a722eb2c837c3cbe3 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 5 Jun 2024 10:02:26 +0200 Subject: [PATCH 11/30] Apply suggestions from code review --- src/torchmetrics/regression/nrmse.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index 8f210281e6e..7ba065c1c59 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -108,6 +108,7 @@ class NormalizedRootMeanSquaredError(Metric): tensor(0.1919) >>> nrmse = NormalizedRootMeanSquaredError(normalization="range") >>> nrmse(preds, target) + tensor(0.1701) Example:: Multioutput normalized root mean squared error computation: From 069354531ac911ab9e87e58e890d95552db58f27 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 11 Oct 2024 11:25:51 +0200 Subject: [PATCH 12/30] add l2 option --- .../functional/regression/nrmse.py | 20 ++++++++++++------- src/torchmetrics/regression/nrmse.py | 15 +++++++------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/torchmetrics/functional/regression/nrmse.py b/src/torchmetrics/functional/regression/nrmse.py index 5633329cfa3..29ceac2b169 100644 --- a/src/torchmetrics/functional/regression/nrmse.py +++ b/src/torchmetrics/functional/regression/nrmse.py @@ -21,7 +21,7 @@ def _normalized_root_mean_squared_error_update( - preds: Tensor, target: Tensor, num_outputs: int, normalization: Literal["mean", "range", "std"] = "mean" + preds: Tensor, target: Tensor, num_outputs: int, normalization: Literal["mean", "range", "std", "l2"] = "mean" ) -> Tuple[Tensor, int, Tensor]: """Updates and returns the sum of squared errors and the number of observations for NRMSE computation. @@ -29,7 +29,7 @@ def _normalized_root_mean_squared_error_update( preds: Predicted tensor target: Ground truth tensor num_outputs: Number of outputs in multioutput setting - normalization: type of normalization to be applied. Choose from "mean", "range", "std" + normalization: type of normalization to be applied. Choose from "mean", "range", "std", "l2" """ sum_squared_error, num_obs = _mean_squared_error_update(preds, target, num_outputs) @@ -41,8 +41,12 @@ def _normalized_root_mean_squared_error_update( denom = torch.max(target, dim=0).values - torch.min(target, dim=0).values elif normalization == "std": denom = torch.std(target, correction=0, dim=0) + elif normalization == "l2": + denom = torch.norm(target, p=2, dim=0) else: - raise ValueError(f"Argument `normalization` should be either 'mean', 'range' or 'std', but got {normalization}") + raise ValueError( + f"Argument `normalization` should be either 'mean', 'range', 'std' or 'l2' but got {normalization}" + ) return sum_squared_error, num_obs, denom @@ -57,7 +61,7 @@ def _normalized_root_mean_squared_error_compute( def normalized_root_mean_squared_error( preds: Tensor, target: Tensor, - normalization: Literal["mean", "range", "std"] = "mean", + normalization: Literal["mean", "range", "std", "l2"] = "mean", num_outputs: int = 1, ) -> Tensor: """Calculates the `Normalized Root Mean Squared Error`_ (NRMSE) also know as scatter index. @@ -65,9 +69,9 @@ def normalized_root_mean_squared_error( Args: preds: estimated labels target: ground truth labels - normalization: type of normalization to be applied. Choose from "mean", "range", "std" which corresponds to - normalizing the RMSE by the mean of the target, the range of the target or the standard deviation of the - target. + normalization: type of normalization to be applied. Choose from "mean", "range", "std", "l2" which corresponds + to normalizing the RMSE by the mean of the target, the range of the target, the standard deviation of the + target or the L2 norm of the target. num_outputs: Number of outputs in multioutput setting Return: @@ -84,6 +88,8 @@ def normalized_root_mean_squared_error( tensor(0.2500) >>> normalized_root_mean_squared_error(preds, target, normalization="std") tensor(0.6030) + >>> normalized_root_mean_squared_error(preds, target, normalization="l2") + tensor(0.5000) Example (multioutput): >>> import torch diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index 7ba065c1c59..ded0daf1b27 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -78,7 +78,8 @@ class NormalizedRootMeanSquaredError(Metric): \text{NRMSE} = \frac{\text{RMSE}}{\text{denom}} where RMSE is the root mean squared error and `denom` is the normalization factor. The normalization factor can be - either be the mean, range or standard deviation of the target, which can be set using the `normalization` argument. + either be the mean, range, standard deviation or L2 norm of the target, which can be set using the `normalization` + argument. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -90,9 +91,9 @@ class NormalizedRootMeanSquaredError(Metric): - ``nrmse`` (:class:`~torch.Tensor`): A tensor with the mean squared error Args: - normalization: type of normalization to be applied. Choose from "mean", "range", "std" which corresponds to - normalizing the RMSE by the mean of the target, the range of the target or the standard deviation of the - target. + normalization: type of normalization to be applied. Choose from "mean", "range", "std", "l2" which corresponds + to normalizing the RMSE by the mean of the target, the range of the target, the standard deviation of the + target or the L2 norm of the target. num_outputs: Number of outputs in multioutput setting kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -137,15 +138,15 @@ class NormalizedRootMeanSquaredError(Metric): def __init__( self, - normalization: Literal["mean", "range", "std"] = "mean", + normalization: Literal["mean", "range", "std", "l2"] = "mean", num_outputs: int = 1, **kwargs: Any, ) -> None: super().__init__(**kwargs) - if normalization not in ("mean", "range", "std"): + if normalization not in ("mean", "range", "std", "l2"): raise ValueError( - f"Argument `normalization` should be either 'mean', 'range' or 'std', but got {normalization}" + f"Argument `normalization` should be either 'mean', 'range', 'std' or 'l2', but got {normalization}" ) self.normalization = normalization From 52672454f4f83c49542e4a1fef43700bdaee7736 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 11 Oct 2024 11:29:22 +0200 Subject: [PATCH 13/30] added tests for argument error validation --- tests/unittests/regression/test_mean_error.py | 40 +++++++++++++++++-- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index 3e8224ac7ae..d738f7e3124 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -163,14 +163,14 @@ def _multi_target_ref_wrapper(preds, target, sk_fn, metric_args): @pytest.mark.parametrize( - "preds, target, ref_metric", + ("preds", "target", "ref_metric"), [ (_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_wrapper), (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_wrapper), ], ) @pytest.mark.parametrize( - "metric_class, metric_functional, sk_fn, metric_args", + ("metric_class", "metric_functional", "sk_fn", "metric_args"), [ (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True}), (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": False}), @@ -212,6 +212,13 @@ def _multi_target_ref_wrapper(preds, target, sk_fn, metric_args): {"normalization": "std", "num_outputs": 1}, id="nrmse_singleoutput_std", ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "l2", "num_outputs": 1}, + id="nrmse_multioutput_l2", + ), pytest.param( NormalizedRootMeanSquaredError, normalized_root_mean_squared_error, @@ -233,6 +240,13 @@ def _multi_target_ref_wrapper(preds, target, sk_fn, metric_args): {"normalization": "std", "num_outputs": NUM_TARGETS}, id="nrmse_multioutput_std", ), + pytest.param( + NormalizedRootMeanSquaredError, + normalized_root_mean_squared_error, + _reference_normalized_root_mean_squared_error, + {"normalization": "l2", "num_outputs": NUM_TARGETS}, + id="nrmse_multioutput_l2", + ), ], ) class TestMeanError(MetricTester): @@ -309,10 +323,30 @@ def test_mean_error_half_gpu(self, preds, target, ref_metric, metric_class, metr @pytest.mark.parametrize( - "metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError, MeanAbsolutePercentageError] + "metric_class", + [ + MeanSquaredError, + MeanAbsoluteError, + MeanSquaredLogError, + MeanAbsolutePercentageError, + NormalizedRootMeanSquaredError, + ], ) def test_error_on_different_shape(metric_class): """Test that error is raised on different shapes of input.""" metric = metric_class() with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): metric(torch.randn(100), torch.randn(50)) + + +@pytest.mark.parametrize( + ("metric_class", "arguments", "error_msg"), + [ + (MeanSquaredError, {"squared": "something"}, "Expected argument `squared` to be a boolean.*"), + (NormalizedRootMeanSquaredError, {"normalization": "something"}, "Argument `normalization` should be either.*"), + ], +) +def test_error_on_wrong_extra_args(metric_class, arguments, error_msg): + """Test that error is raised on wrong extra arguments.""" + with pytest.raises(ValueError, match=error_msg): + metric_class(**arguments) From b967aa0c69ae70d664876216b759c429f417adeb Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 11 Oct 2024 11:32:01 +0200 Subject: [PATCH 14/30] fix doctest --- src/torchmetrics/functional/regression/nrmse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/regression/nrmse.py b/src/torchmetrics/functional/regression/nrmse.py index 29ceac2b169..52cae36adb0 100644 --- a/src/torchmetrics/functional/regression/nrmse.py +++ b/src/torchmetrics/functional/regression/nrmse.py @@ -89,7 +89,7 @@ def normalized_root_mean_squared_error( >>> normalized_root_mean_squared_error(preds, target, normalization="std") tensor(0.6030) >>> normalized_root_mean_squared_error(preds, target, normalization="l2") - tensor(0.5000) + tensor(0.1667) Example (multioutput): >>> import torch From 4d57398b95affad01be666d4eefc138c2f3e7da7 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 11 Oct 2024 11:33:43 +0200 Subject: [PATCH 15/30] fix plotting code + test --- src/torchmetrics/regression/nrmse.py | 8 ++++---- tests/unittests/utilities/test_plot.py | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index ded0daf1b27..b97649b700e 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -223,8 +223,8 @@ def plot( >>> from torch import randn >>> # Example plotting a single value - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import NormalizedRootMeanSquaredError + >>> metric = NormalizedRootMeanSquaredError() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot() @@ -233,8 +233,8 @@ def plot( >>> from torch import randn >>> # Example plotting multiple values - >>> from torchmetrics.regression import MeanSquaredError - >>> metric = MeanSquaredError() + >>> from torchmetrics.regression import NormalizedRootMeanSquaredError + >>> metric = NormalizedRootMeanSquaredError() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 465ed2d55e5..add2b78b1b8 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -130,6 +130,7 @@ MeanSquaredError, MeanSquaredLogError, MinkowskiDistance, + NormalizedRootMeanSquaredError, PearsonCorrCoef, R2Score, RelativeSquaredError, @@ -476,6 +477,7 @@ pytest.param(MeanAbsoluteError, _rand_input, _rand_input, id="mean absolute error"), pytest.param(MeanAbsolutePercentageError, _rand_input, _rand_input, id="mean absolute percentage error"), pytest.param(partial(MinkowskiDistance, p=3), _rand_input, _rand_input, id="minkowski distance"), + pytest.param(NormalizedRootMeanSquaredError, _rand_input, _rand_input, id="normalized root mean squared error"), pytest.param(PearsonCorrCoef, _rand_input, _rand_input, id="pearson corr coef"), pytest.param(R2Score, _rand_input, _rand_input, id="r2 score"), pytest.param(RelativeSquaredError, _rand_input, _rand_input, id="relative squared error"), From 2d26828737f960eaf03885c29cbea2524e9b0554 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 11 Oct 2024 11:49:09 +0200 Subject: [PATCH 16/30] fix part of tests --- src/torchmetrics/regression/nrmse.py | 9 +++- tests/unittests/regression/test_mean_error.py | 53 ++++++++++++++----- 2 files changed, 48 insertions(+), 14 deletions(-) diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index b97649b700e..ddd232195f0 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -133,6 +133,7 @@ class NormalizedRootMeanSquaredError(Metric): total: Tensor min_val: Tensor max_val: Tensor + target_squared: Tensor mean_val: Tensor var_val: Tensor @@ -160,6 +161,7 @@ def __init__( self.add_state("max_val", default=-float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None) self.add_state("mean_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) self.add_state("var_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("target_squared", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets. @@ -171,9 +173,10 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.sum_squared_error += sum_squared_error target = target.view(-1) if self.num_outputs == 1 else target - # Update min and max + # Update min and max and target squared self.min_val = torch.minimum(target.min(dim=0).values, self.min_val) self.max_val = torch.maximum(target.max(dim=0).values, self.max_val) + self.target_squared += (target**2).sum(dim=0) # Update mean and variance new_mean = (self.total * self.mean_val + target.sum(dim=0)) / (self.total + num_obs) @@ -197,8 +200,10 @@ def compute(self) -> Tensor: denom = self.mean_val elif self.normalization == "range": denom = self.max_val - self.min_val - else: + elif self.normalization == "std": denom = torch.sqrt(self.var_val / self.total) + else: + denom = torch.sqrt(self.target_squared) return _normalized_root_mean_squared_error_compute(self.sum_squared_error, self.total, denom) def plot( diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index d738f7e3124..e0df6fc84ec 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -124,9 +124,12 @@ def _reference_normalized_root_mean_squared_error( if num_outputs == 1: y_true = y_true.flatten() y_pred = y_pred.flatten() - evaluator = RegressionMetric(y_true, y_pred) if normalization == "range" else RegressionMetric(y_pred, y_true) - arg_mapping = {"mean": 1, "range": 2, "std": 4} - return evaluator.normalized_root_mean_square_error(model=arg_mapping[normalization]) + if normalization != "l2": + evaluator = RegressionMetric(y_true, y_pred) if normalization == "range" else RegressionMetric(y_pred, y_true) + arg_mapping = {"mean": 1, "range": 2, "std": 4} + return evaluator.normalized_root_mean_square_error(model=arg_mapping[normalization]) + # for l2 normalization we do not have a reference implementation + return np.sqrt(np.mean(np.square(y_true - y_pred), axis=0)) / np.linalg.norm(y_true, axis=0) def _reference_weighted_mean_abs_percentage_error(target, preds): @@ -172,24 +175,50 @@ def _multi_target_ref_wrapper(preds, target, sk_fn, metric_args): @pytest.mark.parametrize( ("metric_class", "metric_functional", "sk_fn", "metric_args"), [ - (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True}), - (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": False}), - (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True, "num_outputs": NUM_TARGETS}), - (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}), - (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {"num_outputs": NUM_TARGETS}), - (MeanAbsolutePercentageError, mean_absolute_percentage_error, sk_mean_abs_percentage_error, {}), - ( + pytest.param( + MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True}, id="mse_singleoutput" + ), + pytest.param( + MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": False}, id="rmse_singleoutput" + ), + pytest.param( + MeanSquaredError, + mean_squared_error, + sk_mean_squared_error, + {"squared": True, "num_outputs": NUM_TARGETS}, + id="mse_multioutput", + ), + pytest.param(MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}, id="mae_singleoutput"), + pytest.param( + MeanAbsoluteError, + mean_absolute_error, + sk_mean_absolute_error, + {"num_outputs": NUM_TARGETS}, + id="mae_multioutput", + ), + pytest.param( + MeanAbsolutePercentageError, + mean_absolute_percentage_error, + sk_mean_abs_percentage_error, + {}, + id="mape_singleoutput", + ), + pytest.param( SymmetricMeanAbsolutePercentageError, symmetric_mean_absolute_percentage_error, _reference_symmetric_mape, {}, + id="symmetric_mean_absolute_percentage_error", ), - (MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}), - ( + pytest.param( + MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}, id="mean_squared_log_error" + ), + pytest.param( WeightedMeanAbsolutePercentageError, weighted_mean_absolute_percentage_error, _reference_weighted_mean_abs_percentage_error, {}, + id="weighted_mean_absolute_percentage_error", ), pytest.param( NormalizedRootMeanSquaredError, From 6f7c821f83f65a9ce6c42f6697083cc6ede2af62 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 11 Oct 2024 12:03:02 +0200 Subject: [PATCH 17/30] fix implementation --- src/torchmetrics/regression/nrmse.py | 58 ++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index ddd232195f0..60f3f717bfd 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -33,40 +33,58 @@ def _final_aggregation( min_val: Tensor, max_val: Tensor, mean_val: Tensor, - std_val: Tensor, + var_val: Tensor, + target_squared: Tensor, total: Tensor, - normalization: Literal["mean", "range", "std"] = "mean", + normalization: Literal["mean", "range", "std", "l2"] = "mean", ) -> Tensor: + """In the case of multiple devices we need to aggregate the statistics from the different devices.""" if len(min_val) == 1: if normalization == "mean": return mean_val[0] if normalization == "range": return max_val[0] - min_val[0] if normalization == "std": - return std_val[0] - - min_val_1, max_val_1, mean_val_1, std_val_1, total_1 = min_val[0], max_val[0], mean_val[0], std_val[0], total[0] + return var_val[0] + if normalization == "l2": + return target_squared[0] + + min_val_1, max_val_1, mean_val_1, var_val_1, target_squared_1, total_1 = ( + min_val[0], + max_val[0], + mean_val[0], + var_val[0], + target_squared[0], + total[0], + ) for i in range(1, len(min_val)): - min_val_2, max_val_2, mean_val_2, std_val_2, total_2 = min_val[i], max_val[i], mean_val[i], std_val[i], total[i] + min_val_2, max_val_2, mean_val_2, var_val_2, target_squared_2, total_2 = ( + min_val[i], + max_val[i], + mean_val[i], + var_val[i], + target_squared[i], + total[i], + ) total = total_1 + total_2 mean = (total_1 * mean_val_1 + total_2 * mean_val_2) / total - std = torch.sqrt( - ( - std_val_1**2 * (total_1 - 1) - + std_val_2**2 * (total_2 - 1) - + (mean_val_1 - mean) ** 2 * total_1 - + (mean_val_2 - mean) ** 2 * total_2 - ) - / (total - 1) - ) + var = ( + (total_1 - 1) * var_val_1 + + (total_2 - 1) * var_val_2 + + ((mean_val_1 - mean) ** 2) * total_1 + + ((mean_val_2 - mean) ** 2) * total_2 + ) / (total - 1) min_val = torch.min(min_val_1, min_val_2) max_val = torch.max(max_val_1, max_val_2) + target_squared = target_squared_1 + target_squared_2 if normalization == "mean": return mean if normalization == "range": return max_val - min_val - return std + if normalization == "std": + return var + return target_squared class NormalizedRootMeanSquaredError(Metric): @@ -193,7 +211,13 @@ def compute(self) -> Tensor: """ if (self.num_outputs == 1 and self.mean_val.numel() > 1) or (self.num_outputs > 1 and self.mean_val.ndim > 1): denom = _final_aggregation( - self.min_val, self.max_val, self.mean_val, self.var_val, self.total, self.normalization + min_val=self.min_val, + max_val=self.max_val, + mean_val=self.mean_val, + var_val=self.var_val, + target_squared=self.target_squared, + total=self.total, + normalization=self.normalization, ) else: if self.normalization == "mean": From 83e2f779c7847050eae57a28387a5be8b05a3b4b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 11 Oct 2024 12:06:16 +0200 Subject: [PATCH 18/30] fix doctest --- src/torchmetrics/regression/nrmse.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index 60f3f717bfd..3d5a524df02 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -134,11 +134,11 @@ class NormalizedRootMeanSquaredError(Metric): >>> import torch >>> from torchmetrics import NormalizedRootMeanSquaredError - >>> target = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) - >>> preds = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]) - >>> nrmse = NormalizedRootMeanSquaredError(num_outputs=3) + >>> preds = torch.tensor([[0., 1], [2, 3], [4, 5], [6, 7]]) + >>> target = torch.tensor([[0., 1], [3, 3], [4, 5], [8, 9]]) + >>> nrmse = NormalizedRootMeanSquaredError(num_outputs=2) >>> nrmse(preds, target) - tensor([1., 1., 1.]) + tensor([0.2981, 0.2222]) """ From 01d19b692d0f25539f52782f10b6bbec6e31bafe Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 11 Oct 2024 14:02:49 +0200 Subject: [PATCH 19/30] skip failing tests --- tests/unittests/regression/test_mean_error.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index e0df6fc84ec..2cebe7f9bb1 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -343,6 +343,10 @@ def test_mean_error_half_cpu(self, preds, target, ref_metric, metric_class, metr # WeightedMeanAbsolutePercentageError half + cpu does not work due to missing support in torch.clamp pytest.xfail("WeightedMeanAbsolutePercentageError metric does not support cpu + half precision") + if metric_class == NormalizedRootMeanSquaredError: + # NormalizedRootMeanSquaredError half + cpu does not work due to missing support in torch.sqrt + pytest.xfail("NormalizedRootMeanSquaredError metric does not support cpu + half precision") + self.run_precision_test_cpu(preds, target, metric_class, metric_functional) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") From b7a116dec703c6aa7fe0399d62c5f7bba6d18e01 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 11 Oct 2024 15:37:36 +0200 Subject: [PATCH 20/30] fix ddp testing --- src/torchmetrics/regression/nrmse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index 3d5a524df02..2ec53e448ea 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -174,7 +174,7 @@ def __init__( self.num_outputs = num_outputs self.add_state("sum_squared_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum") - self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx=None) self.add_state("min_val", default=float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None) self.add_state("max_val", default=-float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None) self.add_state("mean_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) From 604380198b1fc3b03d15fd60a1ddd60444b00999 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 12 Oct 2024 15:45:41 +0200 Subject: [PATCH 21/30] try fixing ddp issues, cannot reproduce locally --- src/torchmetrics/regression/nrmse.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index 2ec53e448ea..2ec210085a0 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -14,7 +14,7 @@ from typing import Any, Optional, Sequence, Union import torch -from torch import Tensor, tensor +from torch import Tensor from typing_extensions import Literal from torchmetrics.functional.regression.nrmse import ( @@ -174,7 +174,7 @@ def __init__( self.num_outputs = num_outputs self.add_state("sum_squared_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum") - self.add_state("total", default=tensor(0), dist_reduce_fx=None) + self.add_state("total", default=torch.zeros(num_outputs), dist_reduce_fx=None) self.add_state("min_val", default=float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None) self.add_state("max_val", default=-float("Inf") * torch.ones(self.num_outputs), dist_reduce_fx=None) self.add_state("mean_val", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) @@ -228,6 +228,7 @@ def compute(self) -> Tensor: denom = torch.sqrt(self.var_val / self.total) else: denom = torch.sqrt(self.target_squared) + print(self.sum_squared_error, self.total, denom) return _normalized_root_mean_squared_error_compute(self.sum_squared_error, self.total, denom) def plot( From 2e6e5e3894301ebcb52be20d1746d7fa4f096378 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 12 Oct 2024 17:17:24 +0200 Subject: [PATCH 22/30] fix doctests --- src/torchmetrics/regression/nrmse.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index 2ec210085a0..ffe326b978e 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -120,8 +120,8 @@ class NormalizedRootMeanSquaredError(Metric): >>> import torch >>> from torchmetrics import NormalizedRootMeanSquaredError - >>> target = tensor([2.5, 5.0, 4.0, 8.0]) - >>> preds = tensor([3.0, 5.0, 2.5, 7.0]) + >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) + >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) >>> nrmse = NormalizedRootMeanSquaredError(normalization="mean") >>> nrmse(preds, target) tensor(0.1919) From 62f240af320a4bcb62b8d1e1880e4a1debc44d08 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 14 Oct 2024 09:38:57 +0200 Subject: [PATCH 23/30] remove print --- src/torchmetrics/regression/nrmse.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index ffe326b978e..4af3454097c 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -228,7 +228,6 @@ def compute(self) -> Tensor: denom = torch.sqrt(self.var_val / self.total) else: denom = torch.sqrt(self.target_squared) - print(self.sum_squared_error, self.total, denom) return _normalized_root_mean_squared_error_compute(self.sum_squared_error, self.total, denom) def plot( From da588153eef6a05f38541b08efd3337d2e5b933e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 15 Oct 2024 11:34:47 +0200 Subject: [PATCH 24/30] add debug print temp --- tests/unittests/_helpers/testers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index c5a69077f3c..49c6aafe8d9 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -32,6 +32,8 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O """Recursively assert that two results are within a certain tolerance.""" # single output compare if isinstance(tm_result, Tensor): + print("TM", tm_result) + print("REF", ref_result) assert np.allclose( tm_result.detach().cpu().numpy() if isinstance(tm_result, Tensor) else tm_result, ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result, From ec7f070f274f21dbe2ab5494722124608a67779e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 15 Oct 2024 15:25:07 +0200 Subject: [PATCH 25/30] lower atol --- tests/unittests/_helpers/testers.py | 2 -- tests/unittests/regression/test_mean_error.py | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index 49c6aafe8d9..c5a69077f3c 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -32,8 +32,6 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O """Recursively assert that two results are within a certain tolerance.""" # single output compare if isinstance(tm_result, Tensor): - print("TM", tm_result) - print("REF", ref_result) assert np.allclose( tm_result.detach().cpu().numpy() if isinstance(tm_result, Tensor) else tm_result, ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result, diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index 2cebe7f9bb1..38c86817184 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -281,6 +281,8 @@ def _multi_target_ref_wrapper(preds, target, sk_fn, metric_args): class TestMeanError(MetricTester): """Test class for `MeanError` metric.""" + atol = 1e-5 + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_mean_error_class( self, preds, target, ref_metric, metric_class, metric_functional, sk_fn, metric_args, ddp From 47c294ed605b722a7dc350c713ac4bedd77c3895 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Fri, 18 Oct 2024 13:38:46 +0200 Subject: [PATCH 26/30] readd debug print --- tests/unittests/_helpers/testers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index c5a69077f3c..49c6aafe8d9 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -32,6 +32,8 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O """Recursively assert that two results are within a certain tolerance.""" # single output compare if isinstance(tm_result, Tensor): + print("TM", tm_result) + print("REF", ref_result) assert np.allclose( tm_result.detach().cpu().numpy() if isinstance(tm_result, Tensor) else tm_result, ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result, From c5e7b2615234b7a1e5bc3816c25485a894514b7b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 18 Oct 2024 14:29:35 +0200 Subject: [PATCH 27/30] more debugging --- tests/unittests/_helpers/testers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index 49c6aafe8d9..ec35aeb160b 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -32,14 +32,12 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O """Recursively assert that two results are within a certain tolerance.""" # single output compare if isinstance(tm_result, Tensor): - print("TM", tm_result) - print("REF", ref_result) assert np.allclose( tm_result.detach().cpu().numpy() if isinstance(tm_result, Tensor) else tm_result, ref_result.detach().cpu().numpy() if isinstance(ref_result, Tensor) else ref_result, atol=atol, equal_nan=True, - ) + ), f"TM: {tm_result}, REF: {ref_result}" # multi output compare elif isinstance(tm_result, Sequence): for pl_res, ref_res in zip(tm_result, ref_result): From 0c6718ae0cfa9a0950ab132f2fff1c9dd9b9f8e0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 19 Oct 2024 09:21:28 +0200 Subject: [PATCH 28/30] improve testing instructions --- tests/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/README.md b/tests/README.md index 7f5cbd4e98a..6fce25567ef 100644 --- a/tests/README.md +++ b/tests/README.md @@ -7,16 +7,16 @@ the following command in the root directory of the project: pip install . -r requirements/_devel.txt ``` -Then for windows users, to execute the tests (unit tests and integration tests) run the following command (will only run non-DDP tests): +Then for Windows users, to execute the tests (unit tests and integration tests) run the following command (will only run non-DDP tests): ```bash pytest tests/ ``` -For linux/Mac users you will need to provide the `-m` argument to indicate if `ddp` tests should also be executed: +For Linux/Mac users you will need to provide the `-m` argument to indicate if `ddp` tests should also be executed: ```bash -pytest -m DDP tests/ # to run only DDP tests +USE_PYTEST_POOL="1" pytest -m DDP tests/ # to run only DDP tests pytest -m "not DDP" tests/ # to run all tests except DDP tests ``` From 67925cb066935a771d428872a5c68113192b81ce Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 19 Oct 2024 09:22:53 +0200 Subject: [PATCH 29/30] fix math --- src/torchmetrics/regression/nrmse.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/regression/nrmse.py b/src/torchmetrics/regression/nrmse.py index 4af3454097c..62562803542 100644 --- a/src/torchmetrics/regression/nrmse.py +++ b/src/torchmetrics/regression/nrmse.py @@ -66,14 +66,18 @@ def _final_aggregation( target_squared[i], total[i], ) + # update total and mean total = total_1 + total_2 mean = (total_1 * mean_val_1 + total_2 * mean_val_2) / total - var = ( - (total_1 - 1) * var_val_1 - + (total_2 - 1) * var_val_2 - + ((mean_val_1 - mean) ** 2) * total_1 - + ((mean_val_2 - mean) ** 2) * total_2 - ) / (total - 1) + + # update variance + _temp = (total_1 + 1) * mean - total_1 * mean_val_1 + var_val_1 += (_temp - mean_val_1) * (_temp - mean) - (_temp - mean) ** 2 + _temp = (total_2 + 1) * mean - total_2 * mean_val_2 + var_val_2 += (_temp - mean_val_2) * (_temp - mean) - (_temp - mean) ** 2 + var = var_val_1 + var_val_2 + + # update min and max and target squared min_val = torch.min(min_val_1, min_val_2) max_val = torch.max(max_val_1, max_val_2) target_squared = target_squared_1 + target_squared_2 @@ -83,8 +87,8 @@ def _final_aggregation( if normalization == "range": return max_val - min_val if normalization == "std": - return var - return target_squared + return (var / total).sqrt() + return target_squared.sqrt() class NormalizedRootMeanSquaredError(Metric): @@ -219,6 +223,7 @@ def compute(self) -> Tensor: total=self.total, normalization=self.normalization, ) + total = self.total.squeeze().sum(dim=0) else: if self.normalization == "mean": denom = self.mean_val @@ -228,7 +233,8 @@ def compute(self) -> Tensor: denom = torch.sqrt(self.var_val / self.total) else: denom = torch.sqrt(self.target_squared) - return _normalized_root_mean_squared_error_compute(self.sum_squared_error, self.total, denom) + total = self.total + return _normalized_root_mean_squared_error_compute(self.sum_squared_error, total, denom) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None From d160e931615eb7adab50079d21e13ff51ce51c02 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Mon, 21 Oct 2024 20:04:02 +0200 Subject: [PATCH 30/30] link --- docs/source/conf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index 5442f9641a9..81f842e7a12 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -447,6 +447,9 @@ def linkcode_resolve(domain, info) -> Optional[str]: # noqa: ANN001 "https://aclanthology.org/W17-4770", # A wavelet transform method to merge Landsat TM and SPOT panchromatic data "https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013", + # Improved normalization of time-lapse seismic data using normalized root mean square repeatability data ... + # ... to improve automatic production and seismic history matching in the Nelson field + "https://onlinelibrary.wiley.com/doi/abs/10.1111/1365-2478.12109", # todo: these links seems to be unstable, referring to .devcontainer "https://code.visualstudio.com", "https://code.visualstudio.com/.*",