BinaryAccuracy()
sometimes gives incorrect answers due to non-deterministic sigmoiding
#1604
Labels
Milestone
🐛 Bug
torchmetrics.classification.BinaryAccuracy
will apply a sigmoid to some inputs but not others leading to incorrect behavior.Details
The current behavior of BinaryAccuracy() is to apply a sigmoid transformation if the inputs are outside of [0, 1] before binarizing
i.e.
y_hat = 1(sigmoid(z) >= threshold) if z outside [0, 1]
y_hat = 1(z >= threshold) if z inside [0, 1]
I assume z inside [0, 1] is checked for then entire batch (i.e. if one element of the batch is outside [0, 1] then we apply the sigmoid to everyone).
This will cause silent errors. In particular, if the user inputs logits then they expect the logits to always be sigmoided. However, it is totally possible for all of the logits to lie in [0, 1] for some batches in which case the input will not be sigmoided which will cause incorrect thresholding.
To Reproduce
Here is a simple example. Support our network outputs logits.
This example should lead to a correct prediction
BinaryAccuracy()
however thinks it's an incorrect prediction~Suggested Fix
I suggest adding an argument indicating whether or not the input predictions are sigmoided so the inputs are either always sigmoided or never sigmoided
The text was updated successfully, but these errors were encountered: