-
Notifications
You must be signed in to change notification settings - Fork 412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BinaryPrecisionRecallCurve
for large datasets (>100 million samples)
#1309
base: master
Are you sure you want to change the base?
Changes from all commits
48188d5
1e5f130
b98edf7
ca963bb
6698d76
d5a8309
a1b687d
86b1178
bea71f6
0b3afad
560e899
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
humanfriendly |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -11,14 +11,18 @@ | |||||||||||||||||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||||||||
# limitations under the License. | ||||||||||||||||||||||||
import warnings | ||||||||||||||||||||||||
from enum import Enum | ||||||||||||||||||||||||
from typing import Any, List, Optional, Tuple, Union | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
import humanfriendly | ||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||
from torch import Tensor | ||||||||||||||||||||||||
from typing_extensions import Literal | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
from torchmetrics.functional.classification.precision_recall_curve import ( | ||||||||||||||||||||||||
_adjust_threshold_arg, | ||||||||||||||||||||||||
_binary_clf_curve, | ||||||||||||||||||||||||
_binary_precision_recall_curve_arg_validation, | ||||||||||||||||||||||||
_binary_precision_recall_curve_compute, | ||||||||||||||||||||||||
_binary_precision_recall_curve_format, | ||||||||||||||||||||||||
|
@@ -38,6 +42,26 @@ | |||||||||||||||||||||||
from torchmetrics.metric import Metric | ||||||||||||||||||||||||
from torchmetrics.utilities.data import dim_zero_cat | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
_DYNAMIC_THRESHOLDS_MIN_NSAMPLES = 1024**2 # 1MiB,TODO: find a better way to estimate a reasonable minimum | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
def _budget_bytes_to_nsamples(budget_bytes: int): | ||||||||||||||||||||||||
# assume that both preds and target ("* 2") will be of size (N, 1) and of type float32 (4 bytes) | ||||||||||||||||||||||||
return budget_bytes / (2 * 4) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
def _validate_memory_budget(budget: int): | ||||||||||||||||||||||||
if budget <= 0: | ||||||||||||||||||||||||
raise ValueError("Budget must be larger than 0.") | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if _budget_bytes_to_nsamples(budget) <= _DYNAMIC_THRESHOLDS_MIN_NSAMPLES: | ||||||||||||||||||||||||
warnings.warn( | ||||||||||||||||||||||||
f"Budget is relatively small ({humanfriendly.format_size(budget, binary=True)}). " | ||||||||||||||||||||||||
"The dynamic mode is recommended for bigger samples." | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
return budget | ||||||||||||||||||||||||
Comment on lines
+53
to
+63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we only have the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the multiclass/label cases the estimation (number of samples) <-> (memory consumption) would be different. I don't have a very strong opinion on this, i will put some more thought on the multi* cases first. |
||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
class BinaryPrecisionRecallCurve(Metric): | ||||||||||||||||||||||||
r"""Computes the precision-recall curve for binary tasks. The curve consist of multiple pairs of precision and | ||||||||||||||||||||||||
|
@@ -70,6 +94,7 @@ class BinaryPrecisionRecallCurve(Metric): | |||||||||||||||||||||||
- If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation | ||||||||||||||||||||||||
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as | ||||||||||||||||||||||||
bins for the calculation. | ||||||||||||||||||||||||
- If set to a `str`, the value is interpreted as a memory budget and the dynamic mode approach is used. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
validate_args: bool indicating if input arguments and tensors should be validated for correctness. | ||||||||||||||||||||||||
Set to ``False`` for faster computations. | ||||||||||||||||||||||||
|
@@ -101,9 +126,25 @@ class BinaryPrecisionRecallCurve(Metric): | |||||||||||||||||||||||
higher_is_better: Optional[bool] = None | ||||||||||||||||||||||||
full_state_update: bool = False | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
class _ComputationMode(Enum): | ||||||||||||||||||||||||
"""Internal state of the dynamic mode.""" | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
BINNED = "binned" | ||||||||||||||||||||||||
NON_BINNED = "non-binned" | ||||||||||||||||||||||||
NON_BINNED_DYNAMIC = "non-binned-dynamic" | ||||||||||||||||||||||||
Comment on lines
+129
to
+134
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems weird to me having a class inside a class def? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's kind of rare indeed but i usually do this in such cases where it is strictly only used internally 🤷♂️ Should i pop it out? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes please |
||||||||||||||||||||||||
|
||||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||||
def _deduce_computation_mode(thresholds: Optional[Union[int, List[float], Tensor, str]]) -> _ComputationMode: | ||||||||||||||||||||||||
if isinstance(thresholds, str): | ||||||||||||||||||||||||
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED_DYNAMIC | ||||||||||||||||||||||||
elif thresholds is None: | ||||||||||||||||||||||||
return BinaryPrecisionRecallCurve._ComputationMode.NON_BINNED | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
return BinaryPrecisionRecallCurve._ComputationMode.BINNED | ||||||||||||||||||||||||
Comment on lines
+138
to
+143
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||
self, | ||||||||||||||||||||||||
thresholds: Optional[Union[int, List[float], Tensor]] = None, | ||||||||||||||||||||||||
thresholds: Optional[Union[int, List[float], Tensor, str]] = None, | ||||||||||||||||||||||||
ignore_index: Optional[int] = None, | ||||||||||||||||||||||||
validate_args: bool = True, | ||||||||||||||||||||||||
**kwargs: Any, | ||||||||||||||||||||||||
|
@@ -115,8 +156,23 @@ def __init__( | |||||||||||||||||||||||
self.ignore_index = ignore_index | ||||||||||||||||||||||||
self.validate_args = validate_args | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
self._computation_mode = self._deduce_computation_mode(thresholds) | ||||||||||||||||||||||||
thresholds = _adjust_threshold_arg(thresholds) | ||||||||||||||||||||||||
if thresholds is None: | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if self._computation_mode == self._ComputationMode.NON_BINNED_DYNAMIC: | ||||||||||||||||||||||||
self._memory_budget_bytes = _validate_memory_budget(thresholds) | ||||||||||||||||||||||||
# used after the switch to binned mode | ||||||||||||||||||||||||
self.register_buffer("thresholds", None) | ||||||||||||||||||||||||
self.add_state( | ||||||||||||||||||||||||
"confmat", | ||||||||||||||||||||||||
default=torch.zeros(_budget_bytes_to_nsamples(self._memory_budget_bytes), 2, 2, dtype=torch.long), | ||||||||||||||||||||||||
dist_reduce_fx="sum", | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
# they are deleted after the switch to binned mode | ||||||||||||||||||||||||
self.preds = [] | ||||||||||||||||||||||||
self.target = [] | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
elif thresholds is None: | ||||||||||||||||||||||||
self.thresholds = thresholds | ||||||||||||||||||||||||
self.add_state("preds", default=[], dist_reduce_fx="cat") | ||||||||||||||||||||||||
self.add_state("target", default=[], dist_reduce_fx="cat") | ||||||||||||||||||||||||
|
@@ -137,6 +193,23 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore | |||||||||||||||||||||||
self.preds.append(state[0]) | ||||||||||||||||||||||||
self.target.append(state[1]) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if self._computation_mode != self._ComputationMode.NON_BINNED_DYNAMIC: | ||||||||||||||||||||||||
return | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
all_preds = dim_zero_cat(self.preds) | ||||||||||||||||||||||||
mem_used = all_preds.element_size() * all_preds.nelement() * 2 # 2 accounts for the target | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if mem_used < self._memory_budget_bytes: | ||||||||||||||||||||||||
return | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# switch to binned mode | ||||||||||||||||||||||||
self.preds, self.target = all_preds, dim_zero_cat(self.target) | ||||||||||||||||||||||||
_, _, self.thresholds = _binary_clf_curve(self.preds, self.target) | ||||||||||||||||||||||||
# if the number of thr | ||||||||||||||||||||||||
self.confmat = _binary_precision_recall_curve_update(self.preds, self.target, self.thresholds) | ||||||||||||||||||||||||
del self.preds, self.target | ||||||||||||||||||||||||
self._computation_mode = self._ComputationMode.BINNED | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def compute(self) -> Tuple[Tensor, Tensor, Tensor]: | ||||||||||||||||||||||||
if self.thresholds is None: | ||||||||||||||||||||||||
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] | ||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -14,6 +14,7 @@ | |||||
|
||||||
from typing import List, Optional, Sequence, Tuple, Union | ||||||
|
||||||
import humanfriendly | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We really do not want to introduce any new dependencies. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I saw that one coming haha. So, i put it anyway because it feels like the kind of functionality pruned to silly mistakes while a tiny library like this has it neatly packed in. I can try to make a minimal version of it based on the library's source code. Is that a better solution? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets make it conditional, if user already have it, then use it if module_available("humanfriendly"):
import humanfriendly
else:
humanfriendly = None |
||||||
import torch | ||||||
from torch import Tensor, tensor | ||||||
from torch.nn import functional as F | ||||||
|
@@ -80,13 +81,20 @@ def _binary_clf_curve( | |||||
|
||||||
|
||||||
def _adjust_threshold_arg( | ||||||
thresholds: Optional[Union[int, List[float], Tensor]] = None, device: Optional[torch.device] = None | ||||||
) -> Optional[Tensor]: | ||||||
"""Utility function for converting the threshold arg for list and int to tensor format.""" | ||||||
thresholds: Optional[Union[int, List[float], Tensor, str]] = None, device: Optional[torch.device] = None | ||||||
) -> Optional[Union[Tensor, int]]: | ||||||
"""Utility function for converting the threshold arg. | ||||||
|
||||||
- list and int -> tensor | ||||||
- None -> None | ||||||
- str -> int (memory budget) in Mb | ||||||
""" | ||||||
if isinstance(thresholds, int): | ||||||
thresholds = torch.linspace(0, 1, thresholds, device=device) | ||||||
if isinstance(thresholds, list): | ||||||
thresholds = torch.tensor(thresholds, device=device) | ||||||
if isinstance(thresholds, str): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
thresholds = humanfriendly.parse_size(thresholds, binary=True) | ||||||
return thresholds | ||||||
|
||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets have this as optional