Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metrics] Add multiclass auroc #4236

Merged

Conversation

ddrevicky
Copy link
Contributor

@ddrevicky ddrevicky commented Oct 19, 2020

What does this PR do?

Implements functional multiclass AUROC.

Fixes #3304

Notes on the code:

Had to pass reorder=False for auc because tests against sklearn kept showing different values and debugging showed that the difference was actually coming out of our auc using torch.argsort which is unstable. A short colab notebook documenting this. I submitted a separate issue #4237 for that.

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together? Otherwise, we ask you to create a separate PR for every change.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@mergify mergify bot requested a review from a team October 19, 2020 16:12
Comment on lines +950 to +957
@multiclass_auc_decorator(reorder=False)
def _multiclass_auroc(pred, target, sample_weight, num_classes):
return multiclass_roc(pred, target, sample_weight, num_classes)

class_aurocs = _multiclass_auroc(pred=pred, target=target,
sample_weight=sample_weight,
num_classes=num_classes)
return torch.mean(class_aurocs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented this using the multiclass_auc_decorator similarly as is done for auroc but have to say that as not such an experienced Pythonist I was scratching my head for a good while trying to figure out what the multiclass_auc_decorator was doing. It's possible that for other people reading the code it might also take unnecessary amount of time. Would the following be more readable? No decorator is needed either outside or within the multiclass_auroc function. Just my humble opinion, which do you guys prefer? :)

class_rocs = multiclass_roc(pred=pred, target=target, sample_weight=sample_weight, num_classes=num_classes)
class_aurocs = []
for fpr, tpr, _ in class_rocs:
    class_aurocs.append(auc(fpr, tpr, reorder=False))
return torch.mean(torch.stack(class_aurocs))

@codecov
Copy link

codecov bot commented Oct 19, 2020

Codecov Report

Merging #4236 into master will increase coverage by 3%.
The diff coverage is 100%.

@@           Coverage Diff           @@
##           master   #4236    +/-   ##
=======================================
+ Coverage      90%     93%    +3%     
=======================================
  Files         113     113            
  Lines        8232    8191    -41     
=======================================
+ Hits         7387    7612   +225     
+ Misses        845     579   -266     

@ddrevicky ddrevicky changed the title Feature/3304 multiclass auroc [Metrics] Add multiclass auroc Oct 19, 2020
@Borda Borda added feature Is an improvement or enhancement Metrics labels Oct 19, 2020
@Borda Borda added this to the 1.0.x milestone Oct 19, 2020
CHANGELOG.md Outdated Show resolved Hide resolved
tests/metrics/functional/test_classification.py Outdated Show resolved Hide resolved
tests/metrics/functional/test_classification.py Outdated Show resolved Hide resolved
@mergify mergify bot requested a review from a team October 19, 2020 18:56
@Borda Borda requested review from justusschock, ananyahjha93, SkafteNicki and teddykoker and removed request for a team October 19, 2020 18:56
@mergify mergify bot requested a review from a team October 19, 2020 18:57
Copy link
Contributor

@teddykoker teddykoker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@edenlightning edenlightning modified the milestones: 1.0.3, 1.1 Oct 19, 2020
@mergify mergify bot requested a review from a team October 20, 2020 08:07
Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

pytorch_lightning/metrics/functional/classification.py Outdated Show resolved Hide resolved
pytorch_lightning/metrics/functional/classification.py Outdated Show resolved Hide resolved
pytorch_lightning/metrics/functional/classification.py Outdated Show resolved Hide resolved
@Borda Borda added the ready PRs ready to be merged label Oct 20, 2020
Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR ! Some extra tests on matrix size would be great !

"Multiclass AUROC metric expects the target scores to be"
" probabilities, i.e. they should sum up to 1.0 over classes")

if torch.unique(target).size(0) != pred.size(1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be torch.unique(target).size(0) <= pred.size(1)

target = torch.tensor([0, 0, 0, 0])
>>> torch.unique(target).size(0)
1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we could have a get_num_classes utils too.

Copy link
Contributor Author

@ddrevicky ddrevicky Oct 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the metric is undefined when torch.unique(target).size(0) < pred.size(1), that's why there is a strict equals. The way this implementation works (based on sklearn's version) is, that it uses a one-vs-rest strategy of computing the AUROC. For n classes it computes n binary AUROCs, for each class in turn, it is considered a positive class and all the other classes negative. Then it averages those.

E.g., for n=3 and target=[0, 1, 1, 2], for class 0 we binarize the target to make 0 the positive class: [1, 0, 0, 0] and compute the AUC of ROC of that.

If a target label is not present in the target, e.g. n=3 and target=[0, 1, 1, 1] then for the absent class 2 the binarized target would look like [0, 0, 0, 0] (all negative) and ROC cannot be computed (would raise an error). Consequently, the whole multiclass AUROC is undefined in that case.

As for get_num_classes there already is such a util, but it does something different than we need here. It doesn't look at dimensions of the predictions, just at the max value in both and deduced num classes from that (which when I think about it now, could fail silently for example when n_cls=5 but target=[0,1,2,3].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense !

>>> target = torch.tensor([0, 1, 3, 2])
>>> multiclass_auroc(pred, target) # doctest: +NORMALIZE_WHITESPACE
tensor(0.6667)
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we should check pred.size(0) == target.size(0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add that of course but this is not a check that is done in any other metric implementation. So if it's done here it should probably be done everywhere. If that's desired, I could add a helper to classification.py

def check_batch_dims(pred, target):
    if not pred.size(0) == target.size(0):
         raise ValueError(f"Batch size for prediction ({pred.size(0)}) and target ({target.size(0)}) must be equal.")

Would that work? Then this helper could be used in each metric instead of copy pasting the if clause and the exception.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we are slowly unifying the functional and class based interface, we are doing more checks for shape, so this will come in a future PR :]

"Multiclass AUROC metric expects the target scores to be"
" probabilities, i.e. they should sum up to 1.0 over classes")

if torch.unique(target).size(0) != pred.size(1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense !

@rohitgr7 rohitgr7 self-requested a review October 30, 2020 11:38
CHANGELOG.md Outdated
@@ -106,6 +106,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added trace functionality to the function `to_torchscript` ([#4142](https://github.com/PyTorchLightning/pytorch-lightning/pull/4142))

- Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ddrevicky could you move this to the unreleased section?

Copy link
Contributor Author

@ddrevicky ddrevicky Oct 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be okay now.

@SkafteNicki SkafteNicki merged commit 38bb4e2 into Lightning-AI:master Oct 30, 2020
SeanNaren pushed a commit that referenced this pull request Nov 10, 2020
* Add functional multiclass AUROC metric

* Add multiclass_auroc tests

* fixup! Add functional multiclass AUROC metric

* fixup! fixup! Add functional multiclass AUROC metric

* Add multiclass_auroc doc reference

* Update CHANGELOG

* formatting

* Shorter error message regex match in tests

* Set num classes as pytest parameter

* formatting

* Update CHANGELOG

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte <[email protected]>
(cherry picked from commit 38bb4e2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

MulticlassAUROC: Implement a multi-class version of the AUROC metric
7 participants