diff --git a/CHANGELOG.md b/CHANGELOG.md index 13f09f6dcf8..3ff2faff315 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `MinMaxMetric` to wrappers ([#556](https://github.com/PyTorchLightning/metrics/pull/556)) +- Added `ignore_index` to to retrieval metrics ([#676](https://github.com/PyTorchLightning/metrics/pull/676)) + + ### Changed - Scalar metrics will now consistently have additional dimensions squeezed ([#622](https://github.com/PyTorchLightning/metrics/pull/622)) diff --git a/tests/retrieval/helpers.py b/tests/retrieval/helpers.py index bbff0e85c74..1cbf3078096 100644 --- a/tests/retrieval/helpers.py +++ b/tests/retrieval/helpers.py @@ -31,6 +31,7 @@ from tests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes as _irs_mis_sz from tests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes_func as _irs_mis_sz_fn from tests.retrieval.inputs import _input_retrieval_scores_no_target as _irs_no_tgt +from tests.retrieval.inputs import _input_retrieval_scores_with_ignore_index as _irs_ii from tests.retrieval.inputs import _input_retrieval_scores_wrong_targets as _irs_bad_tgt seed_all(42) @@ -72,6 +73,7 @@ def _compute_sklearn_metric( indexes: np.ndarray = None, metric: Callable = None, empty_target_action: str = "skip", + ignore_index: int = None, reverse: bool = False, **kwargs, ) -> Tensor: @@ -90,6 +92,10 @@ def _compute_sklearn_metric( assert isinstance(preds, np.ndarray) assert isinstance(target, np.ndarray) + if ignore_index is not None: + valid_positions = target != ignore_index + indexes, preds, target = indexes[valid_positions], preds[valid_positions], target[valid_positions] + indexes = indexes.flatten() preds = preds.flatten() target = target.flatten() @@ -196,6 +202,14 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict: "`empty_target_action` received a wrong value `casual_argument`.", dict(empty_target_action="casual_argument"), ), + # check ignore_index is valid + ( + _irs.indexes, + _irs.preds, + _irs.target, + "Argument `ignore_index` must be an integer or None.", + dict(ignore_index=-100.0), + ), # check input shapes are consistent ( _irs_mis_sz.indexes, @@ -242,6 +256,14 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict: "`empty_target_action` received a wrong value `casual_argument`.", dict(empty_target_action="casual_argument"), ), + # check ignore_index is valid + ( + _irs.indexes, + _irs.preds, + _irs.target, + "Argument `ignore_index` must be an integer or None.", + dict(ignore_index=-100.0), + ), # check input shapes are consistent ( _irs_mis_sz.indexes, @@ -292,6 +314,13 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict: ], ) +_default_metric_class_input_arguments_ignore_index = dict( + argnames="indexes,preds,target", + argvalues=[ + (_irs_ii.indexes, _irs_ii.preds, _irs_ii.target), + ], +) + _default_metric_class_input_arguments_with_non_binary_target = dict( argnames="indexes,preds,target", argvalues=[ @@ -444,7 +473,7 @@ def metric_functional_ignore_indexes(preds, target, indexes): metric_module=metric_module, metric_functional=metric_functional_ignore_indexes, metric_args={"empty_target_action": "neg"}, - indexes=indexes, # every additional argument will be passed to RetrievalMAP and _sk_metric_adapted + indexes=indexes, # every additional argument will be passed to the retrieval metric and _sk_metric_adapted ) def run_precision_test_gpu( @@ -467,7 +496,7 @@ def metric_functional_ignore_indexes(preds, target, indexes): metric_module=metric_module, metric_functional=metric_functional_ignore_indexes, metric_args={"empty_target_action": "neg"}, - indexes=indexes, # every additional argument will be passed to RetrievalMAP and _sk_metric_adapted + indexes=indexes, # every additional argument will be passed to retrieval metric and _sk_metric_adapted ) @staticmethod diff --git a/tests/retrieval/inputs.py b/tests/retrieval/inputs.py index 2c688b11aa7..95bf476afd1 100644 --- a/tests/retrieval/inputs.py +++ b/tests/retrieval/inputs.py @@ -44,6 +44,14 @@ target=torch.rand(NUM_BATCHES, 2 * BATCH_SIZE), ) +_input_retrieval_scores_with_ignore_index = Input( + indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)).masked_fill( + mask=torch.randn(NUM_BATCHES, BATCH_SIZE) > 0.5, value=-100 + ), +) + # with errors _input_retrieval_scores_no_target = Input( indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), diff --git a/tests/retrieval/test_fallout.py b/tests/retrieval/test_fallout.py index 7f87c678caf..61cd94960d5 100644 --- a/tests/retrieval/test_fallout.py +++ b/tests/retrieval/test_fallout.py @@ -20,6 +20,7 @@ RetrievalMetricTester, _concat_tests, _default_metric_class_input_arguments, + _default_metric_class_input_arguments_ignore_index, _default_metric_functional_input_arguments, _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_k, @@ -55,6 +56,7 @@ class TestFallOut(RetrievalMetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @pytest.mark.parametrize("k", [None, 1, 4, 10]) @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_class_metric( @@ -65,9 +67,39 @@ def test_class_metric( target: Tensor, dist_sync_on_step: bool, empty_target_action: str, + ignore_index: int, k: int, ): - metric_args = {"empty_target_action": empty_target_action, "k": k} + metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index) + + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalFallOut, + sk_metric=_fallout_at_k, + dist_sync_on_step=dist_sync_on_step, + reverse=True, + metric_args=metric_args, + ) + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) + def test_class_metric_ignore_index( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + dist_sync_on_step: bool, + empty_target_action: str, + k: int, + ): + metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100) self.run_class_metric_test( ddp=ddp, diff --git a/tests/retrieval/test_hit_rate.py b/tests/retrieval/test_hit_rate.py index d68badb9af5..06da9be6587 100644 --- a/tests/retrieval/test_hit_rate.py +++ b/tests/retrieval/test_hit_rate.py @@ -20,6 +20,7 @@ RetrievalMetricTester, _concat_tests, _default_metric_class_input_arguments, + _default_metric_class_input_arguments_ignore_index, _default_metric_functional_input_arguments, _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_k, @@ -52,6 +53,7 @@ class TestHitRate(RetrievalMetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @pytest.mark.parametrize("k", [None, 1, 4, 10]) @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_class_metric( @@ -62,9 +64,38 @@ def test_class_metric( target: Tensor, dist_sync_on_step: bool, empty_target_action: str, + ignore_index: int, k: int, ): - metric_args = {"empty_target_action": empty_target_action, "k": k} + metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index) + + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalHitRate, + sk_metric=_hit_rate_at_k, + dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + ) + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) + def test_class_metric_ignore_index( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + dist_sync_on_step: bool, + empty_target_action: str, + k: int, + ): + metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100) self.run_class_metric_test( ddp=ddp, diff --git a/tests/retrieval/test_map.py b/tests/retrieval/test_map.py index 2f92449b046..8a7e3a67a75 100644 --- a/tests/retrieval/test_map.py +++ b/tests/retrieval/test_map.py @@ -20,6 +20,7 @@ RetrievalMetricTester, _concat_tests, _default_metric_class_input_arguments, + _default_metric_class_input_arguments_ignore_index, _default_metric_functional_input_arguments, _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_no_pos_target, @@ -35,6 +36,7 @@ class TestMAP(RetrievalMetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_class_metric( self, @@ -44,8 +46,35 @@ def test_class_metric( target: Tensor, dist_sync_on_step: bool, empty_target_action: str, + ignore_index: int, ): - metric_args = {"empty_target_action": empty_target_action} + metric_args = dict(empty_target_action=empty_target_action, ignore_index=ignore_index) + + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalMAP, + sk_metric=sk_average_precision_score, + dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + ) + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) + def test_class_metric_ignore_index( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + dist_sync_on_step: bool, + empty_target_action: str, + ): + metric_args = dict(empty_target_action=empty_target_action, ignore_index=-100) self.run_class_metric_test( ddp=ddp, diff --git a/tests/retrieval/test_mrr.py b/tests/retrieval/test_mrr.py index ef6fcae986d..9e3cc318876 100644 --- a/tests/retrieval/test_mrr.py +++ b/tests/retrieval/test_mrr.py @@ -21,6 +21,7 @@ RetrievalMetricTester, _concat_tests, _default_metric_class_input_arguments, + _default_metric_class_input_arguments_ignore_index, _default_metric_functional_input_arguments, _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_no_pos_target, @@ -57,6 +58,7 @@ class TestMRR(RetrievalMetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_class_metric( self, @@ -66,8 +68,35 @@ def test_class_metric( target: Tensor, dist_sync_on_step: bool, empty_target_action: str, + ignore_index: int, ): - metric_args = {"empty_target_action": empty_target_action} + metric_args = dict(empty_target_action=empty_target_action, ignore_index=ignore_index) + + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalMRR, + sk_metric=_reciprocal_rank, + dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + ) + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) + def test_class_metric_ignore_index( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + dist_sync_on_step: bool, + empty_target_action: str, + ): + metric_args = dict(empty_target_action=empty_target_action, ignore_index=-100) self.run_class_metric_test( ddp=ddp, diff --git a/tests/retrieval/test_ndcg.py b/tests/retrieval/test_ndcg.py index ad606b9a60b..ff6b5a0737a 100644 --- a/tests/retrieval/test_ndcg.py +++ b/tests/retrieval/test_ndcg.py @@ -20,6 +20,7 @@ from tests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, + _default_metric_class_input_arguments_ignore_index, _default_metric_class_input_arguments_with_non_binary_target, _default_metric_functional_input_arguments_with_non_binary_target, _errors_test_class_metric_parameters_k, @@ -51,6 +52,7 @@ class TestNDCG(RetrievalMetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("ignore_index", [None, 3]) # avoid setting 0, otherwise test with all 0 targets will fail @pytest.mark.parametrize("k", [None, 1, 4, 10]) @pytest.mark.parametrize(**_default_metric_class_input_arguments_with_non_binary_target) def test_class_metric( @@ -61,9 +63,38 @@ def test_class_metric( target: Tensor, dist_sync_on_step: bool, empty_target_action: str, + ignore_index: int, k: int, ): - metric_args = {"empty_target_action": empty_target_action, "k": k} + metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index) + + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalNormalizedDCG, + sk_metric=_ndcg_at_k, + dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + ) + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) + def test_class_metric_ignore_index( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + dist_sync_on_step: bool, + empty_target_action: str, + k: int, + ): + metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100) self.run_class_metric_test( ddp=ddp, diff --git a/tests/retrieval/test_precision.py b/tests/retrieval/test_precision.py index 0c8a5cc3ee2..e8541a60cc4 100644 --- a/tests/retrieval/test_precision.py +++ b/tests/retrieval/test_precision.py @@ -20,6 +20,7 @@ RetrievalMetricTester, _concat_tests, _default_metric_class_input_arguments, + _default_metric_class_input_arguments_ignore_index, _default_metric_functional_input_arguments, _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_k, @@ -56,6 +57,7 @@ class TestPrecision(RetrievalMetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @pytest.mark.parametrize("k", [None, 1, 4, 10]) @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_class_metric( @@ -66,9 +68,38 @@ def test_class_metric( target: Tensor, dist_sync_on_step: bool, empty_target_action: str, + ignore_index: int, k: int, ): - metric_args = {"empty_target_action": empty_target_action, "k": k} + metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index) + + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalPrecision, + sk_metric=_precision_at_k, + dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + ) + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) + def test_class_metric_ignore_index( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + dist_sync_on_step: bool, + empty_target_action: str, + k: int, + ): + metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100) self.run_class_metric_test( ddp=ddp, diff --git a/tests/retrieval/test_r_precision.py b/tests/retrieval/test_r_precision.py index 7822991b9a5..5d2c103c916 100644 --- a/tests/retrieval/test_r_precision.py +++ b/tests/retrieval/test_r_precision.py @@ -20,6 +20,7 @@ RetrievalMetricTester, _concat_tests, _default_metric_class_input_arguments, + _default_metric_class_input_arguments_ignore_index, _default_metric_functional_input_arguments, _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_no_pos_target, @@ -51,6 +52,7 @@ class TestRPrecision(RetrievalMetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_class_metric( self, @@ -60,8 +62,35 @@ def test_class_metric( target: Tensor, dist_sync_on_step: bool, empty_target_action: str, + ignore_index: int, ): - metric_args = {"empty_target_action": empty_target_action} + metric_args = dict(empty_target_action=empty_target_action, ignore_index=ignore_index) + + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalRPrecision, + sk_metric=_r_precision, + dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + ) + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) + def test_class_metric_ignore_index( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + dist_sync_on_step: bool, + empty_target_action: str, + ): + metric_args = dict(empty_target_action=empty_target_action, ignore_index=-100) self.run_class_metric_test( ddp=ddp, diff --git a/tests/retrieval/test_recall.py b/tests/retrieval/test_recall.py index 126804ab041..cbb5b8fd40d 100644 --- a/tests/retrieval/test_recall.py +++ b/tests/retrieval/test_recall.py @@ -20,6 +20,7 @@ RetrievalMetricTester, _concat_tests, _default_metric_class_input_arguments, + _default_metric_class_input_arguments_ignore_index, _default_metric_functional_input_arguments, _errors_test_class_metric_parameters_default, _errors_test_class_metric_parameters_k, @@ -55,6 +56,7 @@ class TestRecall(RetrievalMetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail @pytest.mark.parametrize("k", [None, 1, 4, 10]) @pytest.mark.parametrize(**_default_metric_class_input_arguments) def test_class_metric( @@ -65,9 +67,38 @@ def test_class_metric( target: Tensor, dist_sync_on_step: bool, empty_target_action: str, + ignore_index: int, k: int, ): - metric_args = {"empty_target_action": empty_target_action, "k": k} + metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index) + + self.run_class_metric_test( + ddp=ddp, + indexes=indexes, + preds=preds, + target=target, + metric_class=RetrievalRecall, + sk_metric=_recall_at_k, + dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args, + ) + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + @pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"]) + @pytest.mark.parametrize("k", [None, 1, 4, 10]) + @pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index) + def test_class_metric_ignore_index( + self, + ddp: bool, + indexes: Tensor, + preds: Tensor, + target: Tensor, + dist_sync_on_step: bool, + empty_target_action: str, + k: int, + ): + metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100) self.run_class_metric_test( ddp=ddp, diff --git a/torchmetrics/retrieval/mean_average_precision.py b/torchmetrics/retrieval/mean_average_precision.py index 996b43d7bc3..923e361784c 100644 --- a/torchmetrics/retrieval/mean_average_precision.py +++ b/torchmetrics/retrieval/mean_average_precision.py @@ -42,6 +42,8 @@ class RetrievalMAP(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` + ignore_index: + Ignore predictions where the target is equal to this number. compute_on_step: Forward only calls ``update()`` and return None if this is set to False. dist_sync_on_step: @@ -53,6 +55,12 @@ class RetrievalMAP(RetrievalMetric): Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. + Raises: + ValueError: + If ``empty_target_action`` is not one of ``error``, ``skip``, ``neg`` or ``pos``. + ValueError: + If ``ignore_index`` is not `None` or an integer. + Example: >>> from torchmetrics import RetrievalMAP >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) diff --git a/torchmetrics/retrieval/mean_reciprocal_rank.py b/torchmetrics/retrieval/mean_reciprocal_rank.py index 07cc0601971..e50377dda22 100644 --- a/torchmetrics/retrieval/mean_reciprocal_rank.py +++ b/torchmetrics/retrieval/mean_reciprocal_rank.py @@ -42,6 +42,8 @@ class RetrievalMRR(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` + ignore_index: + Ignore predictions where the target is equal to this number. compute_on_step: Forward only calls ``update()`` and return None if this is set to False. dist_sync_on_step: @@ -53,6 +55,12 @@ class RetrievalMRR(RetrievalMetric): Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. + Raises: + ValueError: + If ``empty_target_action`` is not one of ``error``, ``skip``, ``neg`` or ``pos``. + ValueError: + If ``ignore_index`` is not `None` or an integer. + Example: >>> from torchmetrics import RetrievalMRR >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) diff --git a/torchmetrics/retrieval/retrieval_fallout.py b/torchmetrics/retrieval/retrieval_fallout.py index 03f68f27412..5042d197cb3 100644 --- a/torchmetrics/retrieval/retrieval_fallout.py +++ b/torchmetrics/retrieval/retrieval_fallout.py @@ -46,6 +46,8 @@ class RetrievalFallOut(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` + ignore_index: + Ignore predictions where the target is equal to this number. k: consider only the top k elements for each query (default: `None`, which considers them all) compute_on_step: Forward only calls ``update()`` and return None if this is set to False. @@ -60,7 +62,11 @@ class RetrievalFallOut(RetrievalMetric): Raises: ValueError: - If ``k`` parameter is not `None` or an integer larger than 0 + If ``empty_target_action`` is not one of ``error``, ``skip``, ``neg`` or ``pos``. + ValueError: + If ``ignore_index`` is not `None` or an integer. + ValueError: + If ``k`` parameter is not `None` or an integer larger than 0. Example: >>> from torchmetrics import RetrievalFallOut @@ -77,7 +83,8 @@ class RetrievalFallOut(RetrievalMetric): def __init__( self, empty_target_action: str = "pos", - k: int = None, + ignore_index: Optional[int] = None, + k: Optional[int] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -85,6 +92,7 @@ def __init__( ) -> None: super().__init__( empty_target_action=empty_target_action, + ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, diff --git a/torchmetrics/retrieval/retrieval_hit_rate.py b/torchmetrics/retrieval/retrieval_hit_rate.py index 566edb29d06..34ad142f393 100644 --- a/torchmetrics/retrieval/retrieval_hit_rate.py +++ b/torchmetrics/retrieval/retrieval_hit_rate.py @@ -44,6 +44,8 @@ class RetrievalHitRate(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` + ignore_index: + Ignore predictions where the target is equal to this number. k: consider only the top k elements for each query (default: `None`, which considers them all) compute_on_step: Forward only calls ``update()`` and return None if this is set to False. @@ -58,7 +60,11 @@ class RetrievalHitRate(RetrievalMetric): Raises: ValueError: - If ``k`` parameter is not `None` or an integer larger than 0 + If ``empty_target_action`` is not one of ``error``, ``skip``, ``neg`` or ``pos``. + ValueError: + If ``ignore_index`` is not `None` or an integer. + ValueError: + If ``k`` parameter is not `None` or an integer larger than 0. Example: >>> from torchmetrics import RetrievalHitRate @@ -75,7 +81,8 @@ class RetrievalHitRate(RetrievalMetric): def __init__( self, empty_target_action: str = "neg", - k: int = None, + ignore_index: Optional[int] = None, + k: Optional[int] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -83,6 +90,7 @@ def __init__( ) -> None: super().__init__( empty_target_action=empty_target_action, + ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, diff --git a/torchmetrics/retrieval/retrieval_metric.py b/torchmetrics/retrieval/retrieval_metric.py index 1ad81d3ab54..b4a3759bea6 100644 --- a/torchmetrics/retrieval/retrieval_metric.py +++ b/torchmetrics/retrieval/retrieval_metric.py @@ -51,6 +51,8 @@ class RetrievalMetric(Metric, ABC): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` + ignore_index: + Ignore predictions where the target is equal to this number. compute_on_step: Forward only calls ``update()`` and return None if this is set to False. dist_sync_on_step: @@ -61,6 +63,12 @@ class RetrievalMetric(Metric, ABC): dist_sync_fn: Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. + + Raises: + ValueError: + If ``empty_target_action`` is not one of ``error``, ``skip``, ``neg`` or ``pos``. + ValueError: + If ``ignore_index`` is not `None` or an integer. """ indexes: List[Tensor] @@ -71,6 +79,7 @@ class RetrievalMetric(Metric, ABC): def __init__( self, empty_target_action: str = "neg", + ignore_index: Optional[int] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -90,6 +99,11 @@ def __init__( self.empty_target_action = empty_target_action + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError("Argument `ignore_index` must be an integer or None.") + + self.ignore_index = ignore_index + self.add_state("indexes", default=[], dist_reduce_fx=None) self.add_state("preds", default=[], dist_reduce_fx=None) self.add_state("target", default=[], dist_reduce_fx=None) @@ -100,7 +114,7 @@ def update(self, preds: Tensor, target: Tensor, indexes: Tensor) -> None: # typ raise ValueError("Argument `indexes` cannot be None") indexes, preds, target = _check_retrieval_inputs( - indexes, preds, target, allow_non_binary_target=self.allow_non_binary_target + indexes, preds, target, allow_non_binary_target=self.allow_non_binary_target, ignore_index=self.ignore_index ) self.indexes.append(indexes) diff --git a/torchmetrics/retrieval/retrieval_ndcg.py b/torchmetrics/retrieval/retrieval_ndcg.py index 1b4bb708036..a0456fea4f5 100644 --- a/torchmetrics/retrieval/retrieval_ndcg.py +++ b/torchmetrics/retrieval/retrieval_ndcg.py @@ -44,6 +44,8 @@ class RetrievalNormalizedDCG(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` + ignore_index: + Ignore predictions where the target is equal to this number. k: consider only the top k elements for each query (default: `None`, which considers them all) compute_on_step: Forward only calls ``update()`` and return None if this is set to False. @@ -58,7 +60,11 @@ class RetrievalNormalizedDCG(RetrievalMetric): Raises: ValueError: - If ``k`` parameter is not `None` or an integer larger than 0 + If ``empty_target_action`` is not one of ``error``, ``skip``, ``neg`` or ``pos``. + ValueError: + If ``ignore_index`` is not `None` or an integer. + ValueError: + If ``k`` parameter is not `None` or an integer larger than 0. Example: >>> from torchmetrics import RetrievalNormalizedDCG @@ -75,7 +81,8 @@ class RetrievalNormalizedDCG(RetrievalMetric): def __init__( self, empty_target_action: str = "neg", - k: int = None, + ignore_index: Optional[int] = None, + k: Optional[int] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -83,6 +90,7 @@ def __init__( ) -> None: super().__init__( empty_target_action=empty_target_action, + ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, diff --git a/torchmetrics/retrieval/retrieval_precision.py b/torchmetrics/retrieval/retrieval_precision.py index a753ee4589c..a31cda48c54 100644 --- a/torchmetrics/retrieval/retrieval_precision.py +++ b/torchmetrics/retrieval/retrieval_precision.py @@ -44,6 +44,8 @@ class RetrievalPrecision(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` + ignore_index: + Ignore predictions where the target is equal to this number. k: consider only the top k elements for each query (default: `None`, which considers them all) compute_on_step: Forward only calls ``update()`` and return None if this is set to False. @@ -58,7 +60,11 @@ class RetrievalPrecision(RetrievalMetric): Raises: ValueError: - If ``k`` parameter is not `None` or an integer larger than 0 + If ``empty_target_action`` is not one of ``error``, ``skip``, ``neg`` or ``pos``. + ValueError: + If ``ignore_index`` is not `None` or an integer. + ValueError: + If ``k`` parameter is not `None` or an integer larger than 0. Example: >>> from torchmetrics import RetrievalPrecision @@ -75,7 +81,8 @@ class RetrievalPrecision(RetrievalMetric): def __init__( self, empty_target_action: str = "neg", - k: int = None, + ignore_index: Optional[int] = None, + k: Optional[int] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -83,6 +90,7 @@ def __init__( ) -> None: super().__init__( empty_target_action=empty_target_action, + ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, diff --git a/torchmetrics/retrieval/retrieval_r_precision.py b/torchmetrics/retrieval/retrieval_r_precision.py index 1d9e0e6834e..d121f8e6baa 100644 --- a/torchmetrics/retrieval/retrieval_r_precision.py +++ b/torchmetrics/retrieval/retrieval_r_precision.py @@ -42,6 +42,8 @@ class RetrievalRPrecision(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` + ignore_index: + Ignore predictions where the target is equal to this number. compute_on_step: Forward only calls ``update()`` and return None if this is set to False. dist_sync_on_step: @@ -53,6 +55,12 @@ class RetrievalRPrecision(RetrievalMetric): Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. + Raises: + ValueError: + If ``empty_target_action`` is not one of ``error``, ``skip``, ``neg`` or ``pos``. + ValueError: + If ``ignore_index`` is not `None` or an integer. + Example: >>> from torchmetrics import RetrievalRPrecision >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) diff --git a/torchmetrics/retrieval/retrieval_recall.py b/torchmetrics/retrieval/retrieval_recall.py index 586997041e5..2861b086344 100644 --- a/torchmetrics/retrieval/retrieval_recall.py +++ b/torchmetrics/retrieval/retrieval_recall.py @@ -44,6 +44,8 @@ class RetrievalRecall(RetrievalMetric): - ``'skip'``: skip those queries; if all queries are skipped, ``0.0`` is returned - ``'error'``: raise a ``ValueError`` + ignore_index: + Ignore predictions where the target is equal to this number. k: consider only the top k elements for each query (default: `None`, which considers them all) compute_on_step: Forward only calls ``update()`` and return None if this is set to False. @@ -58,7 +60,11 @@ class RetrievalRecall(RetrievalMetric): Raises: ValueError: - If ``k`` parameter is not `None` or an integer larger than 0 + If ``empty_target_action`` is not one of ``error``, ``skip``, ``neg`` or ``pos``. + ValueError: + If ``ignore_index`` is not `None` or an integer. + ValueError: + If ``k`` parameter is not `None` or an integer larger than 0. Example: >>> from torchmetrics import RetrievalRecall @@ -75,7 +81,8 @@ class RetrievalRecall(RetrievalMetric): def __init__( self, empty_target_action: str = "neg", - k: int = None, + ignore_index: Optional[int] = None, + k: Optional[int] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -83,6 +90,7 @@ def __init__( ) -> None: super().__init__( empty_target_action=empty_target_action, + ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index ce5a39f2a36..d9ecf06c2cd 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -516,6 +516,7 @@ def _check_retrieval_inputs( preds: Tensor, target: Tensor, allow_non_binary_target: bool = False, + ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor, Tensor]: """Check ``indexes``, ``preds`` and ``target`` tensors are of the same shape and of the correct dtype. @@ -523,6 +524,7 @@ def _check_retrieval_inputs( indexes: tensor with queries indexes preds: tensor with scores/logits target: tensor with ground true labels + ignore_index: ignore predictions where targets are equal to this number Raises: ValueError: @@ -537,14 +539,19 @@ def _check_retrieval_inputs( if indexes.shape != preds.shape or preds.shape != target.shape: raise ValueError("`indexes`, `preds` and `target` must be of the same shape") + if indexes.dtype is not torch.long: + raise ValueError("`indexes` must be a tensor of long integers") + + # remove predictions where target is equal to `ignore_index` + if ignore_index is not None: + valid_positions = target != ignore_index + indexes, preds, target = indexes[valid_positions], preds[valid_positions], target[valid_positions] + if not indexes.numel() or not indexes.size(): raise ValueError( "`indexes`, `preds` and `target` must be non-empty and non-scalar tensors", ) - if indexes.dtype is not torch.long: - raise ValueError("`indexes` must be a tensor of long integers") - preds, target = _check_retrieval_target_and_prediction_types( preds, target, allow_non_binary_target=allow_non_binary_target ) @@ -578,5 +585,7 @@ def _check_retrieval_target_and_prediction_types( if not allow_non_binary_target and (target.max() > 1 or target.min() < 0): raise ValueError("`target` must contain `binary` values") - target = target.float().flatten() if target.is_floating_point() else target.long().flatten() - return preds.float().flatten(), target + target = target.float() if target.is_floating_point() else target.long() + preds = preds.float() + + return preds.flatten(), target.flatten()