Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RetrievalRecall, RetrievalPrecision require different, 1D input than MulticlassRecall, MulticlassPrecision which accept batch input #188

Open
jaanli opened this issue Dec 21, 2023 · 2 comments

Comments

@jaanli
Copy link

jaanli commented Dec 21, 2023

🐛 Describe the bug

The different behavior of RetrievalRecall and RetrievalPrecision make it difficult to compute standard metrics such as Precision@k or Recall@k for multiclass classification problems.

Would it be possible to have them accept the same shape of input, e.g. inputs of shape batch_size, num_classes and targets of shape batch_size, num_classes?

Example code below:

To install: pip install --pre torcheval-nightly; using '0.0.7'.

import torch
from torch.nn import functional as F
from torcheval.metrics import RetrievalRecall


batch_size = 10
num_classes = 20
# generate random predictions
preds = torch.rand(batch_size, num_classes)
# generate random targets
targets = torch.randint(0, num_classes, (batch_size,))

recall = RetrievalRecall(num_queries=batch_size, k=5)

# first make the targets one hot (RetrievalRecall does not accept num_classes arguments, requires binary targets)
targets_one_hot = F.one_hot(targets.type(torch.long), num_classes)
targets_one_hot.shape

# indexes associate each prediction with a target
indexes = torch.arange(batch_size).repeat(num_classes, 1).T

recall.update(preds.ravel(), targets_one_hot.ravel(), indexes=indexes.ravel())

recall.compute().mean() # -> 0.1


from torcheval.metrics import MulticlassRecall, MulticlassPrecision

recall = MulticlassRecall(num_classes=num_classes)
precision = MulticlassPrecision(num_classes=num_classes)
recall.update(preds, targets)
precision.update(preds, targets)
recall.compute(), precision.compute() # -> 0.1, 0.1

Current workaround:

import torch
from torch.nn import functional as F
from torcheval.metrics import RetrievalRecall


class MulticlassRetrievalRecall(RetrievalRecall):
    def __init__(self, batch_size, num_classes, **kwargs):
        super().__init__(num_queries=batch_size, **kwargs)
        self.num_classes = num_classes
        
    def update(self, input, target):
        target_one_hot = F.one_hot(target.type(torch.long), self.num_classes)
        indexes = torch.arange(len(input)).repeat(self.num_classes, 1).T
        super().update(input.ravel(), target_one_hot.ravel(), indexes=indexes.ravel())

Usage:

recall_multi = MulticlassRetrievalRecall(batch_size, num_classes, k=5)
recall_multi.update(preds, targets)
recall_multi.compute().mean() # -> 0.1

Open to any tips on how best to do this! Thank for this helpful canonical library :)

Versions

python collect_env.py                                                                                       9854  17:14:34  

Collecting environment information...
PyTorch version: 2.1.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.6.2 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.0.40.1)
CMake version: version 3.22.2
Libc version: N/A

Python version: 3.11.6 (main, Nov  2 2023, 04:39:43) [Clang 14.0.3 (clang-1403.0.22.14.1)] (64-bit runtime)
Python platform: macOS-13.6.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Max

Versions of relevant libraries:
[pip3] numpy==1.26.2
[pip3] torch==2.1.1
[pip3] torchaudio==2.1.1
[pip3] torchdata==0.7.1
[pip3] torcheval==0.0.7
[pip3] torcheval-nightly==2023.12.21
[pip3] torchtext==0.16.1
[pip3] torchvision==0.16.1
[conda] numpy                     1.24.3          py310hb93e574_0  
[conda] numpy-base                1.24.3          py310haf87e8b_0  
[conda] torch                     2.0.1                    pypi_0    pypi
@jaanli
Copy link
Author

jaanli commented Dec 21, 2023

cc @jsseely for visibility

@bobakfb
Copy link
Contributor

bobakfb commented Jan 4, 2024

@galrotem @JKSenthil Any chance one of you could look into this and the other posted by @jaanli

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants