Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BinaryConfusionMatrix does not work with float target #212

Open
rafael-patronilo opened this issue Dec 4, 2024 · 0 comments
Open

BinaryConfusionMatrix does not work with float target #212

rafael-patronilo opened this issue Dec 4, 2024 · 0 comments

Comments

@rafael-patronilo
Copy link

🐛 Describe the bug

Binary Confusion Matrix does not work if target is float (even if all values are 0 or 1).

Note: It could be argued that target should always be int. However given that #146 was solved, I assume you will also want to solve this one, or at least include an error for it.
In my case, target was float because my csv was automatically loaded by pandas.

Minimal example:

import torch
from torcheval.metrics import BinaryConfusionMatrix

input = torch.randint(0, 2, (10,)).to(torch.float32)
target = torch.randint(0, 2, (10,))

cm = BinaryConfusionMatrix()
cm.update(input, target) # no error here
print(cm.compute())

cm = BinaryConfusionMatrix()
cm.update(input, target.to(torch.float32)) # error here
print(cm.compute())

Error and Traceback

Traceback (most recent call last):
  File "/mnt/c/Users/rafae/Documents/faculdade/thesis/thesis-framework/temp.py", line 12, in <module>
    cm.update(input, target.to(torch.float32)) # error here
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/rafae/Documents/faculdade/thesis/thesis-framework/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/rafae/Documents/faculdade/thesis/thesis-framework/.venv/lib/python3.12/site-packages/torcheval/metrics/classification/confusion_matrix.py", line 311, in update
    self.confusion_matrix += _binary_confusion_matrix_update(
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/rafae/Documents/faculdade/thesis/thesis-framework/.venv/lib/python3.12/site-packages/torcheval/metrics/functional/classification/confusion_matrix.py", line 175, in _binary_confusion_matrix_update
    return _update(input, target, 2)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/mnt/c/Users/rafae/Documents/faculdade/thesis/thesis-framework/.venv/lib/python3.12/site-packages/torcheval/metrics/functional/classification/confusion_matrix.py", line 232, in _update

    # Each prediction creates an entry at the position (true, pred)
    sparse_cm = torch.sparse_coo_tensor(coordinates, torch.ones_like(target), cm_shape)
                ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

    return sparse_cm.to_dense()
RuntimeError: indices must be an int64 tensor

Interpretation

The coordinates tensor (corresponding to the indices parameter of sparse_coo_tensor) is a vstack of input (after applying the threshold, therefore an int tensor) with target. Since target is float, vstack must create the stacked tensor as float.

Versions

Versions

  • torcheval 0.0.7
  • torch 2.5.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant