Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ignore_idx to retrieval metrics #676

Merged
merged 22 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d32a19a
implemented feature #672
lucadiliello Dec 10, 2021
e405271
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2021
67752c8
fixed typo in tests
lucadiliello Dec 10, 2021
17b1826
Merge branch 'feature-#672' of https://github.com/lucadiliello/metric…
lucadiliello Dec 10, 2021
ef5a9e3
drop none
Borda Dec 10, 2021
c6eceb4
Merge branch 'master' into feature-#672
lucadiliello Dec 10, 2021
ee2f780
Merge branch 'master' into feature-#672
SkafteNicki Dec 13, 2021
1fc6c01
changelog
SkafteNicki Dec 13, 2021
18ceb14
Update torchmetrics/retrieval/retrieval_hit_rate.py
lucadiliello Dec 13, 2021
9e03dc6
Update torchmetrics/retrieval/retrieval_metric.py
lucadiliello Dec 13, 2021
9868995
Update torchmetrics/retrieval/retrieval_ndcg.py
lucadiliello Dec 13, 2021
3cfa5a0
Update torchmetrics/retrieval/retrieval_precision.py
lucadiliello Dec 13, 2021
bfba4fe
Update torchmetrics/retrieval/retrieval_recall.py
lucadiliello Dec 13, 2021
00e777d
Update torchmetrics/retrieval/retrieval_recall.py
lucadiliello Dec 13, 2021
c05cd79
Update torchmetrics/utilities/checks.py
lucadiliello Dec 13, 2021
6663552
Update torchmetrics/retrieval/retrieval_precision.py
lucadiliello Dec 13, 2021
ec6c510
Update torchmetrics/retrieval/retrieval_ndcg.py
lucadiliello Dec 13, 2021
192ad7e
Update torchmetrics/retrieval/retrieval_hit_rate.py
lucadiliello Dec 13, 2021
c14fb55
Update torchmetrics/retrieval/retrieval_fallout.py
lucadiliello Dec 13, 2021
2c38938
Update torchmetrics/retrieval/retrieval_fallout.py
lucadiliello Dec 13, 2021
5cbac0e
Merge branch 'master' into feature-#672
Borda Dec 13, 2021
a7d0295
added Raises section to all retrieval metrics
lucadiliello Dec 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `MinMaxMetric` to wrappers ([#556](https://github.com/PyTorchLightning/metrics/pull/556))


- Added `ignore_index` to to retrieval metrics ([#676](https://github.com/PyTorchLightning/metrics/pull/676))


### Changed

- Scalar metrics will now consistently have additional dimensions squeezed ([#622](https://github.com/PyTorchLightning/metrics/pull/622))
Expand Down
33 changes: 31 additions & 2 deletions tests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes as _irs_mis_sz
from tests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes_func as _irs_mis_sz_fn
from tests.retrieval.inputs import _input_retrieval_scores_no_target as _irs_no_tgt
from tests.retrieval.inputs import _input_retrieval_scores_with_ignore_index as _irs_ii
from tests.retrieval.inputs import _input_retrieval_scores_wrong_targets as _irs_bad_tgt

seed_all(42)
Expand Down Expand Up @@ -72,6 +73,7 @@ def _compute_sklearn_metric(
indexes: np.ndarray = None,
metric: Callable = None,
empty_target_action: str = "skip",
ignore_index: int = None,
reverse: bool = False,
**kwargs,
) -> Tensor:
Expand All @@ -90,6 +92,10 @@ def _compute_sklearn_metric(
assert isinstance(preds, np.ndarray)
assert isinstance(target, np.ndarray)

if ignore_index is not None:
valid_positions = target != ignore_index
indexes, preds, target = indexes[valid_positions], preds[valid_positions], target[valid_positions]

indexes = indexes.flatten()
preds = preds.flatten()
target = target.flatten()
Expand Down Expand Up @@ -196,6 +202,14 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict:
"`empty_target_action` received a wrong value `casual_argument`.",
dict(empty_target_action="casual_argument"),
),
# check ignore_index is valid
(
_irs.indexes,
_irs.preds,
_irs.target,
"Argument `ignore_index` must be an integer or None.",
dict(ignore_index=-100.0),
),
# check input shapes are consistent
(
_irs_mis_sz.indexes,
Expand Down Expand Up @@ -242,6 +256,14 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict:
"`empty_target_action` received a wrong value `casual_argument`.",
dict(empty_target_action="casual_argument"),
),
# check ignore_index is valid
(
_irs.indexes,
_irs.preds,
_irs.target,
"Argument `ignore_index` must be an integer or None.",
dict(ignore_index=-100.0),
),
# check input shapes are consistent
(
_irs_mis_sz.indexes,
Expand Down Expand Up @@ -292,6 +314,13 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict:
],
)

_default_metric_class_input_arguments_ignore_index = dict(
argnames="indexes,preds,target",
argvalues=[
(_irs_ii.indexes, _irs_ii.preds, _irs_ii.target),
],
)

_default_metric_class_input_arguments_with_non_binary_target = dict(
argnames="indexes,preds,target",
argvalues=[
Expand Down Expand Up @@ -444,7 +473,7 @@ def metric_functional_ignore_indexes(preds, target, indexes):
metric_module=metric_module,
metric_functional=metric_functional_ignore_indexes,
metric_args={"empty_target_action": "neg"},
indexes=indexes, # every additional argument will be passed to RetrievalMAP and _sk_metric_adapted
indexes=indexes, # every additional argument will be passed to the retrieval metric and _sk_metric_adapted
)

def run_precision_test_gpu(
Expand All @@ -467,7 +496,7 @@ def metric_functional_ignore_indexes(preds, target, indexes):
metric_module=metric_module,
metric_functional=metric_functional_ignore_indexes,
metric_args={"empty_target_action": "neg"},
indexes=indexes, # every additional argument will be passed to RetrievalMAP and _sk_metric_adapted
indexes=indexes, # every additional argument will be passed to retrieval metric and _sk_metric_adapted
)

@staticmethod
Expand Down
8 changes: 8 additions & 0 deletions tests/retrieval/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@
target=torch.rand(NUM_BATCHES, 2 * BATCH_SIZE),
)

_input_retrieval_scores_with_ignore_index = Input(
indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)).masked_fill(
mask=torch.randn(NUM_BATCHES, BATCH_SIZE) > 0.5, value=-100
),
)

# with errors
_input_retrieval_scores_no_target = Input(
indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)),
Expand Down
34 changes: 33 additions & 1 deletion tests/retrieval/test_fallout.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments,
_default_metric_class_input_arguments_ignore_index,
_default_metric_functional_input_arguments,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_k,
Expand Down Expand Up @@ -55,6 +56,7 @@ class TestFallOut(RetrievalMetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_class_metric(
Expand All @@ -65,9 +67,39 @@ def test_class_metric(
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
ignore_index: int,
k: int,
):
metric_args = {"empty_target_action": empty_target_action, "k": k}
metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index)

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalFallOut,
sk_metric=_fallout_at_k,
dist_sync_on_step=dist_sync_on_step,
reverse=True,
metric_args=metric_args,
)

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index)
def test_class_metric_ignore_index(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
k: int,
):
metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100)

self.run_class_metric_test(
ddp=ddp,
Expand Down
33 changes: 32 additions & 1 deletion tests/retrieval/test_hit_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments,
_default_metric_class_input_arguments_ignore_index,
_default_metric_functional_input_arguments,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_k,
Expand Down Expand Up @@ -52,6 +53,7 @@ class TestHitRate(RetrievalMetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_class_metric(
Expand All @@ -62,9 +64,38 @@ def test_class_metric(
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
ignore_index: int,
k: int,
):
metric_args = {"empty_target_action": empty_target_action, "k": k}
metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index)

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalHitRate,
sk_metric=_hit_rate_at_k,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
)

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index)
def test_class_metric_ignore_index(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
k: int,
):
metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100)

self.run_class_metric_test(
ddp=ddp,
Expand Down
31 changes: 30 additions & 1 deletion tests/retrieval/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments,
_default_metric_class_input_arguments_ignore_index,
_default_metric_functional_input_arguments,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_no_pos_target,
Expand All @@ -35,6 +36,7 @@ class TestMAP(RetrievalMetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_class_metric(
self,
Expand All @@ -44,8 +46,35 @@ def test_class_metric(
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
ignore_index: int,
):
metric_args = {"empty_target_action": empty_target_action}
metric_args = dict(empty_target_action=empty_target_action, ignore_index=ignore_index)

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalMAP,
sk_metric=sk_average_precision_score,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
)

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index)
def test_class_metric_ignore_index(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
):
metric_args = dict(empty_target_action=empty_target_action, ignore_index=-100)

self.run_class_metric_test(
ddp=ddp,
Expand Down
31 changes: 30 additions & 1 deletion tests/retrieval/test_mrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments,
_default_metric_class_input_arguments_ignore_index,
_default_metric_functional_input_arguments,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_no_pos_target,
Expand Down Expand Up @@ -57,6 +58,7 @@ class TestMRR(RetrievalMetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_class_metric(
self,
Expand All @@ -66,8 +68,35 @@ def test_class_metric(
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
ignore_index: int,
):
metric_args = {"empty_target_action": empty_target_action}
metric_args = dict(empty_target_action=empty_target_action, ignore_index=ignore_index)

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalMRR,
sk_metric=_reciprocal_rank,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
)

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index)
def test_class_metric_ignore_index(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
):
metric_args = dict(empty_target_action=empty_target_action, ignore_index=-100)

self.run_class_metric_test(
ddp=ddp,
Expand Down
33 changes: 32 additions & 1 deletion tests/retrieval/test_ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tests.retrieval.helpers import (
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments_ignore_index,
_default_metric_class_input_arguments_with_non_binary_target,
_default_metric_functional_input_arguments_with_non_binary_target,
_errors_test_class_metric_parameters_k,
Expand Down Expand Up @@ -51,6 +52,7 @@ class TestNDCG(RetrievalMetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("ignore_index", [None, 3]) # avoid setting 0, otherwise test with all 0 targets will fail
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments_with_non_binary_target)
def test_class_metric(
Expand All @@ -61,9 +63,38 @@ def test_class_metric(
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
ignore_index: int,
k: int,
):
metric_args = {"empty_target_action": empty_target_action, "k": k}
metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index)

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalNormalizedDCG,
sk_metric=_ndcg_at_k,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
)

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index)
def test_class_metric_ignore_index(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
k: int,
):
metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100)

self.run_class_metric_test(
ddp=ddp,
Expand Down
Loading