-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Metrics] Add multiclass auroc #4236
[Metrics] Add multiclass auroc #4236
Conversation
@multiclass_auc_decorator(reorder=False) | ||
def _multiclass_auroc(pred, target, sample_weight, num_classes): | ||
return multiclass_roc(pred, target, sample_weight, num_classes) | ||
|
||
class_aurocs = _multiclass_auroc(pred=pred, target=target, | ||
sample_weight=sample_weight, | ||
num_classes=num_classes) | ||
return torch.mean(class_aurocs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've implemented this using the multiclass_auc_decorator
similarly as is done for auroc
but have to say that as not such an experienced Pythonist I was scratching my head for a good while trying to figure out what the multiclass_auc_decorator
was doing. It's possible that for other people reading the code it might also take unnecessary amount of time. Would the following be more readable? No decorator is needed either outside or within the multiclass_auroc
function. Just my humble opinion, which do you guys prefer? :)
class_rocs = multiclass_roc(pred=pred, target=target, sample_weight=sample_weight, num_classes=num_classes)
class_aurocs = []
for fpr, tpr, _ in class_rocs:
class_aurocs.append(auc(fpr, tpr, reorder=False))
return torch.mean(torch.stack(class_aurocs))
Codecov Report
@@ Coverage Diff @@
## master #4236 +/- ##
=======================================
+ Coverage 90% 93% +3%
=======================================
Files 113 113
Lines 8232 8191 -41
=======================================
+ Hits 7387 7612 +225
+ Misses 845 579 -266 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great PR ! Some extra tests on matrix size would be great !
"Multiclass AUROC metric expects the target scores to be" | ||
" probabilities, i.e. they should sum up to 1.0 over classes") | ||
|
||
if torch.unique(target).size(0) != pred.size(1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be torch.unique(target).size(0) <= pred.size(1)
target = torch.tensor([0, 0, 0, 0])
>>> torch.unique(target).size(0)
1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or we could have a get_num_classes utils too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well the metric is undefined when torch.unique(target).size(0) < pred.size(1)
, that's why there is a strict equals. The way this implementation works (based on sklearn's version) is, that it uses a one-vs-rest strategy of computing the AUROC. For n
classes it computes n
binary AUROCs, for each class in turn, it is considered a positive class and all the other classes negative. Then it averages those.
E.g., for n=3
and target=[0, 1, 1, 2]
, for class 0
we binarize the target to make 0
the positive class: [1, 0, 0, 0]
and compute the AUC
of ROC
of that.
If a target label is not present in the target
, e.g. n=3
and target=[0, 1, 1, 1]
then for the absent class 2
the binarized target would look like [0, 0, 0, 0]
(all negative) and ROC
cannot be computed (would raise an error). Consequently, the whole multiclass AUROC is undefined in that case.
As for get_num_classes
there already is such a util, but it does something different than we need here. It doesn't look at dimensions of the predictions, just at the max value in both and deduced num classes from that (which when I think about it now, could fail silently for example when n_cls=5
but target=[0,1,2,3]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense !
>>> target = torch.tensor([0, 1, 3, 2]) | ||
>>> multiclass_auroc(pred, target) # doctest: +NORMALIZE_WHITESPACE | ||
tensor(0.6667) | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, we should check pred.size(0) == target.size(0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add that of course but this is not a check that is done in any other metric implementation. So if it's done here it should probably be done everywhere. If that's desired, I could add a helper to classification.py
def check_batch_dims(pred, target):
if not pred.size(0) == target.size(0):
raise ValueError(f"Batch size for prediction ({pred.size(0)}) and target ({target.size(0)}) must be equal.")
Would that work? Then this helper could be used in each metric instead of copy pasting the if
clause and the exception.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we are slowly unifying the functional and class based interface, we are doing more checks for shape, so this will come in a future PR :]
"Multiclass AUROC metric expects the target scores to be" | ||
" probabilities, i.e. they should sum up to 1.0 over classes") | ||
|
||
if torch.unique(target).size(0) != pred.size(1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense !
CHANGELOG.md
Outdated
@@ -106,6 +106,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). | |||
|
|||
- Added trace functionality to the function `to_torchscript` ([#4142](https://github.com/PyTorchLightning/pytorch-lightning/pull/4142)) | |||
|
|||
- Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ddrevicky could you move this to the unreleased section?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be okay now.
* Add functional multiclass AUROC metric * Add multiclass_auroc tests * fixup! Add functional multiclass AUROC metric * fixup! fixup! Add functional multiclass AUROC metric * Add multiclass_auroc doc reference * Update CHANGELOG * formatting * Shorter error message regex match in tests * Set num classes as pytest parameter * formatting * Update CHANGELOG Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Nicki Skafte <[email protected]> (cherry picked from commit 38bb4e2)
What does this PR do?
Implements functional multiclass AUROC.
Fixes #3304
Notes on the code:
Had to pass
reorder=False
forauc
because tests againstsklearn
kept showing different values and debugging showed that the difference was actually coming out of ourauc
usingtorch.argsort
which is unstable. A short colab notebook documenting this. I submitted a separate issue #4237 for that.Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃