Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constant-memory implementation of common threshold-varying metrics #625

Closed
timesler opened this issue Nov 16, 2021 · 7 comments · Fixed by #1195
Closed

Constant-memory implementation of common threshold-varying metrics #625

timesler opened this issue Nov 16, 2021 · 7 comments · Fixed by #1195
Labels
enhancement New feature or request help wanted Extra attention is needed New metric
Milestone

Comments

@timesler
Copy link

timesler commented Nov 16, 2021

Thanks to the developers for their great work. I (and my work) use this package heavily and will continue to do so :).

🚀 Feature

Related to #128, it would be great to also have constant-memory implementations of the ROC curve and the AUC metric. Given that this has already been implemented for precision-recall, the work is 90% done, it just needs a minor extension.

Motivation

The current implementation of AUROC consumes a lot of memory.

Pitch

I am happy to submit a PR for this (together with help from @norrishd, @ronrest, @BlakeJC94, & @ryanmseer), and would propose to do it by:

  1. Using some logic similar to what is in BinnedPrecisionRecallCurve right now, write a more general base class that implements a "confusion matrix curve", where each of the TP, TN, FP, FN are calculated for a set of defined threshold values. An argument to the parent class could allow child classes to specify a subset of these 4 values to prevent unnecessary calculation. The base class could be called something like ConfusionMatrixCurve, StatScoresCurve, or BinnedStatScores (preferences welcome).
  2. Simplify BinnedPrecisionRecallCurve by inheriting from this base class.
  3. Implement a constant memory version of the ROC called BinnedROC that inherits from the base class.
  4. Similarly, we could also add constant memory implementations of things like BinnedAUROC, and BinnedSpecificityAtFixedSensitivity.

I've implemented similar functionality in a different metrics package here, but for this implementation I would of course follow the patterns and conventions of torchmetrics as best as I could.

Alternatives

@timesler timesler added the enhancement New feature or request label Nov 16, 2021
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@miccio-dk
Copy link

I'm currently struggling with AUROC's memory footprint, so this would honestly be amazing!

@SkafteNicki
Copy link
Member

Hi @timesler,

This sounds like a great addition to torchmetrics. We are aware that some of our implementations suffer from a potential huge memory footprint, so this would be a great to have metrics that solves this.

Please feel free to send a PR :]

One small question: would it make sense to combine binned and non-binned metrics into a single metric with a parameter to change between them e.g.

class AUROC(Metric):
    def __init__(self, reduce_memory=True/False):
        self.reduce_memory = reduce_memory
        ...

    def update(self, ...):
        if self.reduce_memory:
            binned_update(...)
        else:
            current_update(...)

instead of having multiple of the same metric?

@ryandaryl
Copy link

I'm beginning some exploratory work on this feature now in consultation with @timesler .

@Borda Borda added the help wanted Extra attention is needed label Jan 6, 2022
@stale
Copy link

stale bot commented Mar 19, 2022

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Mar 19, 2022
@ryandaryl
Copy link

We're still working on this.

@stale stale bot removed the wontfix label Mar 20, 2022
@Borda Borda added this to the v0.10 milestone Jul 27, 2022
@SkafteNicki
Copy link
Member

Issue will be fixed by classification refactor: see this issue #1001 and this PR #1195 for all changes

Small recap: This issue ask for constant memory implementations of common metrics like roc ect. After the refactor, metrics such as roc, auroc, precision_recall, average_precision all now supports constant memory implementations by providing the thresholds argument. If thresholds=None the standard approach is used (which is accurate but not memory constant) and if thresholds=10 (int) or thresholds=[0.1, 0.3, 0.5, 0.7] (list of floats) an binned version will be used that is less accuracy but memory constant. Example below for roc:

from torchmetrics.functional import binary_roc
import torch

preds = torch.rand(20)
target = torch.randint(0, 2, (20,))

binary_roc(preds, target, thresholds=None)  # accurate and memory intensive 
# (tensor([0.0000, 0.1667, 0.3333, 0.3333, 0.3333, 0.3333, 0.5000, 0.6667, 0.6667,
#         0.6667, 0.6667, 0.8333, 0.8333, 0.8333, 1.0000, 1.0000, 1.0000, 1.0000,
#         1.0000, 1.0000, 1.0000]),
# tensor([0.0000, 0.0000, 0.0000, 0.0714, 0.1429, 0.2143, 0.2143, 0.2143, 0.2857,
#         0.3571, 0.4286, 0.4286, 0.5000, 0.5714, 0.5714, 0.6429, 0.7143, 0.7857,
#         0.8571, 0.9286, 1.0000]),
# tensor([1.0000, 0.9995, 0.8895, 0.8621, 0.8426, 0.8204, 0.8044, 0.7560, 0.7169,
#         0.7023, 0.6685, 0.6194, 0.5599, 0.5071, 0.4728, 0.4574, 0.4332, 0.2989,
#         0.2535, 0.2446, 0.2025]))
binary_roc(preds, target, thresholds=10)   # less accuracy and memory constant
# (tensor([0.0000, 0.3333, 0.5000, 0.6667, 0.8333, 1.0000, 1.0000, 1.0000, 1.0000,
#         1.0000]),
#  tensor([0.0000, 0.0000, 0.2143, 0.4286, 0.5000, 0.6429, 0.7143, 0.9286, 1.0000,
#         1.0000]),
# tensor([1.0000, 0.8889, 0.7778, 0.6667, 0.5556, 0.4444, 0.3333, 0.2222, 0.1111,
#         0.0000]))

Issue will be closed when #1195 is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed New metric
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants