Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metrics] Add multiclass auroc #4236

Merged
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added trace functionality to the function `to_torchscript` ([#4142](https://github.com/PyTorchLightning/pytorch-lightning/pull/4142))

- Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236))
Borda marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ddrevicky could you move this to the unreleased section?

Copy link
Contributor Author

@ddrevicky ddrevicky Oct 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be okay now.


### Changed

- Called `on_load_checkpoint` before loading `state_dict` ([#4057](https://github.com/PyTorchLightning/pytorch-lightning/pull/4057))
Expand Down
7 changes: 7 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,13 @@ auroc [func]
:noindex:


multiclass_auroc [func]
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.multiclass_auroc
:noindex:


average_precision [func]
~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
fbeta_score,
multiclass_precision_recall_curve,
multiclass_roc,
multiclass_auroc,
precision,
precision_recall,
precision_recall_curve,
Expand Down
61 changes: 59 additions & 2 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,13 +850,14 @@ def new_func(*args, **kwargs) -> torch.Tensor:

def multiclass_auc_decorator(reorder: bool = True) -> Callable:
def wrapper(func_to_decorate: Callable) -> Callable:
@wraps(func_to_decorate)
def new_func(*args, **kwargs) -> torch.Tensor:
results = []
for class_result in func_to_decorate(*args, **kwargs):
x, y = class_result[:2]
results.append(auc(x, y, reorder=reorder))

return torch.cat(results)
return torch.stack(results)

return new_func

Expand Down Expand Up @@ -891,7 +892,7 @@ def auroc(
if any(target > 1):
raise ValueError('AUROC metric is meant for binary classification, but'
' target tensor contains value different from 0 and 1.'
' Multiclass is currently not supported.')
' Use `multiclass_auroc` for multi class classification.')

@auc_decorator(reorder=True)
def _auroc(pred, target, sample_weight, pos_label):
Expand All @@ -900,6 +901,62 @@ def _auroc(pred, target, sample_weight, pos_label):
return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label)


def multiclass_auroc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
num_classes: Optional[int] = None,
) -> torch.Tensor:
"""
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from multiclass
prediction scores

Args:
pred: estimated probabilities, with shape [N, C]
target: ground-truth labels, with shape [N,]
sample_weight: sample weights
num_classes: number of classes (default: None, computes automatically from data)

Return:
Tensor containing ROCAUC score

Example:

>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> multiclass_auroc(pred, target) # doctest: +NORMALIZE_WHITESPACE
tensor(0.6667)
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we should check pred.size(0) == target.size(0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add that of course but this is not a check that is done in any other metric implementation. So if it's done here it should probably be done everywhere. If that's desired, I could add a helper to classification.py

def check_batch_dims(pred, target):
    if not pred.size(0) == target.size(0):
         raise ValueError(f"Batch size for prediction ({pred.size(0)}) and target ({target.size(0)}) must be equal.")

Would that work? Then this helper could be used in each metric instead of copy pasting the if clause and the exception.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we are slowly unifying the functional and class based interface, we are doing more checks for shape, so this will come in a future PR :]

if not torch.allclose(pred.sum(dim=1), torch.tensor(1.0)):
raise ValueError(
"Multiclass AUROC metric expects the target scores to be"
" probabilities, i.e. they should sum up to 1.0 over classes")

if torch.unique(target).size(0) != pred.size(1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be torch.unique(target).size(0) <= pred.size(1)

target = torch.tensor([0, 0, 0, 0])
>>> torch.unique(target).size(0)
1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we could have a get_num_classes utils too.

Copy link
Contributor Author

@ddrevicky ddrevicky Oct 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the metric is undefined when torch.unique(target).size(0) < pred.size(1), that's why there is a strict equals. The way this implementation works (based on sklearn's version) is, that it uses a one-vs-rest strategy of computing the AUROC. For n classes it computes n binary AUROCs, for each class in turn, it is considered a positive class and all the other classes negative. Then it averages those.

E.g., for n=3 and target=[0, 1, 1, 2], for class 0 we binarize the target to make 0 the positive class: [1, 0, 0, 0] and compute the AUC of ROC of that.

If a target label is not present in the target, e.g. n=3 and target=[0, 1, 1, 1] then for the absent class 2 the binarized target would look like [0, 0, 0, 0] (all negative) and ROC cannot be computed (would raise an error). Consequently, the whole multiclass AUROC is undefined in that case.

As for get_num_classes there already is such a util, but it does something different than we need here. It doesn't look at dimensions of the predictions, just at the max value in both and deduced num classes from that (which when I think about it now, could fail silently for example when n_cls=5 but target=[0,1,2,3].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense !

raise ValueError(
f"Number of classes found in in 'target' ({torch.unique(target).size(0)})"
f" does not equal the number of columns in 'pred' ({pred.size(1)})."
" Multiclass AUROC is not defined when all of the classes do not"
" occur in the target labels.")

if num_classes is not None and num_classes != pred.size(1):
raise ValueError(
f"Number of classes deduced from 'pred' ({pred.size(1)}) does not equal"
f" the number of classes passed in 'num_classes' ({num_classes}).")

@multiclass_auc_decorator(reorder=False)
def _multiclass_auroc(pred, target, sample_weight, num_classes):
return multiclass_roc(pred, target, sample_weight, num_classes)

class_aurocs = _multiclass_auroc(pred=pred, target=target,
sample_weight=sample_weight,
num_classes=num_classes)
return torch.mean(class_aurocs)
Comment on lines +917 to +924
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented this using the multiclass_auc_decorator similarly as is done for auroc but have to say that as not such an experienced Pythonist I was scratching my head for a good while trying to figure out what the multiclass_auc_decorator was doing. It's possible that for other people reading the code it might also take unnecessary amount of time. Would the following be more readable? No decorator is needed either outside or within the multiclass_auroc function. Just my humble opinion, which do you guys prefer? :)

class_rocs = multiclass_roc(pred=pred, target=target, sample_weight=sample_weight, num_classes=num_classes)
class_aurocs = []
for fpr, tpr, _ in class_rocs:
    class_aurocs.append(auc(fpr, tpr, reorder=False))
return torch.mean(torch.stack(class_aurocs))



def average_precision(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down
42 changes: 42 additions & 0 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
dice_score,
average_precision,
auroc,
multiclass_auroc,
precision_recall_curve,
roc,
auc,
Expand Down Expand Up @@ -346,6 +347,47 @@ def test_auroc(pred, target, expected):
assert score == expected


def test_multiclass_auroc():
with pytest.raises(ValueError,
match=r".*probabilities, i.e. they should sum up to 1.0 over classes"):
_ = multiclass_auroc(pred=torch.tensor([[0.9, 0.9],
[1.0, 0]]),
target=torch.tensor([0, 1]))

with pytest.raises(ValueError,
match=r".*not defined when all of the classes do not occur in the target.*"):
_ = multiclass_auroc(pred=torch.rand((4, 3)).softmax(dim=1),
target=torch.tensor([1, 0, 1, 0]))

with pytest.raises(ValueError,
match=r".*does not equal the number of classes passed in 'num_classes'.*"):
_ = multiclass_auroc(pred=torch.rand((5, 4)).softmax(dim=1),
target=torch.tensor([0, 1, 2, 2, 3]),
num_classes=6)


@pytest.mark.parametrize('n_cls', [2, 5, 10, 50])
def test_multiclass_auroc_against_sklearn(n_cls):
device = 'cuda' if torch.cuda.is_available() else 'cpu'

n_samples = 300
pred = torch.rand(n_samples, n_cls, device=device).softmax(dim=1)
target = torch.randint(n_cls, (n_samples,), device=device)
# Make sure target includes all class labels so that multiclass AUROC is defined
target[10:10 + n_cls] = torch.arange(n_cls)

pl_score = multiclass_auroc(pred, target)
# For the binary case, sklearn expects an (n_samples,) array of probabilities of
# the positive class
pred = pred[:, 1] if n_cls == 2 else pred
sk_score = sk_roc_auc_score(target.cpu().detach().numpy(),
pred.cpu().detach().numpy(),
multi_class="ovr")

sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
assert torch.allclose(sk_score, pl_score)


@pytest.mark.parametrize(['x', 'y', 'expected'], [
pytest.param([0, 1], [0, 1], 0.5),
pytest.param([1, 0], [0, 1], 0.5),
Expand Down