Skip to content

Commit 79cb5e2

Browse files
vatch123BordaSkafteNickipre-commit-ci[bot]mergify[bot]
authored
Fix metrics in macro average (#303)
* fix weights for nonexisting classes * fix division by zero * part fix * add test case * Apply suggestions from code review * fix * changelog * please fix * dist_sync not working * trying to fix Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: SkafteNicki <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent d8b89e0 commit 79cb5e2

File tree

8 files changed

+104
-7
lines changed

8 files changed

+104
-7
lines changed

CHANGELOG.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
7777

7878
### Fixed
7979

80+
- Fixed bug where classification metrics with `average='macro'` would lead to wrong result if a class was missing ([#303](https://github.com/PyTorchLightning/metrics/pull/303))
81+
82+
8083
- Fixed `weighted`, `multi-class` AUROC computation to allow for 0 observations of some class, as contribution to final AUROC is 0 ([#348](https://github.com/PyTorchLightning/metrics/issues/348))
8184

8285

@@ -85,7 +88,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
8588

8689
- Fixed calculation in `IoU` metric when using `ignore_index` argument ([#328](https://github.com/PyTorchLightning/metrics/pull/328))
8790

88-
8991
## [0.4.1] - 2021-07-05
9092

9193
### Changed

tests/classification/inputs.py

+7
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,10 @@ def generate_plausible_inputs_binary(num_batches=NUM_BATCHES, batch_size=BATCH_S
116116
_input_multilabel_prob_plausible = generate_plausible_inputs_multilabel()
117117

118118
_input_binary_prob_plausible = generate_plausible_inputs_binary()
119+
120+
# randomly remove one class from the input
121+
_temp = torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
122+
_class_remove, _class_replace = torch.multinomial(torch.ones(NUM_CLASSES), num_samples=2, replacement=False)
123+
_temp[_temp == _class_remove] = _class_replace
124+
125+
_input_multiclass_with_missing_class = Input(_temp.clone(), _temp.clone())

tests/classification/test_accuracy.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tests.classification.inputs import _input_multiclass as _input_mcls
2424
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
2525
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
26+
from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class
2627
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
2728
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
2829
from tests.classification.inputs import _input_multilabel as _input_mlb
@@ -31,7 +32,7 @@
3132
from tests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
3233
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
3334
from tests.helpers import seed_all
34-
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
35+
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester
3536
from torchmetrics import Accuracy
3637
from torchmetrics.functional import accuracy
3738
from torchmetrics.utilities.checks import _input_format_classification
@@ -342,3 +343,21 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
342343
cl_metric(preds, target)
343344
result_cl = cl_metric.compute()
344345
assert torch.allclose(expected, result_cl, equal_nan=True)
346+
347+
348+
@pytest.mark.parametrize("average", ["micro", "macro", "weighted"])
349+
def test_same_input(average):
350+
preds = _input_miss_class.preds
351+
target = _input_miss_class.target
352+
preds_flat = torch.cat([p for p in preds], dim=0)
353+
target_flat = torch.cat([t for t in target], dim=0)
354+
355+
mc = Accuracy(num_classes=NUM_CLASSES, average=average)
356+
for i in range(NUM_BATCHES):
357+
mc.update(preds[i], target[i])
358+
class_res = mc.compute()
359+
func_res = accuracy(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average)
360+
sk_res = sk_accuracy(target_flat, preds_flat)
361+
362+
assert torch.allclose(class_res, torch.tensor(sk_res).float())
363+
assert torch.allclose(func_res, torch.tensor(sk_res).float())

tests/classification/test_f_beta.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@
2424
from tests.classification.inputs import _input_multiclass as _input_mcls
2525
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
2626
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
27+
from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class
2728
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
2829
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
2930
from tests.classification.inputs import _input_multilabel as _input_mlb
3031
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
3132
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
3233
from tests.helpers import seed_all
33-
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
34+
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester
3435
from torchmetrics import F1, FBeta, Metric
3536
from torchmetrics.functional import f1, fbeta
3637
from torchmetrics.utilities.checks import _input_format_classification
@@ -55,7 +56,6 @@ def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, multiclass, ignore_
5556
preds, target, THRESHOLD, num_classes=num_classes, multiclass=multiclass
5657
)
5758
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()
58-
5959
sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels)
6060

6161
if len(labels) != num_classes and not average:
@@ -425,3 +425,25 @@ def test_top_k(
425425

426426
assert torch.isclose(class_metric.compute(), result)
427427
assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)
428+
429+
430+
@pytest.mark.parametrize("average", ["micro", "macro", "weighted"])
431+
@pytest.mark.parametrize(
432+
"metric_class, metric_functional, sk_fn",
433+
[(partial(FBeta, beta=2.0), partial(fbeta, beta=2.0), partial(fbeta_score, beta=2.0)), (F1, f1, f1_score)],
434+
)
435+
def test_same_input(metric_class, metric_functional, sk_fn, average):
436+
preds = _input_miss_class.preds
437+
target = _input_miss_class.target
438+
preds_flat = torch.cat([p for p in preds], dim=0)
439+
target_flat = torch.cat([t for t in target], dim=0)
440+
441+
mc = metric_class(num_classes=NUM_CLASSES, average=average)
442+
for i in range(NUM_BATCHES):
443+
mc.update(preds[i], target[i])
444+
class_res = mc.compute()
445+
func_res = metric_functional(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average)
446+
sk_res = sk_fn(target_flat, preds_flat, average=average, zero_division=0)
447+
448+
assert torch.allclose(class_res, torch.tensor(sk_res).float())
449+
assert torch.allclose(func_res, torch.tensor(sk_res).float())

tests/classification/test_precision_recall.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@
2424
from tests.classification.inputs import _input_multiclass as _input_mcls
2525
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
2626
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
27+
from tests.classification.inputs import _input_multiclass_with_missing_class as _input_miss_class
2728
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
2829
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
2930
from tests.classification.inputs import _input_multilabel as _input_mlb
3031
from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits
3132
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
3233
from tests.helpers import seed_all
33-
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
34+
from tests.helpers.testers import NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester
3435
from torchmetrics import Metric, Precision, Recall
3536
from torchmetrics.functional import precision, precision_recall, recall
3637
from torchmetrics.utilities.checks import _input_format_classification
@@ -209,7 +210,7 @@ def test_no_support(metric_class, metric_fn):
209210
)
210211
class TestPrecisionRecall(MetricTester):
211212
@pytest.mark.parametrize("ddp", [False, True])
212-
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
213+
@pytest.mark.parametrize("dist_sync_on_step", [False])
213214
def test_precision_recall_class(
214215
self,
215216
ddp: bool,
@@ -437,3 +438,24 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected):
437438
cl_metric(preds, target)
438439
result_cl = cl_metric.compute()
439440
assert torch.allclose(expected, result_cl, equal_nan=True)
441+
442+
443+
@pytest.mark.parametrize("average", ["micro", "macro", "weighted"])
444+
@pytest.mark.parametrize(
445+
"metric_class, metric_functional, sk_fn", [(Precision, precision, precision_score), (Recall, recall, recall_score)]
446+
)
447+
def test_same_input(metric_class, metric_functional, sk_fn, average):
448+
preds = _input_miss_class.preds
449+
target = _input_miss_class.target
450+
preds_flat = torch.cat([p for p in preds], dim=0)
451+
target_flat = torch.cat([t for t in target], dim=0)
452+
453+
mc = metric_class(num_classes=NUM_CLASSES, average=average)
454+
for i in range(NUM_BATCHES):
455+
mc.update(preds[i], target[i])
456+
class_res = mc.compute()
457+
func_res = metric_functional(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average)
458+
sk_res = sk_fn(target_flat, preds_flat, average=average, zero_division=1)
459+
460+
assert torch.allclose(class_res, torch.tensor(sk_res).float())
461+
assert torch.allclose(func_res, torch.tensor(sk_res).float())

torchmetrics/functional/classification/accuracy.py

+6
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ def _accuracy_compute(
8585
else:
8686
numerator = tp
8787
denominator = tp + fn
88+
89+
if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
90+
cond = tp + fp + fn == 0
91+
numerator = numerator[~cond]
92+
denominator = denominator[~cond]
93+
8894
if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
8995
# a class is not present if there exists no TPs, no FPs, and no FNs
9096
meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu()

torchmetrics/functional/classification/f_beta.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,15 @@ def _fbeta_compute(
4646
precision = _safe_divide(tp.float(), tp + fp)
4747
recall = _safe_divide(tp.float(), tp + fn)
4848

49+
if average == AvgMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
50+
cond = tp + fp + fn == 0
51+
precision = precision[~cond]
52+
recall = recall[~cond]
53+
4954
num = (1 + beta ** 2) * precision * recall
5055
denom = beta ** 2 * precision + recall
51-
denom[denom == 0.0] = 1 # avoid division by 0
56+
denom[denom == 0.0] = 1.0 # avoid division by 0
57+
5258
# if classes matter and a given class is not present in both the preds and the target,
5359
# computing the score for this class is meaningless, thus they should be ignored
5460
if average == AvgMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE:

torchmetrics/functional/classification/precision_recall.py

+13
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ def _precision_compute(
2929
) -> Tensor:
3030
numerator = tp
3131
denominator = tp + fp
32+
33+
if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
34+
cond = tp + fp + fn == 0
35+
numerator = numerator[~cond]
36+
denominator = denominator[~cond]
37+
3238
if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
3339
# a class is not present if there exists no TPs, no FPs, and no FNs
3440
meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu()
@@ -199,11 +205,18 @@ def _recall_compute(
199205
) -> Tensor:
200206
numerator = tp
201207
denominator = tp + fn
208+
209+
if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
210+
cond = tp + fp + fn == 0
211+
numerator = numerator[~cond]
212+
denominator = denominator[~cond]
213+
202214
if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
203215
# a class is not present if there exists no TPs, no FPs, and no FNs
204216
meaningless_indeces = ((tp | fn | fp) == 0).nonzero().cpu()
205217
numerator[meaningless_indeces, ...] = -1
206218
denominator[meaningless_indeces, ...] = -1
219+
207220
return _reduce_stat_scores(
208221
numerator=numerator,
209222
denominator=denominator,

0 commit comments

Comments
 (0)