-
Notifications
You must be signed in to change notification settings - Fork 412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Constant-memory implementation of common threshold-varying metrics #625
Comments
Hi! thanks for your contribution!, great first issue! |
I'm currently struggling with AUROC's memory footprint, so this would honestly be amazing! |
Hi @timesler, This sounds like a great addition to torchmetrics. We are aware that some of our implementations suffer from a potential huge memory footprint, so this would be a great to have metrics that solves this. Please feel free to send a PR :] One small question: would it make sense to combine binned and non-binned metrics into a single metric with a parameter to change between them e.g. class AUROC(Metric):
def __init__(self, reduce_memory=True/False):
self.reduce_memory = reduce_memory
...
def update(self, ...):
if self.reduce_memory:
binned_update(...)
else:
current_update(...) instead of having multiple of the same metric? |
I'm beginning some exploratory work on this feature now in consultation with @timesler . |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
We're still working on this. |
Issue will be fixed by classification refactor: see this issue #1001 and this PR #1195 for all changes Small recap: This issue ask for constant memory implementations of common metrics like from torchmetrics.functional import binary_roc
import torch
preds = torch.rand(20)
target = torch.randint(0, 2, (20,))
binary_roc(preds, target, thresholds=None) # accurate and memory intensive
# (tensor([0.0000, 0.1667, 0.3333, 0.3333, 0.3333, 0.3333, 0.5000, 0.6667, 0.6667,
# 0.6667, 0.6667, 0.8333, 0.8333, 0.8333, 1.0000, 1.0000, 1.0000, 1.0000,
# 1.0000, 1.0000, 1.0000]),
# tensor([0.0000, 0.0000, 0.0000, 0.0714, 0.1429, 0.2143, 0.2143, 0.2143, 0.2857,
# 0.3571, 0.4286, 0.4286, 0.5000, 0.5714, 0.5714, 0.6429, 0.7143, 0.7857,
# 0.8571, 0.9286, 1.0000]),
# tensor([1.0000, 0.9995, 0.8895, 0.8621, 0.8426, 0.8204, 0.8044, 0.7560, 0.7169,
# 0.7023, 0.6685, 0.6194, 0.5599, 0.5071, 0.4728, 0.4574, 0.4332, 0.2989,
# 0.2535, 0.2446, 0.2025]))
binary_roc(preds, target, thresholds=10) # less accuracy and memory constant
# (tensor([0.0000, 0.3333, 0.5000, 0.6667, 0.8333, 1.0000, 1.0000, 1.0000, 1.0000,
# 1.0000]),
# tensor([0.0000, 0.0000, 0.2143, 0.4286, 0.5000, 0.6429, 0.7143, 0.9286, 1.0000,
# 1.0000]),
# tensor([1.0000, 0.8889, 0.7778, 0.6667, 0.5556, 0.4444, 0.3333, 0.2222, 0.1111,
# 0.0000])) Issue will be closed when #1195 is merged. |
Thanks to the developers for their great work. I (and my work) use this package heavily and will continue to do so :).
🚀 Feature
Related to #128, it would be great to also have constant-memory implementations of the ROC curve and the AUC metric. Given that this has already been implemented for precision-recall, the work is 90% done, it just needs a minor extension.
Motivation
The current implementation of AUROC consumes a lot of memory.
Pitch
I am happy to submit a PR for this (together with help from @norrishd, @ronrest, @BlakeJC94, & @ryanmseer), and would propose to do it by:
BinnedPrecisionRecallCurve
right now, write a more general base class that implements a "confusion matrix curve", where each of the TP, TN, FP, FN are calculated for a set of defined threshold values. An argument to the parent class could allow child classes to specify a subset of these 4 values to prevent unnecessary calculation. The base class could be called something likeConfusionMatrixCurve
,StatScoresCurve
, orBinnedStatScores
(preferences welcome).BinnedPrecisionRecallCurve
by inheriting from this base class.BinnedROC
that inherits from the base class.BinnedAUROC
, andBinnedSpecificityAtFixedSensitivity
.I've implemented similar functionality in a different metrics package here, but for this implementation I would of course follow the patterns and conventions of torchmetrics as best as I could.
Alternatives
The text was updated successfully, but these errors were encountered: