You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The different behavior of RetrievalRecall and RetrievalPrecision make it difficult to compute standard metrics such as Precision@k or Recall@k for multiclass classification problems.
Would it be possible to have them accept the same shape of input, e.g. inputs of shape batch_size, num_classes and targets of shape batch_size, num_classes?
Example code below:
To install: pip install --pre torcheval-nightly; using '0.0.7'.
import torch
from torch.nn import functional as F
from torcheval.metrics import RetrievalRecall
batch_size = 10
num_classes = 20
# generate random predictions
preds = torch.rand(batch_size, num_classes)
# generate random targets
targets = torch.randint(0, num_classes, (batch_size,))
recall = RetrievalRecall(num_queries=batch_size, k=5)
# first make the targets one hot (RetrievalRecall does not accept num_classes arguments, requires binary targets)
targets_one_hot = F.one_hot(targets.type(torch.long), num_classes)
targets_one_hot.shape
# indexes associate each prediction with a target
indexes = torch.arange(batch_size).repeat(num_classes, 1).T
recall.update(preds.ravel(), targets_one_hot.ravel(), indexes=indexes.ravel())
recall.compute().mean() # -> 0.1
from torcheval.metrics import MulticlassRecall, MulticlassPrecision
recall = MulticlassRecall(num_classes=num_classes)
precision = MulticlassPrecision(num_classes=num_classes)
recall.update(preds, targets)
precision.update(preds, targets)
recall.compute(), precision.compute() # -> 0.1, 0.1
Current workaround:
import torch
from torch.nn import functional as F
from torcheval.metrics import RetrievalRecall
class MulticlassRetrievalRecall(RetrievalRecall):
def __init__(self, batch_size, num_classes, **kwargs):
super().__init__(num_queries=batch_size, **kwargs)
self.num_classes = num_classes
def update(self, input, target):
target_one_hot = F.one_hot(target.type(torch.long), self.num_classes)
indexes = torch.arange(len(input)).repeat(self.num_classes, 1).T
super().update(input.ravel(), target_one_hot.ravel(), indexes=indexes.ravel())
Open to any tips on how best to do this! Thank for this helpful canonical library :)
Versions
python collect_env.py 9854 17:14:34
Collecting environment information...
PyTorch version: 2.1.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 13.6.2 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.0.40.1)
CMake version: version 3.22.2
Libc version: N/A
Python version: 3.11.6 (main, Nov 2 2023, 04:39:43) [Clang 14.0.3 (clang-1403.0.22.14.1)] (64-bit runtime)
Python platform: macOS-13.6.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Max
Versions of relevant libraries:
[pip3] numpy==1.26.2
[pip3] torch==2.1.1
[pip3] torchaudio==2.1.1
[pip3] torchdata==0.7.1
[pip3] torcheval==0.0.7
[pip3] torcheval-nightly==2023.12.21
[pip3] torchtext==0.16.1
[pip3] torchvision==0.16.1
[conda] numpy 1.24.3 py310hb93e574_0
[conda] numpy-base 1.24.3 py310haf87e8b_0
[conda] torch 2.0.1 pypi_0 pypi
The text was updated successfully, but these errors were encountered:
🐛 Describe the bug
The different behavior of RetrievalRecall and RetrievalPrecision make it difficult to compute standard metrics such as Precision@k or Recall@k for multiclass classification problems.
Would it be possible to have them accept the same shape of input, e.g. inputs of shape
batch_size, num_classes
and targets of shapebatch_size, num_classes
?Example code below:
To install:
pip install --pre torcheval-nightly
; using '0.0.7'.Current workaround:
Usage:
Open to any tips on how best to do this! Thank for this helpful canonical library :)
Versions
The text was updated successfully, but these errors were encountered: