diff --git a/tests/retrieval/test_recall_precision.py b/tests/retrieval/test_recall_precision.py index 5b3a1a8164c..ad405759e8e 100644 --- a/tests/retrieval/test_recall_precision.py +++ b/tests/retrieval/test_recall_precision.py @@ -13,17 +13,16 @@ # limitations under the License. -from functools import partial from typing import Tuple, Union import numpy as np import pytest -import torch from numpy import array from torch import Tensor, tensor +from tests import MetricTester from tests.helpers import seed_all -from tests.helpers.testers import MetricTester +from tests.retrieval.helpers import _irs from tests.retrieval.helpers import get_group_indexes from tests.retrieval.test_precision import _precision_at_k from tests.retrieval.test_recall import _recall_at_k @@ -106,3 +105,39 @@ def test_compute_recall_at_precision_metric(): min_precision, ) assert res == (tensor(0.5000), 1) + + +@pytest.mark.parametrize( + "indexes,preds,target", + [ + (i, p,t) for i, p, t in zip(_irs.indexes, _irs.preds, _irs.target) + ] +) +@pytest.mark.parametrize("ddp", [False]) +@pytest.mark.parametrize("dist_sync_on_step", [False]) +@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) +@pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail +@pytest.mark.parametrize("max_k", [None, 1, 4, 10]) +@pytest.mark.parametrize("min_precision", [.0, .2]) +class TestRetrievalRecallAtFixedPrecision(MetricTester): + atol = 0.02 + + def test_12312312( + self, indexes, preds, target, ddp, dist_sync_on_step, empty_target_action, ignore_index, max_k, min_precision + ): + + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalRecallAtFixedPrecision, + sk_metric=_compute_recall_at_precision_metric, + dist_sync_on_step=dist_sync_on_step, + metric_args={ + "max_k": max_k, + "min_precision": min_precision, + "ignore_index": ignore_index, + "empty_target_action": empty_target_action, + }, + )