diff --git a/tests/classification/__init__.py b/tests/classification/__init__.py index e69de29bb2d..6acfc5730db 100644 --- a/tests/classification/__init__.py +++ b/tests/classification/__init__.py @@ -0,0 +1,13 @@ +from torchmetrics import Metric + + +class MetricWrapper(Metric): + def __init__(self, metric): + super().__init__() + self.metric = metric + + def update(self, *args, **kwargs): + self.metric.update(*args, **kwargs) + + def compute(self, *args, **kwargs): + return self.metric.compute(*args, **kwargs) diff --git a/tests/classification/inputs.py b/tests/classification/inputs.py index df7cfe4b496..635f99957f1 100644 --- a/tests/classification/inputs.py +++ b/tests/classification/inputs.py @@ -126,3 +126,13 @@ def generate_plausible_inputs_binary(num_batches=NUM_BATCHES, batch_size=BATCH_S _temp[_temp == _class_remove] = _class_replace _input_multiclass_with_missing_class = Input(_temp.clone(), _temp.clone()) + + +_negmetric_noneavg = { + "pred1": torch.tensor([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]), + "target1": torch.tensor([0, 1]), + "res1": torch.tensor([0.0, 0.0, float("nan")]), + "pred2": torch.tensor([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]), + "target2": torch.tensor([0, 2]), + "res2": torch.tensor([0.0, 0.0, 0.0]), +} diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index 20b857feaec..40c27f20368 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -20,6 +20,7 @@ from sklearn.metrics import accuracy_score as sk_accuracy from torch import tensor +from tests.classification import MetricWrapper from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob from tests.classification.inputs import _input_multiclass as _input_mcls from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits @@ -32,6 +33,7 @@ from tests.classification.inputs import _input_multilabel_multidim as _input_mlmd from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from tests.classification.inputs import _negmetric_noneavg from tests.helpers import seed_all from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester from torchmetrics import Accuracy @@ -438,3 +440,11 @@ def test_negative_ignore_index(preds, target, ignore_index, result): # Test functional with pytest.raises(ValueError, match="^[The `target` has to be a non-negative tensor.]"): acc_score = accuracy(preds, target, num_classes=num_classes, ignore_index=ignore_index) + + +def test_negmetric_noneavg(noneavg=_negmetric_noneavg): + acc = MetricWrapper(Accuracy(average="none", num_classes=noneavg["pred1"].shape[1])) + result1 = acc(noneavg["pred1"], noneavg["target1"]) + assert torch.allclose(noneavg["res1"], result1, equal_nan=True) + result2 = acc(noneavg["pred2"], noneavg["target2"]) + assert torch.allclose(noneavg["res2"], result2, equal_nan=True) diff --git a/tests/classification/test_precision_recall.py b/tests/classification/test_precision_recall.py index ebdbfe393f9..1c117d70591 100644 --- a/tests/classification/test_precision_recall.py +++ b/tests/classification/test_precision_recall.py @@ -20,6 +20,7 @@ from sklearn.metrics import precision_score, recall_score from torch import Tensor, tensor +from tests.classification import MetricWrapper from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob from tests.classification.inputs import _input_multiclass as _input_mcls from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits @@ -30,6 +31,7 @@ from tests.classification.inputs import _input_multilabel as _input_mlb from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from tests.classification.inputs import _negmetric_noneavg from tests.helpers import seed_all from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester from torchmetrics import Metric, Precision, Recall @@ -457,3 +459,12 @@ def test_same_input(metric_class, metric_functional, sk_fn, average): assert torch.allclose(class_res, torch.tensor(sk_res).float()) assert torch.allclose(func_res, torch.tensor(sk_res).float()) + + +@pytest.mark.parametrize("metric_cls", [Precision, Recall]) +def test_noneavg(metric_cls, noneavg=_negmetric_noneavg): + prec = MetricWrapper(metric_cls(average="none", num_classes=noneavg["pred1"].shape[1])) + result1 = prec(noneavg["pred1"], noneavg["target1"]) + assert torch.allclose(noneavg["res1"], result1, equal_nan=True) + result2 = prec(noneavg["pred2"], noneavg["target2"]) + assert torch.allclose(noneavg["res2"], result2, equal_nan=True) diff --git a/torchmetrics/functional/classification/accuracy.py b/torchmetrics/functional/classification/accuracy.py index 29580ee4047..9ae4febcee6 100644 --- a/torchmetrics/functional/classification/accuracy.py +++ b/torchmetrics/functional/classification/accuracy.py @@ -178,7 +178,7 @@ def _accuracy_compute( numerator = tp + tn denominator = tp + tn + fp + fn else: - numerator = tp + numerator = tp.clone() denominator = tp + fn if mdmc_average != MDMCAverageMethod.SAMPLEWISE: diff --git a/torchmetrics/functional/classification/precision_recall.py b/torchmetrics/functional/classification/precision_recall.py index 3d0d7f0376c..5640477ad74 100644 --- a/torchmetrics/functional/classification/precision_recall.py +++ b/torchmetrics/functional/classification/precision_recall.py @@ -49,7 +49,7 @@ def _precision_compute( tensor(0.2500) """ - numerator = tp + numerator = tp.clone() denominator = tp + fp if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: @@ -241,7 +241,7 @@ def _recall_compute( >>> _recall_compute(tp, fp, fn, average='micro', mdmc_average=None) tensor(0.2500) """ - numerator = tp + numerator = tp.clone() denominator = tp + fn if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: diff --git a/torchmetrics/functional/classification/specificity.py b/torchmetrics/functional/classification/specificity.py index 85d0180162c..3e0097842f2 100644 --- a/torchmetrics/functional/classification/specificity.py +++ b/torchmetrics/functional/classification/specificity.py @@ -51,7 +51,7 @@ def _specificity_compute( tensor(0.6250) """ - numerator = tn + numerator = tn.clone() denominator = tn + fp if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE: # a class is not present if there exists no TPs, no FPs, and no FNs