Skip to content

Commit

Permalink
Fix the use of logits in calibration error (#985)
Browse files Browse the repository at this point in the history
* fix

* Update CHANGELOG.md

* fix logical operator

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
SkafteNicki and Borda authored Apr 26, 2022
1 parent e236821 commit c494f35
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `BinnedPrecisionRecallCurve` when `thresholds` argument is not provided ([#968](https://github.com/PyTorchLightning/metrics/pull/968))


-
- Fixed `CalibrationError` to work on logit input ([#985](https://github.com/PyTorchLightning/metrics/pull/985))


## [0.8.0] - 2022-04-14
Expand Down
12 changes: 10 additions & 2 deletions tests/classification/test_calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import numpy as np
import pytest
from scipy.special import softmax as _softmax

from tests.classification.inputs import _input_binary_prob
from tests.classification.inputs import _input_binary_logits, _input_binary_prob
from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
Expand All @@ -24,8 +26,12 @@
def _sk_calibration(preds, target, n_bins, norm, debias=False):
_, _, mode = _input_format_classification(preds, target, threshold=THRESHOLD)
sk_preds, sk_target = preds.numpy(), target.numpy()

if mode == DataType.BINARY:
if not np.logical_and(0 <= sk_preds, sk_preds <= 1).all():
sk_preds = 1.0 / (1 + np.exp(-sk_preds)) # sigmoid transform
if mode == DataType.MULTICLASS:
if not np.logical_and(0 <= sk_preds, sk_preds <= 1).all():
sk_preds = _softmax(sk_preds, axis=1)
# binary label is whether or not the predicted class is correct
sk_target = np.equal(np.argmax(sk_preds, axis=1), sk_target)
sk_preds = np.max(sk_preds, axis=1)
Expand All @@ -46,7 +52,9 @@ def _sk_calibration(preds, target, n_bins, norm, debias=False):
"preds, target",
[
(_input_binary_prob.preds, _input_binary_prob.target),
(_input_binary_logits.preds, _input_binary_logits.target),
(_input_mcls_prob.preds, _input_mcls_prob.target),
(_input_mcls_logits.preds, _input_mcls_logits.target),
(_input_mdmc_prob.preds, _input_mdmc_prob.target),
],
)
Expand Down
4 changes: 4 additions & 0 deletions torchmetrics/functional/classification/calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,12 @@ def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
_, _, mode = _input_format_classification(preds, target)

if mode == DataType.BINARY:
if not ((0 <= preds) * (preds <= 1)).all():
preds = preds.sigmoid()
confidences, accuracies = preds, target
elif mode == DataType.MULTICLASS:
if not ((0 <= preds) * (preds <= 1)).all():
preds = preds.softmax(dim=1)
confidences, predictions = preds.max(dim=1)
accuracies = predictions.eq(target)
elif mode == DataType.MULTIDIM_MULTICLASS:
Expand Down

0 comments on commit c494f35

Please sign in to comment.