Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in MulticlassRecall example from when adding one additional class #189

Open
jaanli opened this issue Dec 22, 2023 · 1 comment
Open

Comments

@jaanli
Copy link

jaanli commented Dec 22, 2023

🐛 Describe the bug

The example from the docs leads to a bug when modified slightly: https://pytorch.org/torcheval/stable/generated/torcheval.metrics.MulticlassRecall.html#torcheval.metrics.MulticlassRecall

>>> metric = MulticlassRecall(num_classes=4)
>>> input = torch.tensor([[0.9, 0.1, 0, 0], [0.1, 0.2, 0.4, 0.3], [0, 1.0, 0, 0], [0, 0, 0.2, 0.8]])
>>> target = torch.tensor([0, 1, 2, 3])
>>> metric.update(input, target)
>>> metric.compute()
tensor(0.5000)

Adding an extra class and specifying a "macro" average leads to a bug:

metric = MulticlassRecall(num_classes=5, average="macro")
input = torch.tensor([[0.9, 0.1, 0, 0, 0], [0.1, 0.2, 0.4, 0.3, 0], [0, 1.0, 0, 0, 0], [0, 0, 0.2, 0.8, 0]])
target = torch.tensor([0, 1, 2, 3])
metric.update(input, target)
metric.compute()

Yields:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/me/projects/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/me/projects/.venv/lib/python3.11/site-packages/torcheval/metrics/classification/recall.py", line 243, in compute
    return _recall_compute(
           ^^^^^^^^^^^^^^^^
  File "/Users/me/projects/.venv/lib/python3.11/site-packages/torcheval/metrics/functional/classification/recall.py", line 195, in _recall_compute
    recall = num_tp / num_labels
             ~~~~~~~^~~~~~~~~~~~
RuntimeError: The size of tensor a (4) must match the size of tensor b (5) at non-singleton dimension 0

Versions

python collect_env.py                                                                                       9854  17:14:34  

Collecting environment information...
PyTorch version: 2.1.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.6.2 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.0.40.1)
CMake version: version 3.22.2
Libc version: N/A

Python version: 3.11.6 (main, Nov  2 2023, 04:39:43) [Clang 14.0.3 (clang-1403.0.22.14.1)] (64-bit runtime)
Python platform: macOS-13.6.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Max

Versions of relevant libraries:
[pip3] numpy==1.26.2
[pip3] torch==2.1.1
[pip3] torchaudio==2.1.1
[pip3] torchdata==0.7.1
[pip3] torcheval==0.0.7
[pip3] torcheval-nightly==2023.12.21
[pip3] torchtext==0.16.1
[pip3] torchvision==0.16.1
[conda] numpy                     1.24.3          py310hb93e574_0  
[conda] numpy-base                1.24.3          py310haf87e8b_0  
[conda] torch                     2.0.1                    pypi_0    pypi
@dburian
Copy link

dburian commented Mar 19, 2024

I ran into this issue as well. The point is that if macro-averaged MulticlassRecall doesn't recieve one prediction for each class the runtime error you posted is thrown when compute() is called.

The same example as you posted with one added example predicting the final class computes without error:

metric = MulticlassRecall(num_classes=5, average="macro")
input = torch.tensor([
    [0.9, 0.1, 0, 0, 0],
    [0.1, 0.2, 0.4, 0.3, 0],
    [0, 1.0, 0, 0, 0],
    [0, 0, 0.2, 0.8, 0],
    [0, 0, 0, 0, 1.0]
])
target = torch.tensor([0, 1, 2, 3, 4])
metric.update(input, target)
metric.compute()

tensor(0.6000)

After looking around this is a duplicate of #150 and should be addressed by the still opened PR #166.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants