-
Notifications
You must be signed in to change notification settings - Fork 412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BinaryPrecisionRecallCurve
for large datasets (>100 million samples)
#1309
base: master
Are you sure you want to change the base?
BinaryPrecisionRecallCurve
for large datasets (>100 million samples)
#1309
Conversation
Hi @jpcbertoldo, thanks for proposing this enhancement (I have also seen your thread on slack) metric = BinaryPrecisionRecallCurve(thresholds="12345")
metric = BinaryPrecisionRecallCurve(thresholds="250mb")
metric = BinaryPrecisionRecallCurve(thresholds="3gb") all enabled this feature. What do you think about that? |
Yeah makes sense. I was actually thinking of considering just the case |
@SkafteNicki by the way, I just saw this
Could we skip this step for this time? : ) |
b98edf7
to
d5a8309
Compare
…rtoldo/torchmetrics into jpcbertoldo/roc-for-large-datasets
That is normally how we like it. But since you already had implemented some, no need to open a issue about this. It is recommended because we do not want people to implement a lot before talking with us, because we may decide it is not a feature we want and therefore their work would go to waste. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you made any thoughts on extending this to multiclass / multilabel?
@@ -14,6 +14,7 @@ | |||
|
|||
from typing import List, Optional, Sequence, Tuple, Union | |||
|
|||
import humanfriendly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We really do not want to introduce any new dependencies.
The conversion from mb and gb seems to be something we can do ourself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw that one coming haha.
So, i put it anyway because it feels like the kind of functionality pruned to silly mistakes while a tiny library like this has it neatly packed in.
I can try to make a minimal version of it based on the library's source code. Is that a better solution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets make it conditional, if user already have it, then use it
if module_available("humanfriendly"):
import humanfriendly
else:
humanfriendly = None
class _ComputationMode(Enum): | ||
"""Internal state of the dynamic mode.""" | ||
|
||
BINNED = "binned" | ||
NON_BINNED = "non-binned" | ||
NON_BINNED_DYNAMIC = "non-binned-dynamic" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems weird to me having a class inside a class def?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's kind of rare indeed but i usually do this in such cases where it is strictly only used internally 🤷♂️
Should i pop it out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes please
def _validate_memory_budget(budget: int): | ||
if budget <= 0: | ||
raise ValueError("Budget must be larger than 0.") | ||
|
||
if _budget_bytes_to_nsamples(budget) <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES: | ||
warnings.warn( | ||
f"Budget is relatively small ({humanfriendly.format_size(budget, binary=True)}). " | ||
"The dynamic mode is recommended for bigger samples." | ||
) | ||
|
||
return budget |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we only have the thresholds
argument i guess all this logic can be moved to the _adjust_threshold_arg
function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the multiclass/label cases the estimation (number of samples) <-> (memory consumption) would be different.
I don't have a very strong opinion on this, i will put some more thought on the multi* cases first.
Just as much as in my reply to your comment 😬. I'll be putting some effort on that this weekend :) |
@SkafteNicki i'm having a hard time to make my self available to invest some time on this. |
@stancld could you help here to finish this PR? 🦦 |
from typing import Any, List, Optional, Tuple, Union | ||
|
||
import humanfriendly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets have this as optional
class _ComputationMode(Enum): | ||
"""Internal state of the dynamic mode.""" | ||
|
||
BINNED = "binned" | ||
NON_BINNED = "non-binned" | ||
NON_BINNED_DYNAMIC = "non-binned-dynamic" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes please
if isinstance(thresholds, str): | ||
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED_DYNAMIC | ||
elif thresholds is None: | ||
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED | ||
else: | ||
return BinaryPrecisionRecallCurve._ComputationMode.BINNED |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(thresholds, str): | |
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED_DYNAMIC | |
elif thresholds is None: | |
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED | |
else: | |
return BinaryPrecisionRecallCurve._ComputationMode.BINNED | |
if isinstance(thresholds, str): | |
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED_DYNAMIC | |
if thresholds is None: | |
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED | |
return BinaryPrecisionRecallCurve._ComputationMode.BINNED |
@@ -14,6 +14,7 @@ | |||
|
|||
from typing import List, Optional, Sequence, Tuple, Union | |||
|
|||
import humanfriendly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets make it conditional, if user already have it, then use it
if module_available("humanfriendly"):
import humanfriendly
else:
humanfriendly = None
if isinstance(thresholds, int): | ||
thresholds = torch.linspace(0, 1, thresholds, device=device) | ||
if isinstance(thresholds, list): | ||
thresholds = torch.tensor(thresholds, device=device) | ||
if isinstance(thresholds, str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(thresholds, str): | |
if isinstance(thresholds, str) and humanfriendly: |
@jpcbertoldo, how is it going here? I think we are on good path... :) |
Hi @Borda, I lost track of this since a while (vacation, other priorities...). I will try to come back to it soon! |
@Borda I was thinking if we could do a simpler solution for this by computing the metric in 2 rounds ("epochs"). 1st round: just find the min/max, then, with a given maximum number of points, linearly -- or some smarter way ? -- space the thresholds between the min/max. 2nd round: compute the actual TPR/FPRs in an online way (because now the thresholds are known in advance). An alternative for the 1st round: instead of just min/max, keep track of unique values, which may provide information for eventually having parts of the threshold range more or less dense. The set of unique values could be in |
@jpcbertoldo apology for the late reply... |
d0a5568
to
9fc79ae
Compare
What does this PR do?
This PR provides an alternative to this warning:
Context
Currently,
BinaryPrecisionRecallCurve
(and related methods by consequence) has two operation modes, which I will call "computed-thresholds" (argthresholds=None
) and "given-thresholds" (otherwise)."computed-thresholds"
All possible thresholds in the sample ares used in
compute()
. It has a high memory consumption if the number of instances is high because all the preds and targets are kept in memory during the updates.How much is "high" memory consumption?
I will consider 100 million samples the order of magnitude where "high" starts.
This roughly corresponds to ~750Mb of for the tensor
preds
under the "computed-thresholds" mode:1e8 (samples) * 4 (bytes/float32) = 375Mb
."given-thresholds"
Thresholds are pre-defined, so all possible binarizations are known in advance. At each update all are tested, giving as many confusion matrices.
The memory consumption is low because and a function of the number of thresholds
If
num_thresholds = 10e6
, then it uses roughly10e6 * (2 *2) (confmat shape) * 8 (bytes/long) ~= 300mb
which is much lower even with 1 million thresholds.However, it requires the user to know meaningful threshols in advance, which is often not the case.
A possible (inconvenient) solution would be to estimate the thresholds on a first call then compute the curve.
Solution
I propose to create a hybrid strategy where the update switches from "computed-thresholds" to "given-thresholds" dinamically.
Given a
budget
(number of instances), the keeps preds and targets in the state until that budget is reached, then estimates meaningful thresholds and compute the confusion matrices at that point. From there it behaves like "given-thresholds" mode.API sketch
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.