Skip to content

Commit

Permalink
WIP ~ added draft of test
Browse files Browse the repository at this point in the history
  • Loading branch information
enuk1dze committed Apr 18, 2022
1 parent f6997fe commit e0ea387
Showing 1 changed file with 38 additions and 3 deletions.
41 changes: 38 additions & 3 deletions tests/retrieval/test_recall_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@
# limitations under the License.


from functools import partial
from typing import Tuple, Union

import numpy as np
import pytest
import torch
from numpy import array
from torch import Tensor, tensor

from tests import MetricTester
from tests.helpers import seed_all
from tests.helpers.testers import MetricTester
from tests.retrieval.helpers import _irs
from tests.retrieval.helpers import get_group_indexes
from tests.retrieval.test_precision import _precision_at_k
from tests.retrieval.test_recall import _recall_at_k
Expand Down Expand Up @@ -106,3 +105,39 @@ def test_compute_recall_at_precision_metric():
min_precision,
)
assert res == (tensor(0.5000), 1)


@pytest.mark.parametrize(
"indexes,preds,target",
[
(i, p,t) for i, p, t in zip(_irs.indexes, _irs.preds, _irs.target)
]
)
@pytest.mark.parametrize("ddp", [False])
@pytest.mark.parametrize("dist_sync_on_step", [False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail
@pytest.mark.parametrize("max_k", [None, 1, 4, 10])
@pytest.mark.parametrize("min_precision", [.0, .2])
class TestRetrievalRecallAtFixedPrecision(MetricTester):
atol = 0.02

def test_12312312(
self, indexes, preds, target, ddp, dist_sync_on_step, empty_target_action, ignore_index, max_k, min_precision
):

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalRecallAtFixedPrecision,
sk_metric=_compute_recall_at_precision_metric,
dist_sync_on_step=dist_sync_on_step,
metric_args={
"max_k": max_k,
"min_precision": min_precision,
"ignore_index": ignore_index,
"empty_target_action": empty_target_action,
},
)

0 comments on commit e0ea387

Please sign in to comment.