diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 3e324a7f353..eff2bd3998c 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -103,11 +103,6 @@ class BinnedPrecisionRecallCurve(Metric): tensor([0.0000, 0.5000, 1.0000])] """ - TPs: Tensor - FPs: Tensor - FNs: Tensor - thresholds: Tensor - def __init__( self, num_classes: int, @@ -165,7 +160,7 @@ def compute(self) -> Tuple[Tensor, Tensor, Tensor]: precisions = torch.cat([ precisions, torch.ones(self.num_classes, 1, dtype=precisions.dtype, device=precisions.device) ], - dim=1) + dim=1) recalls = torch.cat([recalls, torch.zeros(self.num_classes, 1, dtype=recalls.dtype, device=recalls.device)], dim=1)