From 3df4e1b43347415af209c91df6aad3cd397519dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Barz?= Date: Thu, 23 Feb 2023 10:31:32 +0100 Subject: [PATCH] Improve speed and memory consumption of binned `PrecisionRecallCurve` (#1493) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Björn Barz Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka --- CHANGELOG.md | 4 + .../classification/precision_recall_curve.py | 85 ++++++++++++++++++- 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index af15c429c3b..b26bcc71c74 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,8 +43,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Extend `EnumStr` raising `ValueError` for invalid value ([#1479](https://github.com/Lightning-AI/metrics/pull/1479)) +- Improve speed and memory consumption of binned `PrecisionRecallCurve` with large number of samples ([#1493](https://github.com/Lightning-AI/metrics/pull/1493)) + + - Changed `__iter__` method from raising `NotImplementedError` to `TypeError` by setting to `None` ([#1538](https://github.com/Lightning-AI/metrics/pull/1538)) + ### Deprecated - diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 79737a77a07..374cc3d6d9f 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -194,6 +194,23 @@ def _binary_precision_recall_curve_update( """ if thresholds is None: return preds, target + if preds.numel() <= 50_000: + update_fn = _binary_precision_recall_curve_update_vectorized + else: + update_fn = _binary_precision_recall_curve_update_loop + return update_fn(preds, target, thresholds) + + +def _binary_precision_recall_curve_update_vectorized( + preds: Tensor, + target: Tensor, + thresholds: Tensor, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Returns the multi-threshold confusion matrix to calculate the pr-curve with. + + This implementation is vectorized and faster than `_binary_precision_recall_curve_update_loop` for small + numbers of samples (up to 50k) but less memory- and time-efficient for more samples. + """ len_t = len(thresholds) preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0)).long() # num_samples x num_thresholds unique_mapping = preds_t + 2 * target.unsqueeze(-1) + 4 * torch.arange(len_t, device=target.device) @@ -201,6 +218,30 @@ def _binary_precision_recall_curve_update( return bins.reshape(len_t, 2, 2) +def _binary_precision_recall_curve_update_loop( + preds: Tensor, + target: Tensor, + thresholds: Tensor, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Returns the multi-threshold confusion matrix to calculate the pr-curve with. + + This implementation loops over thresholds and is more memory-efficient than + `_binary_precision_recall_curve_update_vectorized`. However, it is slowwer for small + numbers of samples (up to 50k). + """ + len_t = len(thresholds) + target = target == 1 + confmat = thresholds.new_empty((len_t, 2, 2), dtype=torch.int64) + # Iterate one threshold at a time to conserve memory + for i in range(len_t): + preds_t = preds >= thresholds[i] + confmat[i, 1, 1] = (target & preds_t).sum() + confmat[i, 0, 1] = ((~target) & preds_t).sum() + confmat[i, 1, 0] = (target & (~preds_t)).sum() + confmat[:, 0, 0] = len(preds_t) - confmat[:, 0, 1] - confmat[:, 1, 0] - confmat[:, 1, 1] + return confmat + + def _binary_precision_recall_curve_compute( state: Union[Tensor, Tuple[Tensor, Tensor]], thresholds: Optional[Tensor], @@ -409,8 +450,25 @@ def _multiclass_precision_recall_curve_update( """ if thresholds is None: return preds, target + if preds.numel() * num_classes <= 1_000_000: + update_fn = _multiclass_precision_recall_curve_update_vectorized + else: + update_fn = _multiclass_precision_recall_curve_update_loop + return update_fn(preds, target, num_classes, thresholds) + + +def _multiclass_precision_recall_curve_update_vectorized( + preds: Tensor, + target: Tensor, + num_classes: int, + thresholds: Tensor, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Returns the multi-threshold confusion matrix to calculate the pr-curve with. + + This implementation is vectorized and faster than `_binary_precision_recall_curve_update_loop` for small + numbers of samples but less memory- and time-efficient for more samples. + """ len_t = len(thresholds) - # num_samples x num_classes x num_thresholds preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0).unsqueeze(0)).long() target_t = torch.nn.functional.one_hot(target, num_classes=num_classes) unique_mapping = preds_t + 2 * target_t.unsqueeze(-1) @@ -420,6 +478,31 @@ def _multiclass_precision_recall_curve_update( return bins.reshape(len_t, num_classes, 2, 2) +def _multiclass_precision_recall_curve_update_loop( + preds: Tensor, + target: Tensor, + num_classes: int, + thresholds: Tensor, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Returns the state to calculate the pr-curve with. + + This implementation loops over thresholds and is more memory-efficient than + `_binary_precision_recall_curve_update_vectorized`. However, it is slowwer for small + numbers of samples. + """ + len_t = len(thresholds) + target_t = torch.nn.functional.one_hot(target, num_classes=num_classes) + confmat = thresholds.new_empty((len_t, num_classes, 2, 2), dtype=torch.int64) + # Iterate one threshold at a time to conserve memory + for i in range(len_t): + preds_t = preds >= thresholds[i] + confmat[i, :, 1, 1] = (target_t & preds_t).sum(dim=0) + confmat[i, :, 0, 1] = ((~target_t) & preds_t).sum(dim=0) + confmat[i, :, 1, 0] = (target_t & (~preds_t)).sum(dim=0) + confmat[:, :, 0, 0] = len(preds_t) - confmat[:, :, 0, 1] - confmat[:, :, 1, 0] - confmat[:, :, 1, 1] + return confmat + + def _multiclass_precision_recall_curve_compute( state: Union[Tensor, Tuple[Tensor, Tensor]], num_classes: int,