Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 18, 2022
1 parent b6f5b5b commit f6997fe
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 18 deletions.
25 changes: 11 additions & 14 deletions tests/retrieval/test_recall_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,28 @@
from typing import Tuple, Union

import numpy as np
from numpy import array
import pytest
import torch
from torch import tensor
from torch import Tensor
from numpy import array
from torch import Tensor, tensor

from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from tests.retrieval.test_recall import _recall_at_k
from tests.retrieval.test_precision import _precision_at_k
from tests.retrieval.helpers import get_group_indexes
from tests.retrieval.test_precision import _precision_at_k
from tests.retrieval.test_recall import _recall_at_k
from torchmetrics import RetrievalRecallAtFixedPrecision

seed_all(42)


def _compute_recall_at_precision_metric(
preds: Union[Tensor, array],
target: Union[Tensor, array],
indexes: Union[Tensor, array] = None,
max_k: int = None,
min_precision: float = 0.0,
ignore_index: int = None,
preds: Union[Tensor, array],
target: Union[Tensor, array],
indexes: Union[Tensor, array] = None,
max_k: int = None,
min_precision: float = 0.0,
ignore_index: int = None,
) -> Tuple[Tensor, int]:
"""Compute metric with multiple iterations over every query predictions set."""
recalls, precisions = [], []
Expand Down Expand Up @@ -85,9 +84,7 @@ def _compute_recall_at_precision_metric(
recalls = tensor(recalls).mean(dim=0)
precisions = tensor(precisions).mean(dim=0)

recalls_at_k = [
(r, k) for p, r, k in zip(precisions, recalls, max_k_range) if p >= min_precision
]
recalls_at_k = [(r, k) for p, r, k in zip(precisions, recalls, max_k_range) if p >= min_precision]

assert recalls_at_k

Expand Down
5 changes: 1 addition & 4 deletions torchmetrics/retrieval/recall_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,7 @@ def compute(self) -> Tuple[Tensor, int]:
recalls_at_k = [(r, k) for p, r, k in prk if p >= self.min_precision]

if not recalls_at_k:
raise MinPrecisionError(
f'Not found recalls to precision: {self.min_precision}. '
f'Try lower values.'
)
raise MinPrecisionError(f"Not found recalls to precision: {self.min_precision}. " f"Try lower values.")

# return best pair recall, k
return max(recalls_at_k)

0 comments on commit f6997fe

Please sign in to comment.