Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classification Multilabel Micro AveragePrecision does not form a compute group with comparable metrics #1084

Closed
tsteffek opened this issue Jun 12, 2022 · 0 comments · Fixed by #1086
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed

Comments

@tsteffek
Copy link

🐛 Bug

While micro and macro AUROC play well with each other and macro AveragePrecision, micro AveragePrecision will not be merged into the same compute group.

This is due to AveragePrecision flattening its predictions and targets in the update() call (see here) while AUROC flattens only in its compute() (see here). Because of that the shapes don't align and the compute group merge will fail.

To Reproduce

Code sample

import torch
from torchmetrics import MetricCollection, AUROC, AveragePrecision

m = MetricCollection([MetricCollection(AUROC(average='micro', num_classes=3), AveragePrecision(average='micro', num_classes=3), postfix='_micro'),
                     MetricCollection(AUROC(average='macro', num_classes=3), AveragePrecision(average='macro', num_classes=3), postfix='_macro')])
# Multi-label inputs
ml_preds  = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])

m._groups
# Out: 
# {0: ['AUROC_micro'],
#  1: ['AveragePrecision_micro'],
#  2: ['AUROC_macro'],
#  3: ['AveragePrecision_macro']}
m.update(ml_preds, ml_target)
m._groups
# Out: 
# {0: ['AUROC_micro', 'AUROC_macro', 'AveragePrecision_macro'],
#  1: ['AveragePrecision_micro']} - maybe `AveragePrecision_micro` has body odor?

Expected behavior

Micro AveragePrecision shouldn't flatten during update but during compute, which would allow it to have its state shared with e.g. AUROC and itself.

Environment

  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): 0.9.1, pip
  • Python & PyTorch Version (e.g., 1.0): 3.8.12 & 1.11.0
  • Any other relevant information such as OS (e.g., Linux): FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-devel

Additional context

This is especially hurting for AveragePrecision and the likes, since they store all predictions and targets in their state.

@tsteffek tsteffek added bug / fix Something isn't working help wanted Extra attention is needed labels Jun 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants