Skip to content

Commit

Permalink
Implement partial auroc metric (#3790)
Browse files Browse the repository at this point in the history
* Implement partial auroc metric

* Add pycodestyle changes

* Added tests for max_fpr

* changelog

* version for tests

* fix imports

* fix tests

* fix tests

* Added more thresholds in (0,1] to test max_fpr

* Removed deprecated 'reorder' param from auroc

* changelog

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <[email protected]>

* remove old structure

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <[email protected]>

* fix test error

Co-authored-by: Nicki Skafte <[email protected]>
Co-authored-by: Nicki Skafte <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
  • Loading branch information
5 people authored Dec 29, 2020
1 parent ae9956f commit 2094633
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 13 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- `HammingDistance` metric to compute the hamming distance (loss) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838))

- Added `max_fpr` parameter to `auroc` metric for computing partial auroc metric ([#3790](https://github.com/PyTorchLightning/pytorch-lightning/pull/3790))

### Changed


Expand Down
35 changes: 30 additions & 5 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Callable, Optional, Sequence, Tuple

import torch
from distutils.version import LooseVersion

from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap
from pytorch_lightning.metrics.functional.f_beta import fbeta as __fb, f1 as __f1
Expand Down Expand Up @@ -544,6 +545,7 @@ def auroc(
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
max_fpr: float = None,
) -> torch.Tensor:
"""
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores
Expand All @@ -553,6 +555,8 @@ def auroc(
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class
max_fpr: If not ``None``, calculates standardized partial AUC over the
range [0, max_fpr]. Should be a float between 0 and 1.
Return:
Tensor containing ROCAUC score
Expand All @@ -569,11 +573,32 @@ def auroc(
' target tensor contains value different from 0 and 1.'
' Use `multiclass_auroc` for multi class classification.')

@auc_decorator()
def _auroc(pred, target, sample_weight, pos_label):
return _roc(pred, target, sample_weight, pos_label)

return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label)
if max_fpr is None or max_fpr == 1:
fpr, tpr, _ = __roc(pred, target, sample_weight, pos_label)
return auc(fpr, tpr)
if not (isinstance(max_fpr, float) and 0 < max_fpr <= 1):
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")
if LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
raise RuntimeError('`max_fpr` argument requires `torch.bucketize` which'
' is not available below PyTorch version 1.6')

fpr, tpr, _ = __roc(pred, target, sample_weight, pos_label)
max_fpr = torch.tensor(max_fpr, device=fpr.device)
# Add a single point at max_fpr and interpolate its tpr value
stop = torch.bucketize(max_fpr, fpr, out_int32=True, right=True)
weight = (max_fpr - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1])
interp_tpr = torch.lerp(tpr[stop - 1], tpr[stop], weight)
tpr = torch.cat([tpr[:stop], interp_tpr.view(1)])
fpr = torch.cat([fpr[:stop], max_fpr.view(1)])

# Compute partial AUC
partial_auc = auc(fpr, tpr)

# McClish correction: standardize result to be 0.5 if non-discriminant
# and 1 if maximal
min_area = 0.5 * max_fpr ** 2
max_area = max_fpr
return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))


def multiclass_auroc(
Expand Down
40 changes: 32 additions & 8 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import torch
from distutils.version import LooseVersion
from sklearn.metrics import (
jaccard_score as sk_jaccard_score,
precision_score as sk_precision,
Expand Down Expand Up @@ -197,18 +198,41 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
assert thresh.shape == (exp_shape,)


@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([0, 1, 0, 1], [0, 1, 0, 1], 1.),
pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.),
pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5),
pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.),
pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.5),
@pytest.mark.parametrize(['pred', 'target', 'max_fpr', 'expected'], [
pytest.param([0, 1, 0, 1], [0, 1, 0, 1], None, 1.),
pytest.param([1, 1, 0, 0], [0, 0, 1, 1], None, 0.),
pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.8, 0.5),
pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.2, 0.5),
pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 0.5, 1.),
])
def test_auroc(pred, target, expected):
score = auroc(torch.tensor(pred), torch.tensor(target)).item()
def test_auroc(pred, target, max_fpr, expected):
if max_fpr is not None and LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
pytest.skip('requires torch v1.6 or higher to test max_fpr argument')

score = auroc(torch.tensor(pred), torch.tensor(target), max_fpr=max_fpr).item()
assert score == expected


@pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion('1.6.0'),
reason='requires torch v1.6 or higher to test max_fpr argument')
@pytest.mark.parametrize('max_fpr', [
None, 1, 0.99, 0.9, 0.75, 0.5, 0.25, 0.1, 0.01, 0.001,
])
def test_auroc_with_max_fpr_against_sklearn(max_fpr):
device = 'cuda' if torch.cuda.is_available() else 'cpu'

pred = torch.rand((300,), device=device)
# Supports only binary classification
target = torch.randint(2, (300,), dtype=torch.float64, device=device)
sk_score = sk_roc_auc_score(target.cpu().detach().numpy(),
pred.cpu().detach().numpy(),
max_fpr=max_fpr)
pl_score = auroc(pred, target, max_fpr=max_fpr)

sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
assert torch.allclose(sk_score, pl_score)


def test_multiclass_auroc():
with pytest.raises(ValueError,
match=r".*probabilities, i.e. they should sum up to 1.0 over classes"):
Expand Down

0 comments on commit 2094633

Please sign in to comment.