Skip to content

Commit

Permalink
style: format code to comply with pre-commit hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
rittik9 committed Sep 1, 2024
1 parent 1438a24 commit b9716b2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def compute(self) -> Tensor:
fn,
average=self.average,
multidim_average=self.multidim_average,
ignore_index = self.ignore_index,
ignore_index=self.ignore_index,
top_k=self.top_k,
zero_division=self.zero_division,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _precision_recall_reduce(
return _safe_divide(tp, tp + different_stat, zero_division)

score = _safe_divide(tp, tp + different_stat, zero_division)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, ignore_index = ignore_index, top_k=top_k)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, ignore_index=ignore_index, top_k=top_k)


def binary_precision(
Expand Down
13 changes: 10 additions & 3 deletions src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,14 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens


def _adjust_weights_safe_divide(
score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor, ignore_index:Optional[int] = None, top_k: int = 1
score: Tensor,
average: Optional[str],
multilabel: bool,
tp: Tensor,
fp: Tensor,
fn: Tensor,
ignore_index: Optional[int] = None,
top_k: int = 1,
) -> Tensor:
if average is None or average == "none":
return score
Expand All @@ -71,10 +78,10 @@ def _adjust_weights_safe_divide(
weights = torch.ones_like(score)
if not multilabel:
weights[tp + fp + fn == 0 if top_k == 1 else tp + fn == 0] = 0.0

if ignore_index is not None and 0 <= ignore_index < len(score):
weights[ignore_index] = 0.0

return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)


Expand Down

0 comments on commit b9716b2

Please sign in to comment.