Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BinaryAccuracy() sometimes gives incorrect answers due to non-deterministic sigmoiding #1604

Open
idc9 opened this issue Mar 9, 2023 · 2 comments · May be fixed by #1676
Open

BinaryAccuracy() sometimes gives incorrect answers due to non-deterministic sigmoiding #1604

idc9 opened this issue Mar 9, 2023 · 2 comments · May be fixed by #1676
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.1.x
Milestone

Comments

@idc9
Copy link

idc9 commented Mar 9, 2023

🐛 Bug

torchmetrics.classification.BinaryAccuracy will apply a sigmoid to some inputs but not others leading to incorrect behavior.

Details

The current behavior of BinaryAccuracy() is to apply a sigmoid transformation if the inputs are outside of [0, 1] before binarizing

If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

i.e.
y_hat = 1(sigmoid(z) >= threshold) if z outside [0, 1]
y_hat = 1(z >= threshold) if z inside [0, 1]

I assume z inside [0, 1] is checked for then entire batch (i.e. if one element of the batch is outside [0, 1] then we apply the sigmoid to everyone).

This will cause silent errors. In particular, if the user inputs logits then they expect the logits to always be sigmoided. However, it is totally possible for all of the logits to lie in [0, 1] for some batches in which case the input will not be sigmoided which will cause incorrect thresholding.

To Reproduce

Here is a simple example. Support our network outputs logits.

from torchmetrics.classification import BinaryAccuracy
from scipy.special import expit # expit = sigmoid
import numpy as np
import torch

This example should lead to a correct prediction

probability_thresh = 0.5 
logits = np.array([0.49]) # network output
target = np.array([1])

# logits of 0.49 give a probability of 0.62 indicating class 1, the correct prediction
expit(logits)
array([0.62010643])
int(expit(logits) >= probability_thresh) == target
True

BinaryAccuracy() however thinks it's an incorrect prediction~

# torchmetrics, however, thinks we have the inccorect prediction because it does NOT sigmoid the logits
ba = BinaryAccuracy(threshold=probability_thresh) 
ba.forward(preds=torch.tensor(logits), target=torch.tensor(target))
tensor(0.)

Suggested Fix

I suggest adding an argument indicating whether or not the input predictions are sigmoided so the inputs are either always sigmoided or never sigmoided

@idc9 idc9 added bug / fix Something isn't working help wanted Extra attention is needed labels Mar 9, 2023
@github-actions
Copy link

github-actions bot commented Mar 9, 2023

Hi! thanks for your contribution!, great first issue!

@idc9
Copy link
Author

idc9 commented Mar 9, 2023

Looks like a similar thing happens in MultilabelAccuracy (https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html)

@SkafteNicki SkafteNicki linked a pull request Mar 31, 2023 that will close this issue
27 tasks
@Borda Borda changed the title BinaryAccuracy() sometimes gives incorrect answers due to non-deterministic sigmoiding BinaryAccuracy() sometimes gives incorrect answers due to non-deterministic sigmoiding Aug 25, 2023
@Lightning-AI Lightning-AI deleted a comment from stale bot Aug 25, 2023
@Borda Borda added this to the v1.1.x milestone Aug 25, 2023
@Borda Borda modified the milestones: v1.1.x, v1.2.x Sep 24, 2023
@Borda Borda added the v1.1.x label Oct 6, 2023
@Borda Borda modified the milestones: v1.2.x, v1.3.x Jan 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.1.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants
@idc9 @Borda and others