From b726b2e71dcc895becae3fb23f3d74559011919b Mon Sep 17 00:00:00 2001 From: Shion Date: Sat, 19 Aug 2023 20:58:21 +0900 Subject: [PATCH 01/44] working implementation --- .../functional/clustering/__init__.py | 5 + .../clustering/mutual_info_score.py | 119 ++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 src/torchmetrics/functional/clustering/__init__.py create mode 100644 src/torchmetrics/functional/clustering/mutual_info_score.py diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py new file mode 100644 index 00000000000..322b4856620 --- /dev/null +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -0,0 +1,5 @@ +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score + +__all__ = [ + "mutual_info_score" +] diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py new file mode 100644 index 00000000000..c5ecc323552 --- /dev/null +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -0,0 +1,119 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from typing import Optional, Tuple +from torch import Tensor, tensor + +from torchmetrics.utilities.checks import _check_same_shape + + +def _mutual_info_score_check(preds, target) -> bool: + """Check shape of input tensors.""" + # TODO: check if data are disjoint subsets + return _check_same_shape(preds, target) + + +def _calculate_contingency_matrix( + preds: Tensor, + target: Tensor, + eps: Optional[float] = 1e-16, + sparse: bool = False +) -> Tensor: + """Calculate contingency matrix. + + Args: + preds: predicted labels + target: ground truth labels + sparse: If True, returns contingency matrix as a sparse matrix. + + Returns: + contingency: contingency matrix of shape (n_classes_target, n_classes_preds) + """ + if eps is not None and sparse is True: + raise ValueError('Cannot specify `eps` and return sparse tensor.') + + preds_classes, preds_idx = torch.unique(preds, return_inverse=True) + target_classes, target_idx = torch.unique(target, return_inverse=True) + + n_classes_preds = preds_classes.size(0) + n_classes_target = target_classes.size(0) + + contingency = torch.sparse_coo_tensor( + torch.stack((target_idx, preds_idx)), + torch.ones(target_idx.size(0)), + (n_classes_target, n_classes_preds) + ) + + if not sparse: + contingency = contingency.to_dense() + if eps: + contingency = contingency + eps + + return contingency + + +def _mutual_info_score_update(preds, target) -> Tuple[Tensor, Tensor, Tensor]: + """Update and return variables required to compute the mutual information score. + + Args: + preds: predicted class labels + target: ground truth class labels + + Returns: + contingency: contingency matrix + """ + _mutual_info_score_check(preds, target) + return _calculate_contingency_matrix(preds, target) + + +def _mutual_info_score_compute(contingency: Tensor) -> Tensor: + """Compute the mutual information score based on the contingency matrix. + + Args: + contingency: contingency matrix + + Returns: + mutual_info: mutual information score + """ + N = contingency.sum() + U = contingency.sum(dim=1) + V = contingency.sum(dim=0) + + # Check if preds or target labels only have one cluster + if U.size() == 1 or V.size() == 1: + return tensor(0.0) + + log_outer = torch.log(U).reshape(-1, 1) + torch.log(V) + mutual_info = contingency / N * (torch.log(N) + torch.log(contingency) - log_outer) + return mutual_info.sum() + + +def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor: + """Compute mutual information between two clusterings. + + Args: + preds: predicted classes + target: ground truth classes + + Example: + >>> from torchmetrics.functional.clustering import mutual_info_score + >>> target = torch.tensor([0, 3, 2, 2, 1]) + >>> preds = torch.tensor([1, 3, 2, 0, 1]) + >>> mutual_info_score(preds, target) + tensor([1.05492]) + """ + _mutual_info_score_check(preds, target) + contingency = _mutual_info_score_update(preds, target) + return _mutual_info_score_compute(contingency) From a065ef1338e919abf26d6b46146b63a5c8f2c48f Mon Sep 17 00:00:00 2001 From: Shion Date: Sat, 19 Aug 2023 22:33:28 +0900 Subject: [PATCH 02/44] passing functional and basic error tests --- .../clustering/mutual_info_score.py | 22 +++-- tests/unittests/clustering/__init__.py | 0 .../clustering/test_mutual_info_score.py | 88 +++++++++++++++++++ 3 files changed, 104 insertions(+), 6 deletions(-) create mode 100644 tests/unittests/clustering/__init__.py create mode 100644 tests/unittests/clustering/test_mutual_info_score.py diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index c5ecc323552..0402eb3382c 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -19,10 +19,14 @@ from torchmetrics.utilities.checks import _check_same_shape -def _mutual_info_score_check(preds, target) -> bool: +def check_cluster_labels(preds: Tensor, target: Tensor) -> None: """Check shape of input tensors.""" - # TODO: check if data are disjoint subsets - return _check_same_shape(preds, target) + _check_same_shape(preds, target) + if torch.is_floating_point(preds) or torch.is_floating_point(target): + raise ValueError( + f"Expected discrete values but received {preds.dtype} for" + f"predictions and {target.dtype} for target labels instead." + ) def _calculate_contingency_matrix( @@ -40,6 +44,7 @@ def _calculate_contingency_matrix( Returns: contingency: contingency matrix of shape (n_classes_target, n_classes_preds) + """ if eps is not None and sparse is True: raise ValueError('Cannot specify `eps` and return sparse tensor.') @@ -64,7 +69,10 @@ def _calculate_contingency_matrix( return contingency -def _mutual_info_score_update(preds, target) -> Tuple[Tensor, Tensor, Tensor]: +def _mutual_info_score_update( + preds: Tensor, + target: Tensor +) -> Tuple[Tensor, Tensor, Tensor]: """Update and return variables required to compute the mutual information score. Args: @@ -73,8 +81,9 @@ def _mutual_info_score_update(preds, target) -> Tuple[Tensor, Tensor, Tensor]: Returns: contingency: contingency matrix + """ - _mutual_info_score_check(preds, target) + check_cluster_labels(preds, target) return _calculate_contingency_matrix(preds, target) @@ -86,6 +95,7 @@ def _mutual_info_score_compute(contingency: Tensor) -> Tensor: Returns: mutual_info: mutual information score + """ N = contingency.sum() U = contingency.sum(dim=1) @@ -113,7 +123,7 @@ def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor: >>> preds = torch.tensor([1, 3, 2, 0, 1]) >>> mutual_info_score(preds, target) tensor([1.05492]) + """ - _mutual_info_score_check(preds, target) contingency = _mutual_info_score_update(preds, target) return _mutual_info_score_compute(contingency) diff --git a/tests/unittests/clustering/__init__.py b/tests/unittests/clustering/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py new file mode 100644 index 00000000000..a44fead5751 --- /dev/null +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -0,0 +1,88 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import torch +from sklearn.metrics import mutual_info_score as scipy_mutual_info_score +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score +from torchmetrics.clustering.mutual_info_score import MutualInfoScore + +from unittests import BATCH_SIZE, NUM_BATCHES +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) +NUM_CLASSES = 10 + +_single_target_inputs1 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_single_target_inputs2 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_float_inputs = Input( + preds=torch.rand((NUM_BATCHES, BATCH_SIZE)), + target=torch.rand((NUM_BATCHES, BATCH_SIZE)), +) + + +@pytest.mark.parametrize( + "preds, target", + [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), + ], +) +class TestMutualInfoScore(MetricTester): + """Test class for `MutualInfoScore` metric.""" + + atol = 1e-3 + + @pytest.mark.parametrize("compute_on_cpu", [True, False]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_mutual_info_score(self, preds, target, compute_on_cpu, ddp): + """Test class implementation of metric.""" + metric_args = {"num_classes": NUM_CLASSES} + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MutualInfoScore, + reference_metric=scipy_mutual_info_score, + metric_args=metric_args + ) + + def test_mutual_info_score_functional(self, preds, target): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=mutual_info_score, + reference_metric=scipy_mutual_info_score + ) + + +def test_mutual_info_score_functional_raises_invalid_task(): + """Check that metric rejects continuous-valued inputs.""" + preds, target = _float_inputs + with pytest.raises(ValueError, match=r"Expected discrete *"): + mutual_info_score(preds, target) From f355a3be77ea9b1bff92d622fc8ea87d8a1c7b06 Mon Sep 17 00:00:00 2001 From: Shion Date: Sat, 19 Aug 2023 20:58:21 +0900 Subject: [PATCH 03/44] working implementation --- .../functional/clustering/__init__.py | 5 + .../clustering/mutual_info_score.py | 119 ++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 src/torchmetrics/functional/clustering/__init__.py create mode 100644 src/torchmetrics/functional/clustering/mutual_info_score.py diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py new file mode 100644 index 00000000000..322b4856620 --- /dev/null +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -0,0 +1,5 @@ +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score + +__all__ = [ + "mutual_info_score" +] diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py new file mode 100644 index 00000000000..c5ecc323552 --- /dev/null +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -0,0 +1,119 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from typing import Optional, Tuple +from torch import Tensor, tensor + +from torchmetrics.utilities.checks import _check_same_shape + + +def _mutual_info_score_check(preds, target) -> bool: + """Check shape of input tensors.""" + # TODO: check if data are disjoint subsets + return _check_same_shape(preds, target) + + +def _calculate_contingency_matrix( + preds: Tensor, + target: Tensor, + eps: Optional[float] = 1e-16, + sparse: bool = False +) -> Tensor: + """Calculate contingency matrix. + + Args: + preds: predicted labels + target: ground truth labels + sparse: If True, returns contingency matrix as a sparse matrix. + + Returns: + contingency: contingency matrix of shape (n_classes_target, n_classes_preds) + """ + if eps is not None and sparse is True: + raise ValueError('Cannot specify `eps` and return sparse tensor.') + + preds_classes, preds_idx = torch.unique(preds, return_inverse=True) + target_classes, target_idx = torch.unique(target, return_inverse=True) + + n_classes_preds = preds_classes.size(0) + n_classes_target = target_classes.size(0) + + contingency = torch.sparse_coo_tensor( + torch.stack((target_idx, preds_idx)), + torch.ones(target_idx.size(0)), + (n_classes_target, n_classes_preds) + ) + + if not sparse: + contingency = contingency.to_dense() + if eps: + contingency = contingency + eps + + return contingency + + +def _mutual_info_score_update(preds, target) -> Tuple[Tensor, Tensor, Tensor]: + """Update and return variables required to compute the mutual information score. + + Args: + preds: predicted class labels + target: ground truth class labels + + Returns: + contingency: contingency matrix + """ + _mutual_info_score_check(preds, target) + return _calculate_contingency_matrix(preds, target) + + +def _mutual_info_score_compute(contingency: Tensor) -> Tensor: + """Compute the mutual information score based on the contingency matrix. + + Args: + contingency: contingency matrix + + Returns: + mutual_info: mutual information score + """ + N = contingency.sum() + U = contingency.sum(dim=1) + V = contingency.sum(dim=0) + + # Check if preds or target labels only have one cluster + if U.size() == 1 or V.size() == 1: + return tensor(0.0) + + log_outer = torch.log(U).reshape(-1, 1) + torch.log(V) + mutual_info = contingency / N * (torch.log(N) + torch.log(contingency) - log_outer) + return mutual_info.sum() + + +def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor: + """Compute mutual information between two clusterings. + + Args: + preds: predicted classes + target: ground truth classes + + Example: + >>> from torchmetrics.functional.clustering import mutual_info_score + >>> target = torch.tensor([0, 3, 2, 2, 1]) + >>> preds = torch.tensor([1, 3, 2, 0, 1]) + >>> mutual_info_score(preds, target) + tensor([1.05492]) + """ + _mutual_info_score_check(preds, target) + contingency = _mutual_info_score_update(preds, target) + return _mutual_info_score_compute(contingency) From e6862da1b1b793e501440354597aadaf26735d39 Mon Sep 17 00:00:00 2001 From: Shion Date: Sat, 19 Aug 2023 22:33:28 +0900 Subject: [PATCH 04/44] passing functional and basic error tests --- .../clustering/mutual_info_score.py | 22 +++-- tests/unittests/clustering/__init__.py | 0 .../clustering/test_mutual_info_score.py | 88 +++++++++++++++++++ 3 files changed, 104 insertions(+), 6 deletions(-) create mode 100644 tests/unittests/clustering/__init__.py create mode 100644 tests/unittests/clustering/test_mutual_info_score.py diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index c5ecc323552..0402eb3382c 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -19,10 +19,14 @@ from torchmetrics.utilities.checks import _check_same_shape -def _mutual_info_score_check(preds, target) -> bool: +def check_cluster_labels(preds: Tensor, target: Tensor) -> None: """Check shape of input tensors.""" - # TODO: check if data are disjoint subsets - return _check_same_shape(preds, target) + _check_same_shape(preds, target) + if torch.is_floating_point(preds) or torch.is_floating_point(target): + raise ValueError( + f"Expected discrete values but received {preds.dtype} for" + f"predictions and {target.dtype} for target labels instead." + ) def _calculate_contingency_matrix( @@ -40,6 +44,7 @@ def _calculate_contingency_matrix( Returns: contingency: contingency matrix of shape (n_classes_target, n_classes_preds) + """ if eps is not None and sparse is True: raise ValueError('Cannot specify `eps` and return sparse tensor.') @@ -64,7 +69,10 @@ def _calculate_contingency_matrix( return contingency -def _mutual_info_score_update(preds, target) -> Tuple[Tensor, Tensor, Tensor]: +def _mutual_info_score_update( + preds: Tensor, + target: Tensor +) -> Tuple[Tensor, Tensor, Tensor]: """Update and return variables required to compute the mutual information score. Args: @@ -73,8 +81,9 @@ def _mutual_info_score_update(preds, target) -> Tuple[Tensor, Tensor, Tensor]: Returns: contingency: contingency matrix + """ - _mutual_info_score_check(preds, target) + check_cluster_labels(preds, target) return _calculate_contingency_matrix(preds, target) @@ -86,6 +95,7 @@ def _mutual_info_score_compute(contingency: Tensor) -> Tensor: Returns: mutual_info: mutual information score + """ N = contingency.sum() U = contingency.sum(dim=1) @@ -113,7 +123,7 @@ def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor: >>> preds = torch.tensor([1, 3, 2, 0, 1]) >>> mutual_info_score(preds, target) tensor([1.05492]) + """ - _mutual_info_score_check(preds, target) contingency = _mutual_info_score_update(preds, target) return _mutual_info_score_compute(contingency) diff --git a/tests/unittests/clustering/__init__.py b/tests/unittests/clustering/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py new file mode 100644 index 00000000000..a44fead5751 --- /dev/null +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -0,0 +1,88 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import pytest +import torch +from sklearn.metrics import mutual_info_score as scipy_mutual_info_score +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score +from torchmetrics.clustering.mutual_info_score import MutualInfoScore + +from unittests import BATCH_SIZE, NUM_BATCHES +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) +NUM_CLASSES = 10 + +_single_target_inputs1 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_single_target_inputs2 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_float_inputs = Input( + preds=torch.rand((NUM_BATCHES, BATCH_SIZE)), + target=torch.rand((NUM_BATCHES, BATCH_SIZE)), +) + + +@pytest.mark.parametrize( + "preds, target", + [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), + ], +) +class TestMutualInfoScore(MetricTester): + """Test class for `MutualInfoScore` metric.""" + + atol = 1e-3 + + @pytest.mark.parametrize("compute_on_cpu", [True, False]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_mutual_info_score(self, preds, target, compute_on_cpu, ddp): + """Test class implementation of metric.""" + metric_args = {"num_classes": NUM_CLASSES} + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MutualInfoScore, + reference_metric=scipy_mutual_info_score, + metric_args=metric_args + ) + + def test_mutual_info_score_functional(self, preds, target): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=mutual_info_score, + reference_metric=scipy_mutual_info_score + ) + + +def test_mutual_info_score_functional_raises_invalid_task(): + """Check that metric rejects continuous-valued inputs.""" + preds, target = _float_inputs + with pytest.raises(ValueError, match=r"Expected discrete *"): + mutual_info_score(preds, target) From fbfae57e098aab61f8efef95fcc5570854f842d8 Mon Sep 17 00:00:00 2001 From: Shion Date: Mon, 21 Aug 2023 23:57:55 +0900 Subject: [PATCH 05/44] clean up naming and imports --- .../functional/clustering/mutual_info_score.py | 12 +++++++----- .../unittests/clustering/test_mutual_info_score.py | 14 +++++++------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index 0402eb3382c..badfd3aaa02 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -19,12 +19,13 @@ from torchmetrics.utilities.checks import _check_same_shape -def check_cluster_labels(preds: Tensor, target: Tensor) -> None: +def _check_cluster_labels(preds: Tensor, target: Tensor) -> None: """Check shape of input tensors.""" _check_same_shape(preds, target) - if torch.is_floating_point(preds) or torch.is_floating_point(target): + if torch.is_floating_point(preds) or torch.is_complex(preds) or \ + torch.is_floating_point(target) or torch.is_complex(target): raise ValueError( - f"Expected discrete values but received {preds.dtype} for" + f"Expected real, discrete values but received {preds.dtype} for" f"predictions and {target.dtype} for target labels instead." ) @@ -71,7 +72,8 @@ def _calculate_contingency_matrix( def _mutual_info_score_update( preds: Tensor, - target: Tensor + target: Tensor, + # num_classes: int ) -> Tuple[Tensor, Tensor, Tensor]: """Update and return variables required to compute the mutual information score. @@ -83,7 +85,7 @@ def _mutual_info_score_update( contingency: contingency matrix """ - check_cluster_labels(preds, target) + _check_cluster_labels(preds, target) return _calculate_contingency_matrix(preds, target) diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index a44fead5751..4ec74ad28c2 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -16,7 +16,7 @@ import pytest import torch -from sklearn.metrics import mutual_info_score as scipy_mutual_info_score +from sklearn.metrics import mutual_info_score as sklearn_mutual_info_score from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score from torchmetrics.clustering.mutual_info_score import MutualInfoScore @@ -55,20 +55,20 @@ class TestMutualInfoScore(MetricTester): """Test class for `MutualInfoScore` metric.""" - atol = 1e-3 + atol = 1e-5 @pytest.mark.parametrize("compute_on_cpu", [True, False]) @pytest.mark.parametrize("ddp", [True, False]) def test_mutual_info_score(self, preds, target, compute_on_cpu, ddp): """Test class implementation of metric.""" - metric_args = {"num_classes": NUM_CLASSES} + # metric_args = {"num_classes": NUM_CLASSES} self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=MutualInfoScore, - reference_metric=scipy_mutual_info_score, - metric_args=metric_args + reference_metric=sklearn_mutual_info_score, + # metric_args=metric_args ) def test_mutual_info_score_functional(self, preds, target): @@ -77,12 +77,12 @@ def test_mutual_info_score_functional(self, preds, target): preds=preds, target=target, metric_functional=mutual_info_score, - reference_metric=scipy_mutual_info_score + reference_metric=sklearn_mutual_info_score, ) def test_mutual_info_score_functional_raises_invalid_task(): """Check that metric rejects continuous-valued inputs.""" preds, target = _float_inputs - with pytest.raises(ValueError, match=r"Expected discrete *"): + with pytest.raises(ValueError, match=r"Expected *"): mutual_info_score(preds, target) From f72183d6471cb1890ee87e8643ef15e68c72e643 Mon Sep 17 00:00:00 2001 From: Shion Date: Mon, 21 Aug 2023 23:58:25 +0900 Subject: [PATCH 06/44] push metric class (broken but to allow review) --- src/torchmetrics/clustering/__init__.py | 19 +++ .../clustering/mutual_info_score.py | 130 ++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 src/torchmetrics/clustering/__init__.py create mode 100644 src/torchmetrics/clustering/mutual_info_score.py diff --git a/src/torchmetrics/clustering/__init__.py b/src/torchmetrics/clustering/__init__.py new file mode 100644 index 00000000000..e86cc406cb1 --- /dev/null +++ b/src/torchmetrics/clustering/__init__.py @@ -0,0 +1,19 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.clustering.mutual_info_score import MutualInfoScore + + +__all__ = [ + "MutualInfoScore", +] diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py new file mode 100644 index 00000000000..a1c781e2cae --- /dev/null +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -0,0 +1,130 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from typing import Any, Optional, List, Sequence, Union +from torch import Tensor + +from torchmetrics.functional.clustering.mutual_info_score import ( + _mutual_info_score_compute, + _mutual_info_score_update +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MutualInfoScore.plot"] + + +class MutualInfoScore(Metric): + r"""Compute `Mutual Information Score`_. + + .. math:: + MI(U,V) = \sum_{i=1}^{\abs{U}} \sum_{j=1}^{\abs{V}} \frac{\abs{U_i\cap V_j}}{N} \log\frac{N\abs{U_i\cap V_j}}{\abs{U_i}\abs{V_j}} + + Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions, + :math:`\abs{U_i}` is the number of samples in cluster :math:`U_i`, and + :math:`\abs{V_i}` is the number of samples in cluster :math:`V_i`. + + The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields + the same mutual information score. + + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)`` + - ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)`` + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Mutual Information Score + + Example: + >>> from torchmetrics.clustering import MutualInfoScore + >>> target = torch.tensor([]) + >>> preds = torch.tensor([]) + >>> mi_score = MutualInfoScore() + >>> mi_score(preds, target) + tensor() + """ + + is_differentiable = True + higher_is_better = None + full_state_update: bool = True + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 # theoretical upper bound is +inf + preds: List[Tensor] + target: List[Tensor] + contingency: Tensor + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + # self.num_classes = num_classes + # + # self.add_state("contingency", default=torch.zeros(self.num_classes), dist_reduce_fx=None) + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.contingency = _mutual_info_score_update(preds, target) + + def compute(self) -> Tensor: + """Compute mutual information over state.""" + return _mutual_info_score_compute(self.contingency) + + def plot( + self, + val: Union[Tensor, Sequence[Tensor], None] = None, + ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.clustering import MutualInfoScore + >>> metric = MutualInfoScore(num_classes=5) + >>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.clustering import MutualInfoScore + >>> metric = MutualInfoScore(num_classes=5) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) From 7fe14e0fb6a89435a0e2be6792917a98c26d2058 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Aug 2023 15:01:58 +0000 Subject: [PATCH 07/44] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/clustering/__init__.py | 1 - .../clustering/mutual_info_score.py | 16 ++++--------- .../functional/clustering/__init__.py | 4 +--- .../clustering/mutual_info_score.py | 23 +++++++++---------- .../clustering/test_mutual_info_score.py | 2 +- 5 files changed, 18 insertions(+), 28 deletions(-) diff --git a/src/torchmetrics/clustering/__init__.py b/src/torchmetrics/clustering/__init__.py index e86cc406cb1..baeb8c88d31 100644 --- a/src/torchmetrics/clustering/__init__.py +++ b/src/torchmetrics/clustering/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. from torchmetrics.clustering.mutual_info_score import MutualInfoScore - __all__ = [ "MutualInfoScore", ] diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index a1c781e2cae..7605d12b08b 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -11,15 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch +from typing import Any, List, Optional, Sequence, Union -from typing import Any, Optional, List, Sequence, Union +import torch from torch import Tensor -from torchmetrics.functional.clustering.mutual_info_score import ( - _mutual_info_score_compute, - _mutual_info_score_update -) +from torchmetrics.functional.clustering.mutual_info_score import _mutual_info_score_compute, _mutual_info_score_update from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -60,6 +57,7 @@ class MutualInfoScore(Metric): >>> mi_score = MutualInfoScore() >>> mi_score(preds, target) tensor() + """ is_differentiable = True @@ -85,11 +83,7 @@ def compute(self) -> Tensor: """Compute mutual information over state.""" return _mutual_info_score_compute(self.contingency) - def plot( - self, - val: Union[Tensor, Sequence[Tensor], None] = None, - ax: Optional[_AX_TYPE] = None - ) -> _PLOT_OUT_TYPE: + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py index 322b4856620..576acda5108 100644 --- a/src/torchmetrics/functional/clustering/__init__.py +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -1,5 +1,3 @@ from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score -__all__ = [ - "mutual_info_score" -] +__all__ = ["mutual_info_score"] diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index badfd3aaa02..180aa4496a3 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch - from typing import Optional, Tuple + +import torch from torch import Tensor, tensor from torchmetrics.utilities.checks import _check_same_shape @@ -22,8 +22,12 @@ def _check_cluster_labels(preds: Tensor, target: Tensor) -> None: """Check shape of input tensors.""" _check_same_shape(preds, target) - if torch.is_floating_point(preds) or torch.is_complex(preds) or \ - torch.is_floating_point(target) or torch.is_complex(target): + if ( + torch.is_floating_point(preds) + or torch.is_complex(preds) + or torch.is_floating_point(target) + or torch.is_complex(target) + ): raise ValueError( f"Expected real, discrete values but received {preds.dtype} for" f"predictions and {target.dtype} for target labels instead." @@ -31,10 +35,7 @@ def _check_cluster_labels(preds: Tensor, target: Tensor) -> None: def _calculate_contingency_matrix( - preds: Tensor, - target: Tensor, - eps: Optional[float] = 1e-16, - sparse: bool = False + preds: Tensor, target: Tensor, eps: Optional[float] = 1e-16, sparse: bool = False ) -> Tensor: """Calculate contingency matrix. @@ -48,7 +49,7 @@ def _calculate_contingency_matrix( """ if eps is not None and sparse is True: - raise ValueError('Cannot specify `eps` and return sparse tensor.') + raise ValueError("Cannot specify `eps` and return sparse tensor.") preds_classes, preds_idx = torch.unique(preds, return_inverse=True) target_classes, target_idx = torch.unique(target, return_inverse=True) @@ -57,9 +58,7 @@ def _calculate_contingency_matrix( n_classes_target = target_classes.size(0) contingency = torch.sparse_coo_tensor( - torch.stack((target_idx, preds_idx)), - torch.ones(target_idx.size(0)), - (n_classes_target, n_classes_preds) + torch.stack((target_idx, preds_idx)), torch.ones(target_idx.size(0)), (n_classes_target, n_classes_preds) ) if not sparse: diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index 4ec74ad28c2..d594d0f140a 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -17,8 +17,8 @@ import pytest import torch from sklearn.metrics import mutual_info_score as sklearn_mutual_info_score -from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score from torchmetrics.clustering.mutual_info_score import MutualInfoScore +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score from unittests import BATCH_SIZE, NUM_BATCHES from unittests.helpers import seed_all From 808b2785ad651b0fe21b90801499545792597f2b Mon Sep 17 00:00:00 2001 From: Shion Date: Tue, 22 Aug 2023 00:44:26 +0900 Subject: [PATCH 08/44] add docs files --- docs/source/clustering/mutual_info_score.rst | 21 ++++++++++++++++++++ docs/source/index.rst | 8 ++++++++ 2 files changed, 29 insertions(+) create mode 100644 docs/source/clustering/mutual_info_score.rst diff --git a/docs/source/clustering/mutual_info_score.rst b/docs/source/clustering/mutual_info_score.rst new file mode 100644 index 00000000000..1b7d13519f4 --- /dev/null +++ b/docs/source/clustering/mutual_info_score.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Mutual Information Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/clustering.svg + :tags: Clustering + +.. include:: ../links.rst + +################### +Mutual Info. Score +################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.MutualInfoScore + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.mutual_info_score diff --git a/docs/source/index.rst b/docs/source/index.rst index 9da8cf0a51a..af7b6ff798b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -158,6 +158,14 @@ Or directly from conda classification/* +.. toctree:: + :maxdepth: 2 + :name: clustering + :caption: Clustering + :glob: + + clustering/* + .. toctree:: :maxdepth: 2 :name: detection From a0308d2800763c47f2acb7a87bfd49a2192af56e Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 22 Aug 2023 14:11:36 +0200 Subject: [PATCH 09/44] releasing 1.1.0 --- CHANGELOG.md | 36 +---------------------------------- src/torchmetrics/__about__.py | 2 +- 2 files changed, 2 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d4c262ac67..356aca21769 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,56 +7,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 **Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.** -## [UnReleased] - 2023-MM-DD +## [1.1.0] - 2023-08-22 ### Added - Added source aggregated signal-to-distortion ratio (SA-SDR) metric ([#1882](https://github.com/Lightning-AI/torchmetrics/pull/1882) - - - Added `VisualInformationFidelity` to image package ([#1830](https://github.com/Lightning-AI/torchmetrics/pull/1830)) - - - Added `EditDistance` to text package ([#1906](https://github.com/Lightning-AI/torchmetrics/pull/1906)) - - - Added `top_k` argument to `RetrievalMRR` in retrieval package ([#1961](https://github.com/Lightning-AI/torchmetrics/pull/1961)) - - - Added support for evaluating `"segm"` and `"bbox"` detection in `MeanAveragePrecision` at the same time ([#1928](https://github.com/Lightning-AI/torchmetrics/pull/1928)) - - - Added `PerceptualPathLength` to image package ([#1939](https://github.com/Lightning-AI/torchmetrics/pull/1939)) - - - Added support for multioutput evaluation in `MeanSquaredError` ([#1937](https://github.com/Lightning-AI/torchmetrics/pull/1937)) - - - Added argument `extended_summary` to `MeanAveragePrecision` such that precision, recall, iou can be easily returned ([#1983](https://github.com/Lightning-AI/torchmetrics/pull/1983)) - - - Added warning to `ClipScore` if long captions are detected and truncate ([#2001](https://github.com/Lightning-AI/torchmetrics/pull/2001)) - - - Added `CLIPImageQualityAssessment` to multimodal package ([#1931](https://github.com/Lightning-AI/torchmetrics/pull/1931)) - - - Added new property `metric_state` to all metrics for users to investigate currently stored tensors in memory ([#2006](https://github.com/Lightning-AI/torchmetrics/pull/2006)) -### Changed - -- - - -### Removed - -- - - -### Fixed - -- - ## [1.0.3] - 2023-08-08 diff --git a/src/torchmetrics/__about__.py b/src/torchmetrics/__about__.py index 47ee37bceba..2115111fa44 100644 --- a/src/torchmetrics/__about__.py +++ b/src/torchmetrics/__about__.py @@ -1,4 +1,4 @@ -__version__ = "1.1.0.dev" +__version__ = "1.1.0" __author__ = "Lightning-AI et al." __author_email__ = "name@pytorchlightning.ai" __license__ = "Apache-2.0" From 0d3fec9b80dc33b71b2777f0c2b51468ebc00a20 Mon Sep 17 00:00:00 2001 From: Shion Date: Tue, 22 Aug 2023 23:59:20 +0900 Subject: [PATCH 10/44] Create util functions for clustering. Fix metric implementation. --- docs/source/clustering/mutual_info_score.rst | 6 +- docs/source/links.rst | 1 + .../clustering/mutual_info_score.py | 14 ++-- .../functional/clustering/__init__.py | 13 ++++ .../clustering/mutual_info_score.py | 59 +-------------- .../functional/clustering/utils.py | 74 +++++++++++++++++++ 6 files changed, 103 insertions(+), 64 deletions(-) create mode 100644 src/torchmetrics/functional/clustering/utils.py diff --git a/docs/source/clustering/mutual_info_score.rst b/docs/source/clustering/mutual_info_score.rst index 1b7d13519f4..39291801ae9 100644 --- a/docs/source/clustering/mutual_info_score.rst +++ b/docs/source/clustering/mutual_info_score.rst @@ -5,9 +5,9 @@ .. include:: ../links.rst -################### -Mutual Info. Score -################### +######################## +Mutual Information Score +######################## Module Interface ________________ diff --git a/docs/source/links.rst b/docs/source/links.rst index 4ca837ccd64..7627490c661 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -150,4 +150,5 @@ .. _CIOU: https://arxiv.org/abs/2005.03572 .. _DIOU: https://arxiv.org/abs/1911.08287v1 .. _GIOU: https://arxiv.org/abs/1902.09630 +.. _Mutual Information Score: https://en.wikipedia.org/wiki/Mutual_information .. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index 7605d12b08b..ff999d7bb0b 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -16,8 +16,9 @@ import torch from torch import Tensor -from torchmetrics.functional.clustering.mutual_info_score import _mutual_info_score_compute, _mutual_info_score_update +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -71,17 +72,18 @@ class MutualInfoScore(Metric): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - # self.num_classes = num_classes - # - # self.add_state("contingency", default=torch.zeros(self.num_classes), dist_reduce_fx=None) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" - self.contingency = _mutual_info_score_update(preds, target) + self.preds.append(preds) + self.target.append(target) def compute(self) -> Tensor: """Compute mutual information over state.""" - return _mutual_info_score_compute(self.contingency) + return mutual_info_score(dim_zero_cat(self.preds), dim_zero_cat(self.target)) def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py index 576acda5108..c6f46126ca3 100644 --- a/src/torchmetrics/functional/clustering/__init__.py +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -1,3 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score __all__ = ["mutual_info_score"] diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index 180aa4496a3..6a61e2ac40d 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -11,62 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Tuple import torch from torch import Tensor, tensor - -from torchmetrics.utilities.checks import _check_same_shape - - -def _check_cluster_labels(preds: Tensor, target: Tensor) -> None: - """Check shape of input tensors.""" - _check_same_shape(preds, target) - if ( - torch.is_floating_point(preds) - or torch.is_complex(preds) - or torch.is_floating_point(target) - or torch.is_complex(target) - ): - raise ValueError( - f"Expected real, discrete values but received {preds.dtype} for" - f"predictions and {target.dtype} for target labels instead." - ) - - -def _calculate_contingency_matrix( - preds: Tensor, target: Tensor, eps: Optional[float] = 1e-16, sparse: bool = False -) -> Tensor: - """Calculate contingency matrix. - - Args: - preds: predicted labels - target: ground truth labels - sparse: If True, returns contingency matrix as a sparse matrix. - - Returns: - contingency: contingency matrix of shape (n_classes_target, n_classes_preds) - - """ - if eps is not None and sparse is True: - raise ValueError("Cannot specify `eps` and return sparse tensor.") - - preds_classes, preds_idx = torch.unique(preds, return_inverse=True) - target_classes, target_idx = torch.unique(target, return_inverse=True) - - n_classes_preds = preds_classes.size(0) - n_classes_target = target_classes.size(0) - - contingency = torch.sparse_coo_tensor( - torch.stack((target_idx, preds_idx)), torch.ones(target_idx.size(0)), (n_classes_target, n_classes_preds) - ) - - if not sparse: - contingency = contingency.to_dense() - if eps: - contingency = contingency + eps - - return contingency +from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels def _mutual_info_score_update( @@ -84,8 +33,8 @@ def _mutual_info_score_update( contingency: contingency matrix """ - _check_cluster_labels(preds, target) - return _calculate_contingency_matrix(preds, target) + check_cluster_labels(preds, target) + return calculate_contingency_matrix(preds, target) def _mutual_info_score_compute(contingency: Tensor) -> Tensor: diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py new file mode 100644 index 00000000000..334584b8ff4 --- /dev/null +++ b/src/torchmetrics/functional/clustering/utils.py @@ -0,0 +1,74 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from torch import Tensor +from torchmetrics.utilities.checks import _check_same_shape +from typing import Optional + + +def calculate_contingency_matrix( + preds: Tensor, target: Tensor, eps: Optional[float] = 1e-16, sparse: bool = False +) -> Tensor: + """Calculate contingency matrix. + + Args: + preds: predicted labels + target: ground truth labels + sparse: If True, returns contingency matrix as a sparse matrix. + + Returns: + contingency: contingency matrix of shape (n_classes_target, n_classes_preds) + + """ + if eps is not None and sparse is True: + raise ValueError("Cannot specify `eps` and return sparse tensor.") + + preds_classes, preds_idx = torch.unique(preds, return_inverse=True) + target_classes, target_idx = torch.unique(target, return_inverse=True) + + n_classes_preds = preds_classes.size(0) + n_classes_target = target_classes.size(0) + + contingency = torch.sparse_coo_tensor( + torch.stack((target_idx, preds_idx)), torch.ones(target_idx.size(0)), (n_classes_target, n_classes_preds) + ) + + if not sparse: + contingency = contingency.to_dense() + if eps: + contingency = contingency + eps + + return contingency + + +def check_cluster_labels(preds: Tensor, target: Tensor) -> None: + """Check shape of input tensors and if they are real, discrete tensors. + + Args: + preds: predicted labels + target: ground truth labels + + """ + _check_same_shape(preds, target) + if ( + torch.is_floating_point(preds) + or torch.is_complex(preds) + or torch.is_floating_point(target) + or torch.is_complex(target) + ): + raise ValueError( + f"Expected real, discrete values but received {preds.dtype} for" + f"predictions and {target.dtype} for target labels instead." + ) From 7dad1f97f1e2927e923b38c3b277f3fa282b7f20 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Aug 2023 15:00:37 +0000 Subject: [PATCH 11/44] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/clustering/mutual_info_score.py | 1 + src/torchmetrics/functional/clustering/utils.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index 6a61e2ac40d..ac2ed03c41d 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -15,6 +15,7 @@ import torch from torch import Tensor, tensor + from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index 334584b8ff4..a5460d22e58 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch +from typing import Optional +import torch from torch import Tensor + from torchmetrics.utilities.checks import _check_same_shape -from typing import Optional def calculate_contingency_matrix( From c36d8a0e7c419fdf1d465117ced9c30c9a3ec95d Mon Sep 17 00:00:00 2001 From: Shion Date: Wed, 23 Aug 2023 00:15:20 +0900 Subject: [PATCH 12/44] Fix ruff-related errors --- src/torchmetrics/clustering/mutual_info_score.py | 4 ++-- .../functional/clustering/mutual_info_score.py | 14 +++++++------- src/torchmetrics/functional/clustering/utils.py | 9 ++++++--- .../unittests/clustering/test_mutual_info_score.py | 1 - 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index ff999d7bb0b..4e2a573870f 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Any, List, Optional, Sequence, Union -import torch from torch import Tensor from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score @@ -30,7 +29,8 @@ class MutualInfoScore(Metric): r"""Compute `Mutual Information Score`_. .. math:: - MI(U,V) = \sum_{i=1}^{\abs{U}} \sum_{j=1}^{\abs{V}} \frac{\abs{U_i\cap V_j}}{N} \log\frac{N\abs{U_i\cap V_j}}{\abs{U_i}\abs{V_j}} + MI(U,V) = \sum_{i=1}^{\abs{U}} \sum_{j=1}^{\abs{V}} \frac{\abs{U_i\cap V_j}}{N} + \log\frac{N\abs{U_i\cap V_j}}{\abs{U_i}\abs{V_j}} Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions, :math:`\abs{U_i}` is the number of samples in cluster :math:`U_i`, and diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index 6a61e2ac40d..05ab1af5605 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -15,13 +15,13 @@ import torch from torch import Tensor, tensor + from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels def _mutual_info_score_update( preds: Tensor, target: Tensor, - # num_classes: int ) -> Tuple[Tensor, Tensor, Tensor]: """Update and return variables required to compute the mutual information score. @@ -47,16 +47,16 @@ def _mutual_info_score_compute(contingency: Tensor) -> Tensor: mutual_info: mutual information score """ - N = contingency.sum() - U = contingency.sum(dim=1) - V = contingency.sum(dim=0) + n = contingency.sum() + u = contingency.sum(dim=1) + v = contingency.sum(dim=0) # Check if preds or target labels only have one cluster - if U.size() == 1 or V.size() == 1: + if u.size() == 1 or v.size() == 1: return tensor(0.0) - log_outer = torch.log(U).reshape(-1, 1) + torch.log(V) - mutual_info = contingency / N * (torch.log(N) + torch.log(contingency) - log_outer) + log_outer = torch.log(u).reshape(-1, 1) + torch.log(v) + mutual_info = contingency / n * (torch.log(n) + torch.log(contingency) - log_outer) return mutual_info.sum() diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index 334584b8ff4..a0b35277a81 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch +from typing import Optional +import torch from torch import Tensor + from torchmetrics.utilities.checks import _check_same_shape -from typing import Optional def calculate_contingency_matrix( @@ -26,7 +27,9 @@ def calculate_contingency_matrix( Args: preds: predicted labels target: ground truth labels - sparse: If True, returns contingency matrix as a sparse matrix. + eps: value added to contingency matrix + sparse: If True, returns contingency matrix as a sparse matrix. Else, return as dense matrix. + `eps` must be `None` if `sparse` is `True`. Returns: contingency: contingency matrix of shape (n_classes_target, n_classes_preds) diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index d594d0f140a..555539e4a91 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple -from functools import partial import pytest import torch From f677483e9b0d9d9c20850ef9e77d5ba8b9dbd1e4 Mon Sep 17 00:00:00 2001 From: Shion Date: Wed, 23 Aug 2023 00:34:24 +0900 Subject: [PATCH 13/44] Fix docstring examples --- .../clustering/mutual_info_score.py | 20 ++++++++++--------- .../clustering/mutual_info_score.py | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index 4e2a573870f..51709716c58 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -52,12 +52,13 @@ class MutualInfoScore(Metric): - ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Mutual Information Score Example: + >>> import torch >>> from torchmetrics.clustering import MutualInfoScore - >>> target = torch.tensor([]) - >>> preds = torch.tensor([]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) + >>> preds = torch.tensor([2, 1, 0, 1, 0]) >>> mi_score = MutualInfoScore() >>> mi_score(preds, target) - tensor() + tensor(0.5004) """ @@ -106,8 +107,9 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting a single value >>> import torch >>> from torchmetrics.clustering import MutualInfoScore - >>> metric = MutualInfoScore(num_classes=5) - >>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))) + >>> metric = MutualInfoScore() + >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> metric.compute() >>> fig_, ax_ = metric.plot() .. plot:: @@ -116,11 +118,11 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting multiple values >>> import torch >>> from torchmetrics.clustering import MutualInfoScore - >>> metric = MutualInfoScore(num_classes=5) - >>> values = [ ] + >>> metric = MutualInfoScore() >>> for _ in range(10): - ... values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))) - >>> fig_, ax_ = metric.plot(values) + ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))) + >>> metric.compute() + >>> fig_, ax_ = metric.plot() """ return self._plot(val, ax) diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index 05ab1af5605..7cf032b7ebb 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -72,7 +72,7 @@ def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor: >>> target = torch.tensor([0, 3, 2, 2, 1]) >>> preds = torch.tensor([1, 3, 2, 0, 1]) >>> mutual_info_score(preds, target) - tensor([1.05492]) + tensor(1.0549) """ contingency = _mutual_info_score_update(preds, target) From 0d361d16f5c2ccfc5d3228a69ffd527e02b73d1e Mon Sep 17 00:00:00 2001 From: Shion Date: Wed, 23 Aug 2023 00:56:32 +0900 Subject: [PATCH 14/44] Test functional metric for symmetry --- .../unittests/clustering/test_mutual_info_score.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index 555539e4a91..ca9fa1fa0b5 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -60,14 +60,12 @@ class TestMutualInfoScore(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) def test_mutual_info_score(self, preds, target, compute_on_cpu, ddp): """Test class implementation of metric.""" - # metric_args = {"num_classes": NUM_CLASSES} self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=MutualInfoScore, reference_metric=sklearn_mutual_info_score, - # metric_args=metric_args ) def test_mutual_info_score_functional(self, preds, target): @@ -85,3 +83,15 @@ def test_mutual_info_score_functional_raises_invalid_task(): preds, target = _float_inputs with pytest.raises(ValueError, match=r"Expected *"): mutual_info_score(preds, target) + + +@pytest.mark.parametrize( + ("preds", "target"), + [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + ], +) +def test_mutual_info_score_functional_is_symmetric(preds, target): + """Check that the metric funtional is symmetric.""" + for p, t in zip(preds, target): + assert torch.allclose(mutual_info_score(p, t), mutual_info_score(t, p)) From 422ace325f5364c736f4cbec62c24fbdf34ebdd8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 23 Aug 2023 08:16:35 +0200 Subject: [PATCH 15/44] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 94a52a9a6fa..baef1be3009 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008) ### Changed From bf05b8bc5aa730d1ba1743c255ea5fd0b8e7390d Mon Sep 17 00:00:00 2001 From: Shion Date: Wed, 23 Aug 2023 22:47:23 +0900 Subject: [PATCH 16/44] Fix type hint error. Additional checks for tensor shapes. --- .../functional/clustering/mutual_info_score.py | 7 +------ src/torchmetrics/functional/clustering/utils.py | 2 ++ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index 7cf032b7ebb..f81ab46fd96 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -11,18 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple - import torch from torch import Tensor, tensor from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels -def _mutual_info_score_update( - preds: Tensor, - target: Tensor, -) -> Tuple[Tensor, Tensor, Tensor]: +def _mutual_info_score_update(preds: Tensor, target: Tensor) -> Tensor: """Update and return variables required to compute the mutual information score. Args: diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index a0b35277a81..ae11a9c4524 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -65,6 +65,8 @@ def check_cluster_labels(preds: Tensor, target: Tensor) -> None: """ _check_same_shape(preds, target) + if preds.ndim != 1: + raise ValueError(f"Expected arguments to be 1d tensors but got {preds.ndim} and {target.ndim}") if ( torch.is_floating_point(preds) or torch.is_complex(preds) From e9a123319f6089e60ce28bd8a1b07799de9ee106 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 23 Aug 2023 22:48:52 +0900 Subject: [PATCH 17/44] Update src/torchmetrics/clustering/mutual_info_score.py Co-authored-by: Nicki Skafte Detlefsen --- src/torchmetrics/clustering/mutual_info_score.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index 51709716c58..fd56000ba6a 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -39,9 +39,6 @@ class MutualInfoScore(Metric): The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same mutual information score. - Args: - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)`` @@ -51,6 +48,9 @@ class MutualInfoScore(Metric): - ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Mutual Information Score + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + Example: >>> import torch >>> from torchmetrics.clustering import MutualInfoScore From 9cff8767c1109c27756ffdb60a66beff8964a584 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Wed, 23 Aug 2023 22:49:04 +0900 Subject: [PATCH 18/44] Update src/torchmetrics/clustering/mutual_info_score.py Co-authored-by: Nicki Skafte Detlefsen --- src/torchmetrics/clustering/mutual_info_score.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index fd56000ba6a..123052bd4ac 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -66,7 +66,6 @@ class MutualInfoScore(Metric): higher_is_better = None full_state_update: bool = True plot_lower_bound: float = 0.0 - plot_upper_bound: float = 1.0 # theoretical upper bound is +inf preds: List[Tensor] target: List[Tensor] contingency: Tensor From e4523d41d04783e4c12b42898ee79952464fc486 Mon Sep 17 00:00:00 2001 From: Shion Date: Thu, 24 Aug 2023 20:55:35 +0900 Subject: [PATCH 19/44] Test contingency matrix calculation --- .../functional/clustering/utils.py | 30 ++++++- tests/unittests/clustering/test_utils.py | 78 +++++++++++++++++++ 2 files changed, 104 insertions(+), 4 deletions(-) create mode 100644 tests/unittests/clustering/test_utils.py diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index ae11a9c4524..7f685e8cbcc 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -20,7 +20,7 @@ def calculate_contingency_matrix( - preds: Tensor, target: Tensor, eps: Optional[float] = 1e-16, sparse: bool = False + preds: Tensor, target: Tensor, eps: Optional[float] = None, sparse: bool = False ) -> Tensor: """Calculate contingency matrix. @@ -34,9 +34,21 @@ def calculate_contingency_matrix( Returns: contingency: contingency matrix of shape (n_classes_target, n_classes_preds) + Example: + >>> import torch + >>> from torchmetrics.functional.clustering.utils import calculate_contingency_matrix + >>> preds = torch.tensor([2, 1, 0, 1, 0]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) + >>> calculate_contingency_matrix(preds, target, eps=1e-16) + tensor([[1.0000e+00, 1.0000e-16, 1.0000e+00], + [1.0000e+00, 1.0000e+00, 1.0000e-16], + [1.0000e-16, 1.0000e+00, 1.0000e-16]]) + """ if eps is not None and sparse is True: raise ValueError("Cannot specify `eps` and return sparse tensor.") + if preds.ndim != 1 or target.ndim != 1: + raise ValueError(f"Expected 1d `preds` and `target` but got {preds.ndim} and {target.dim}.") preds_classes, preds_idx = torch.unique(preds, return_inverse=True) target_classes, target_idx = torch.unique(target, return_inverse=True) @@ -45,13 +57,23 @@ def calculate_contingency_matrix( n_classes_target = target_classes.size(0) contingency = torch.sparse_coo_tensor( - torch.stack((target_idx, preds_idx)), torch.ones(target_idx.size(0)), (n_classes_target, n_classes_preds) + torch.stack( + ( + target_idx, + preds_idx, + ) + ), + torch.ones(target_idx.size(0)), + ( + n_classes_target, + n_classes_preds, + ), ) if not sparse: contingency = contingency.to_dense() - if eps: - contingency = contingency + eps + if eps: + contingency = contingency + eps return contingency diff --git a/tests/unittests/clustering/test_utils.py b/tests/unittests/clustering/test_utils.py new file mode 100644 index 00000000000..95ee1a6a4a7 --- /dev/null +++ b/tests/unittests/clustering/test_utils.py @@ -0,0 +1,78 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple + +import numpy as np +import pytest +import torch +from sklearn.metrics.cluster import contingency_matrix as sklearn_contingency_matrix +from torchmetrics.functional.clustering.utils import calculate_contingency_matrix + +from unittests import BATCH_SIZE +from unittests.helpers import seed_all + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) +NUM_CLASSES = 10 + +_sklearn_inputs = Input( + preds=torch.tensor([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]), + target=torch.tensor([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2]), +) + +_single_dim_inputs = Input( + preds=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,)), + target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,)), +) + +_multi_dim_inputs = Input( + preds=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE, 2)), + target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE, 2)), +) + + +@pytest.mark.parametrize( + ("preds", "target"), + [(_sklearn_inputs.preds, _sklearn_inputs.target), (_single_dim_inputs.preds, _single_dim_inputs.target)], +) +class TestContingencyMatrix: + """Test calculation of dense and sparse contingency matrices.""" + + atol = 1e-8 + + @pytest.mark.parametrize("eps", [None, 1e-16]) + def test_contingency_matrix_dense(self, preds, target, eps): + """Check that dense contingency matrices are calculated correctly.""" + tm_c = calculate_contingency_matrix(preds, target, eps) + sklearn_c = sklearn_contingency_matrix(target, preds, eps=eps) + assert np.allclose(tm_c, sklearn_c, atol=self.atol) + + def test_contingency_matrix_sparse(self, preds, target): + """Check that sparse contingency matrices are calculated correctly.""" + tm_c = calculate_contingency_matrix(preds, target, sparse=True).to_dense().numpy() + sklearn_c = sklearn_contingency_matrix(target, preds, sparse=True).toarray() + assert np.allclose(tm_c, sklearn_c, atol=self.atol) + + +def test_eps_and_sparse_error(): + """Check that contingency matrix is not calculated if `eps` is nonzero and `sparse` is True.""" + with pytest.raises(ValueError, match="Cannot specify*"): + calculate_contingency_matrix(_single_dim_inputs.preds, _single_dim_inputs.target, eps=1e-16, sparse=True) + + +def test_multidimensional_contingency_error(): + """Check that contingency matrix is not calculated for multidimensional input.""" + with pytest.raises(ValueError, match="Expected 1d*"): + calculate_contingency_matrix(_multi_dim_inputs.preds, _multi_dim_inputs.target) From f1cc3df37e0b943177cbeb1212693776931e6d68 Mon Sep 17 00:00:00 2001 From: Shion Date: Thu, 24 Aug 2023 21:57:07 +0900 Subject: [PATCH 20/44] fix mutual info score calculation. all test passing. --- .../functional/clustering/mutual_info_score.py | 7 ++++++- tests/unittests/clustering/test_mutual_info_score.py | 8 ++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index f81ab46fd96..f7c7cbfa587 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -50,7 +50,12 @@ def _mutual_info_score_compute(contingency: Tensor) -> Tensor: if u.size() == 1 or v.size() == 1: return tensor(0.0) - log_outer = torch.log(u).reshape(-1, 1) + torch.log(v) + # Find indices of nonzero values in U and V + nzu, nzv = torch.nonzero(contingency, as_tuple=True) + contingency = contingency[nzu, nzv] + + # Calculate MI using entries corresponding to nonzero contingency matrix entries + log_outer = torch.log(u[nzu]) + torch.log(v[nzv]) mutual_info = contingency / n * (torch.log(n) + torch.log(contingency) - log_outer) return mutual_info.sum() diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index ca9fa1fa0b5..5c89b58b169 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -78,6 +78,14 @@ def test_mutual_info_score_functional(self, preds, target): ) +def test_mutual_info_score_functional_single_cluster(): + """Check that metric rejects continuous-valued inputs.""" + tensor_a = torch.randint(NUM_CLASSES, (BATCH_SIZE,)) + tensor_b = torch.zeros(BATCH_SIZE, dtype=torch.int) + assert torch.allclose(mutual_info_score(tensor_a, tensor_b), torch.tensor(0.0)) + assert torch.allclose(mutual_info_score(tensor_b, tensor_a), torch.tensor(0.0)) + + def test_mutual_info_score_functional_raises_invalid_task(): """Check that metric rejects continuous-valued inputs.""" preds, target = _float_inputs From f278c5c25bc5713f62bbe282eeb723a91762f3a6 Mon Sep 17 00:00:00 2001 From: Shion Date: Thu, 24 Aug 2023 22:00:15 +0900 Subject: [PATCH 21/44] fix plotting docstring --- src/torchmetrics/clustering/mutual_info_score.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index 123052bd4ac..f4388f0a528 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -54,8 +54,8 @@ class MutualInfoScore(Metric): Example: >>> import torch >>> from torchmetrics.clustering import MutualInfoScore - >>> target = torch.tensor([0, 2, 1, 1, 0]) >>> preds = torch.tensor([2, 1, 0, 1, 0]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) >>> mi_score = MutualInfoScore() >>> mi_score(preds, target) tensor(0.5004) @@ -108,8 +108,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> from torchmetrics.clustering import MutualInfoScore >>> metric = MutualInfoScore() >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) - >>> metric.compute() - >>> fig_, ax_ = metric.plot() + >>> fig_, ax_ = metric.plot(metric.compute()) .. plot:: :scale: 75 @@ -120,8 +119,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> metric = MutualInfoScore() >>> for _ in range(10): ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))) - >>> metric.compute() - >>> fig_, ax_ = metric.plot() + >>> fig_, ax_ = metric.plot(metric.compute()) """ return self._plot(val, ax) From c866355203111b55b3c6f13cdb4215e8a3d4e8ca Mon Sep 17 00:00:00 2001 From: Shion Date: Thu, 24 Aug 2023 22:20:44 +0900 Subject: [PATCH 22/44] add paren --- src/torchmetrics/clustering/mutual_info_score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index f4388f0a528..86118daf41c 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -118,7 +118,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> from torchmetrics.clustering import MutualInfoScore >>> metric = MutualInfoScore() >>> for _ in range(10): - ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))) + ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) >>> fig_, ax_ = metric.plot(metric.compute()) """ From ca5ff5fc39d488277d650fe6cf4da5566a9bb9b0 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 25 Aug 2023 15:05:47 +0200 Subject: [PATCH 23/44] fix doc import --- docs/source/clustering/mutual_info_score.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/clustering/mutual_info_score.rst b/docs/source/clustering/mutual_info_score.rst index 39291801ae9..e5adf06eaa9 100644 --- a/docs/source/clustering/mutual_info_score.rst +++ b/docs/source/clustering/mutual_info_score.rst @@ -12,10 +12,10 @@ Mutual Information Score Module Interface ________________ -.. autoclass:: torchmetrics.MutualInfoScore +.. autoclass:: torchmetrics.clustering.MutualInfoScore :exclude-members: update, compute Functional Interface ____________________ -.. autofunction:: torchmetrics.functional.mutual_info_score +.. autofunction:: torchmetrics.functional.clustering.mutual_info_score From 157e8f833bd02c50fbcdba0ca1d2b4ddb4d9e7ba Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 25 Aug 2023 15:10:54 +0200 Subject: [PATCH 24/44] fix on gpu --- src/torchmetrics/functional/clustering/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index 7f685e8cbcc..64dff0377ee 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -63,7 +63,7 @@ def calculate_contingency_matrix( preds_idx, ) ), - torch.ones(target_idx.size(0)), + torch.ones(target_idx.shape[0], dtype=preds_idx.dtype, device=preds_idx.device), ( n_classes_target, n_classes_preds, From 1d6693a168b7f03bc2edfcb91d7f83416228d5f4 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 25 Aug 2023 16:55:54 +0200 Subject: [PATCH 25/44] add implementation --- src/torchmetrics/clustering/__init__.py | 2 + src/torchmetrics/clustering/rand_score.py | 125 ++++++++++++++++++ .../functional/clustering/__init__.py | 3 +- .../functional/clustering/rand_score.py | 79 +++++++++++ .../functional/clustering/utils.py | 67 ++++++++++ 5 files changed, 275 insertions(+), 1 deletion(-) create mode 100644 src/torchmetrics/clustering/rand_score.py create mode 100644 src/torchmetrics/functional/clustering/rand_score.py diff --git a/src/torchmetrics/clustering/__init__.py b/src/torchmetrics/clustering/__init__.py index baeb8c88d31..118e59cc451 100644 --- a/src/torchmetrics/clustering/__init__.py +++ b/src/torchmetrics/clustering/__init__.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.clustering.mutual_info_score import MutualInfoScore +from torchmetrics.clustering.rand_score import RandScore __all__ = [ "MutualInfoScore", + "RandScore", ] diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py new file mode 100644 index 00000000000..7226d16ed1f --- /dev/null +++ b/src/torchmetrics/clustering/rand_score.py @@ -0,0 +1,125 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Sequence, Union + +from torch import Tensor + +from torchmetrics.functional.clustering.rand_score import rand_score +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["RandScore.plot"] + + +class RandScore(Metric): + r"""Compute `Mutual Information Score`_. + + .. math:: + MI(U,V) = \sum_{i=1}^{\abs{U}} \sum_{j=1}^{\abs{V}} \frac{\abs{U_i\cap V_j}}{N} + \log\frac{N\abs{U_i\cap V_j}}{\abs{U_i}\abs{V_j}} + + Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions, + :math:`\abs{U_i}` is the number of samples in cluster :math:`U_i`, and + :math:`\abs{V_i}` is the number of samples in cluster :math:`V_i`. + + The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields + the same mutual information score. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)`` + - ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)`` + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Mutual Information Score + + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torchmetrics.clustering import RandScore + >>> preds = torch.tensor([2, 1, 0, 1, 0]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) + >>> mi_score = RandScore() + >>> mi_score(preds, target) + tensor(0.5004) + + """ + + is_differentiable = True + higher_is_better = None + full_state_update: bool = True + plot_lower_bound: float = 0.0 + preds: List[Tensor] + target: List[Tensor] + contingency: Tensor + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.preds.append(preds) + self.target.append(target) + + def compute(self) -> Tensor: + """Compute mutual information over state.""" + return rand_score(dim_zero_cat(self.preds), dim_zero_cat(self.target)) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.clustering import MutualInfoScore + >>> metric = MutualInfoScore() + >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.clustering import MutualInfoScore + >>> metric = MutualInfoScore() + >>> for _ in range(10): + ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py index c6f46126ca3..a2c3c110b1d 100644 --- a/src/torchmetrics/functional/clustering/__init__.py +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score +from torchmetrics.functional.clustering.rand_score import rand_score -__all__ = ["mutual_info_score"] +__all__ = ["mutual_info_score", "rand_score"] diff --git a/src/torchmetrics/functional/clustering/rand_score.py b/src/torchmetrics/functional/clustering/rand_score.py new file mode 100644 index 00000000000..6c5b96cdf3b --- /dev/null +++ b/src/torchmetrics/functional/clustering/rand_score.py @@ -0,0 +1,79 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import Tensor + +from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels + + +def _rand_score_update(preds: Tensor, target: Tensor) -> Tensor: + """Update and return variables required to compute the rand score. + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + + Returns: + contingency: contingency matrix + + """ + check_cluster_labels(preds, target) + return calculate_contingency_matrix(preds, target) + + +def _rand_score_compute(contingency: Tensor) -> Tensor: + """Compute the rand score based on the contingency matrix. + + Args: + contingency: contingency matrix + + Returns: + rand_score: rand score + + """ + n_samples = contingency.sum() + n_c = contingency.sum(dim=1) + n_k = contingency.sum(dim=0) + sum_squared = (contingency**2).sum() + + pair_matrix = torch.zeros(2, 2, dtype=contingency.dtype, device=contingency.device) + pair_matrix[1, 1] = sum_squared - n_samples + pair_matrix[0, 1] = (contingency * n_k).sum() - sum_squared + pair_matrix[1, 0] = (contingency.T * n_c).sum() - sum_squared + pair_matrix[0, 0] = n_samples**2 - pair_matrix[0, 1] - pair_matrix[1, 0] - sum_squared + + numerator = pair_matrix.diagonal().sum() + denominator = pair_matrix.sum() + if numerator == denominator or denominator == 0: + # Special limit cases: no clustering since the data is not split; + # or trivial clustering where each document is assigned a unique + # cluster. These are perfect matches hence return 1.0. + return 1.0 + + return numerator / denominator + + +def rand_score(preds: Tensor, target: Tensor) -> Tensor: + """Compute the Rand score between two clusterings. + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + + Returns: + rand_score: rand score + + """ + contingency = _rand_score_update(preds, target) + return _rand_score_compute(contingency) diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index 64dff0377ee..048a3a0d106 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -99,3 +99,70 @@ def check_cluster_labels(preds: Tensor, target: Tensor) -> None: f"Expected real, discrete values but received {preds.dtype} for" f"predictions and {target.dtype} for target labels instead." ) + + +def calcualte_pair_cluster_confusion_matrix( + preds: Optional[Tensor] = None, + target: Optional[Tensor] = None, + contingency: Optional[Tensor] = None, +) -> Tensor: + """Calculates the pair cluster confusion matrix. + + Can either be calculated from predicted cluster labels and target cluster labels or from a pre-computed + contingency matrix. The pair cluster confusion matrix is a 2x2 matrix where that defines the similarity between + two clustering by considering all pairs of samples and counting pairs that are assigned into same or different + clusters in the predicted and target clusterings. + + Note that the matrix is not symmetric. + + Inspired by: + https://scikit-learn.org/stable/modules/generated/sklearn.metrics.cluster.pair_confusion_matrix.html + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + contingency: contingency matrix + + Returns: + A 2x2 tensor containing the pair cluster confusion matrix. + + Raises: + ValueError: + If neither `preds` and `target` nor `contingency` are provided. + ValueError: + If both `preds` and `target` and `contingency` are provided. + + Example: + >>> import torch + >>> from torchmetrics.functional.clustering.utils import calcualte_pair_cluster_confusion_matrix + >>> preds = torch.tensor([0, 0, 1, 1]) + >>> target = torch.tensor([1, 1, 0, 0]) + >>> calcualte_pair_cluster_confusion_matrix(preds, target) + tensor([[8, 0], + [0, 4]]) + >>> preds = torch.tensor([0, 0, 1, 2]) + >>> target = torch.tensor([0, 0, 1, 1]) + >>> calcualte_pair_cluster_confusion_matrix(preds, target) + tensor([[8, 2], + [0, 2]]) + + """ + if preds is None and target is None and contingency is None: + raise ValueError("Must provide either `preds` and `target` or `contingency`.") + if preds is not None and target is not None and contingency is not None: + raise ValueError("Must provide either `preds` and `target` or `contingency`, not both.") + + if preds is not None and target is not None: + contingency = calculate_contingency_matrix(preds, target) + + n_samples = contingency.sum() + n_c = contingency.sum(dim=1) + n_k = contingency.sum(dim=0) + sum_squared = (contingency**2).sum() + + pair_matrix = torch.zeros(2, 2, dtype=contingency.dtype, device=contingency.device) + pair_matrix[1, 1] = sum_squared - n_samples + pair_matrix[1, 0] = (contingency * n_k).sum() - sum_squared + pair_matrix[0, 1] = (contingency.T * n_c).sum() - sum_squared + pair_matrix[0, 0] = n_samples**2 - pair_matrix[0, 1] - pair_matrix[1, 0] - sum_squared + return pair_matrix From 700617bbfc290ea95ce2c010648a4cce46d4fb97 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 25 Aug 2023 16:56:09 +0200 Subject: [PATCH 26/44] add docs --- docs/source/clustering/rand_score.rst | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 docs/source/clustering/rand_score.rst diff --git a/docs/source/clustering/rand_score.rst b/docs/source/clustering/rand_score.rst new file mode 100644 index 00000000000..62650c2d454 --- /dev/null +++ b/docs/source/clustering/rand_score.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Rand Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/clustering.svg + :tags: Clustering + +.. include:: ../links.rst + +########## +Rand Score +########## + +Module Interface +________________ + +.. autoclass:: torchmetrics.clustering.RandScore + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.clustering.rand_score From 74dd965aa98a44d21947cbaf3f9b4866ddd4526f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 25 Aug 2023 17:03:57 +0200 Subject: [PATCH 27/44] fix text --- docs/source/links.rst | 1 + src/torchmetrics/clustering/rand_score.py | 19 ++-- tests/unittests/clustering/test_rand_score.py | 96 +++++++++++++++++++ tests/unittests/clustering/test_utils.py | 22 ++++- 4 files changed, 126 insertions(+), 12 deletions(-) create mode 100644 tests/unittests/clustering/test_rand_score.py diff --git a/docs/source/links.rst b/docs/source/links.rst index 7627490c661..7e875191a1f 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -152,3 +152,4 @@ .. _GIOU: https://arxiv.org/abs/1902.09630 .. _Mutual Information Score: https://en.wikipedia.org/wiki/Mutual_information .. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools +.. _Rand Score: https://link.springer.com/article/10.1007/BF01908075 diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py index 7226d16ed1f..fd87706864f 100644 --- a/src/torchmetrics/clustering/rand_score.py +++ b/src/torchmetrics/clustering/rand_score.py @@ -26,27 +26,24 @@ class RandScore(Metric): - r"""Compute `Mutual Information Score`_. + r"""Compute `Rand Score`_ (alternative know as Rand Index). .. math:: - MI(U,V) = \sum_{i=1}^{\abs{U}} \sum_{j=1}^{\abs{V}} \frac{\abs{U_i\cap V_j}}{N} - \log\frac{N\abs{U_i\cap V_j}}{\abs{U_i}\abs{V_j}} + RS(U, V) = \text{number of agreeing pairs} / \text{number of pairs} - Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions, - :math:`\abs{U_i}` is the number of samples in cluster :math:`U_i`, and - :math:`\abs{V_i}` is the number of samples in cluster :math:`V_i`. + The number of agreeing pairs is every :math:`(i, j)` pair of samples where :math:`i \in U` and :math:`j \in V` + (the predicted and true clusterings, respectively) that are in the same cluster for both clusterings. - The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields - the same mutual information score. + The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score. As input to ``forward`` and ``update`` the metric accepts the following input: - - ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)`` - - ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)`` + - ``preds`` (:class:`~torch.Tensor`): either single integer tensor with shape ``(N,)`` + - ``target`` (:class:`~torch.Tensor`): either single integer tensor with shape ``(N,)`` As output of ``forward`` and ``compute`` the metric returns the following output: - - ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Mutual Information Score + - ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Rand Score Args: kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. diff --git a/tests/unittests/clustering/test_rand_score.py b/tests/unittests/clustering/test_rand_score.py new file mode 100644 index 00000000000..a5a6df11267 --- /dev/null +++ b/tests/unittests/clustering/test_rand_score.py @@ -0,0 +1,96 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple + +import pytest +import torch +from sklearn.metrics import rand_score as sklearn_rand_score +from torchmetrics.clustering.rand_score import RandScore +from torchmetrics.functional.clustering.rand_score import rand_score + +from unittests import BATCH_SIZE, NUM_BATCHES +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) +NUM_CLASSES = 10 + +_single_target_inputs1 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_single_target_inputs2 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_float_inputs = Input( + preds=torch.rand((NUM_BATCHES, BATCH_SIZE)), + target=torch.rand((NUM_BATCHES, BATCH_SIZE)), +) + + +@pytest.mark.parametrize( + "preds, target", + [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), + ], +) +class TestRandScore(MetricTester): + """Test class for `RandScore` metric.""" + + atol = 1e-5 + + @pytest.mark.parametrize("ddp", [True, False]) + def test_rand_score(self, preds, target, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=RandScore, + reference_metric=sklearn_rand_score, + ) + + def test_rand_score_functional(self, preds, target): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=rand_score, + reference_metric=sklearn_rand_score, + ) + + +def test_rand_score_functional_raises_invalid_task(): + """Check that metric rejects continuous-valued inputs.""" + preds, target = _float_inputs + with pytest.raises(ValueError, match=r"Expected *"): + rand_score(preds, target) + + +@pytest.mark.parametrize( + ("preds", "target"), + [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + ], +) +def test_rand_score_functional_is_symmetric(preds, target): + """Check that the metric funtional is symmetric.""" + for p, t in zip(preds, target): + assert torch.allclose(rand_score(p, t), rand_score(t, p)) diff --git a/tests/unittests/clustering/test_utils.py b/tests/unittests/clustering/test_utils.py index 95ee1a6a4a7..571ee5614ee 100644 --- a/tests/unittests/clustering/test_utils.py +++ b/tests/unittests/clustering/test_utils.py @@ -17,7 +17,11 @@ import pytest import torch from sklearn.metrics.cluster import contingency_matrix as sklearn_contingency_matrix -from torchmetrics.functional.clustering.utils import calculate_contingency_matrix +from sklearn.metrics.cluster import pair_confusion_matrix as sklearn_pair_confusion_matrix +from torchmetrics.functional.clustering.utils import ( + calcualte_pair_cluster_confusion_matrix, + calculate_contingency_matrix, +) from unittests import BATCH_SIZE from unittests.helpers import seed_all @@ -76,3 +80,19 @@ def test_multidimensional_contingency_error(): """Check that contingency matrix is not calculated for multidimensional input.""" with pytest.raises(ValueError, match="Expected 1d*"): calculate_contingency_matrix(_multi_dim_inputs.preds, _multi_dim_inputs.target) + + +@pytest.mark.parametrize( + ("preds", "target"), + [(_sklearn_inputs.preds, _sklearn_inputs.target), (_single_dim_inputs.preds, _single_dim_inputs.target)], +) +class TestPairClusterConfusionMatrix: + """Test that implementation matches sklearns.""" + + atol = 1e-8 + + def test_pair_cluster_confusion_matrix(self, preds, target): + """Check that pair cluster confusion matrix is calculated correctly.""" + tm_res = calcualte_pair_cluster_confusion_matrix(preds, target) + sklearn_res = sklearn_pair_confusion_matrix(preds, target) + assert np.allclose(tm_res, sklearn_res, atol=self.atol) From 3f5536d8b7fae003f71ed2477ae6e51dad677d56 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 25 Aug 2023 19:55:56 +0200 Subject: [PATCH 28/44] fixes to doctest + docstrings --- .../clustering/mutual_info_score.py | 4 +-- src/torchmetrics/clustering/rand_score.py | 20 +++++++------- .../clustering/mutual_info_score.py | 4 +-- .../functional/clustering/rand_score.py | 27 ++++++++++--------- tests/unittests/utilities/test_plot.py | 3 +++ 5 files changed, 32 insertions(+), 26 deletions(-) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index 86118daf41c..e943c2aec27 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -41,8 +41,8 @@ class MutualInfoScore(Metric): As input to ``forward`` and ``update`` the metric accepts the following input: - - ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)`` - - ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)`` + - ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels + - ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels As output of ``forward`` and ``compute`` the metric returns the following output: diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py index fd87706864f..49341b00464 100644 --- a/src/torchmetrics/clustering/rand_score.py +++ b/src/torchmetrics/clustering/rand_score.py @@ -38,12 +38,12 @@ class RandScore(Metric): As input to ``forward`` and ``update`` the metric accepts the following input: - - ``preds`` (:class:`~torch.Tensor`): either single integer tensor with shape ``(N,)`` - - ``target`` (:class:`~torch.Tensor`): either single integer tensor with shape ``(N,)`` + - ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels + - ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels As output of ``forward`` and ``compute`` the metric returns the following output: - - ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Rand Score + - ``rand_score`` (:class:`~torch.Tensor`): A tensor with the Rand Score Args: kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -53,9 +53,9 @@ class RandScore(Metric): >>> from torchmetrics.clustering import RandScore >>> preds = torch.tensor([2, 1, 0, 1, 0]) >>> target = torch.tensor([0, 2, 1, 1, 0]) - >>> mi_score = RandScore() - >>> mi_score(preds, target) - tensor(0.5004) + >>> metric = RandScore() + >>> metric(preds, target) + tensor(0.6000) """ @@ -102,8 +102,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting a single value >>> import torch - >>> from torchmetrics.clustering import MutualInfoScore - >>> metric = MutualInfoScore() + >>> from torchmetrics.clustering import RandScore + >>> metric = RandScore() >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) >>> fig_, ax_ = metric.plot(metric.compute()) @@ -112,8 +112,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting multiple values >>> import torch - >>> from torchmetrics.clustering import MutualInfoScore - >>> metric = MutualInfoScore() + >>> from torchmetrics.clustering import RandScore + >>> metric = RandScore() >>> for _ in range(10): ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) >>> fig_, ax_ = metric.plot(metric.compute()) diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index f7c7cbfa587..a729726436e 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -64,8 +64,8 @@ def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor: """Compute mutual information between two clusterings. Args: - preds: predicted classes - target: ground truth classes + preds: predicted cluster labels + target: ground truth cluster labels Example: >>> from torchmetrics.functional.clustering import mutual_info_score diff --git a/src/torchmetrics/functional/clustering/rand_score.py b/src/torchmetrics/functional/clustering/rand_score.py index 6c5b96cdf3b..98e7f0258ea 100644 --- a/src/torchmetrics/functional/clustering/rand_score.py +++ b/src/torchmetrics/functional/clustering/rand_score.py @@ -14,7 +14,11 @@ import torch from torch import Tensor -from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels +from torchmetrics.functional.clustering.utils import ( + calcualte_pair_cluster_confusion_matrix, + calculate_contingency_matrix, + check_cluster_labels, +) def _rand_score_update(preds: Tensor, target: Tensor) -> Tensor: @@ -42,16 +46,7 @@ def _rand_score_compute(contingency: Tensor) -> Tensor: rand_score: rand score """ - n_samples = contingency.sum() - n_c = contingency.sum(dim=1) - n_k = contingency.sum(dim=0) - sum_squared = (contingency**2).sum() - - pair_matrix = torch.zeros(2, 2, dtype=contingency.dtype, device=contingency.device) - pair_matrix[1, 1] = sum_squared - n_samples - pair_matrix[0, 1] = (contingency * n_k).sum() - sum_squared - pair_matrix[1, 0] = (contingency.T * n_c).sum() - sum_squared - pair_matrix[0, 0] = n_samples**2 - pair_matrix[0, 1] - pair_matrix[1, 0] - sum_squared + pair_matrix = calcualte_pair_cluster_confusion_matrix(contingency=contingency) numerator = pair_matrix.diagonal().sum() denominator = pair_matrix.sum() @@ -59,7 +54,7 @@ def _rand_score_compute(contingency: Tensor) -> Tensor: # Special limit cases: no clustering since the data is not split; # or trivial clustering where each document is assigned a unique # cluster. These are perfect matches hence return 1.0. - return 1.0 + return torch.ones_like(numerator, dtype=torch.float32) return numerator / denominator @@ -74,6 +69,14 @@ def rand_score(preds: Tensor, target: Tensor) -> Tensor: Returns: rand_score: rand score + Example: + >>> from torchmetrics.functional.clustering import rand_score + >>> import torch + >>> rand_score(torch.tensor([0, 0, 1, 1]), torch.tensor([1, 1, 0, 0])) + tensor(1.) + >>> rand_score(torch.tensor([0, 0, 1, 2]), torch.tensor([0, 0, 1, 1])) + tensor(0.8333) + """ contingency = _rand_score_update(preds, target) return _rand_score_compute(contingency) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index b47b06da7f8..f5c7d8d4562 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -91,6 +91,7 @@ MultilabelROC, MultilabelSpecificity, ) +from torchmetrics.clustering import MutualInfoScore, RandScore from torchmetrics.detection import PanopticQuality from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio @@ -614,6 +615,8 @@ id="squad", ), pytest.param(TranslationEditRate, _text_input_3, _text_input_4, id="translation edit rate"), + pytest.param(MutualInfoScore, _nominal_input, _nominal_input, id="mutual info score"), + pytest.param(RandScore, _nominal_input, _nominal_input, id="rand score"), ], ) @pytest.mark.parametrize("num_vals", [1, 3]) From 747a4adb63501011da8675b32455309bc14e74c4 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Fri, 25 Aug 2023 19:58:52 +0200 Subject: [PATCH 29/44] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b23ce58d355..25743bfb0dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008) +- Added `RandScore` metric to cluster package ([#2025](https://github.com/Lightning-AI/torchmetrics/pull/2025) + + ### Changed - From 42b10d0ef72b33638cae696864e6de92fc634f0f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 25 Aug 2023 20:22:35 +0200 Subject: [PATCH 30/44] fix --- src/torchmetrics/functional/clustering/rand_score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/clustering/rand_score.py b/src/torchmetrics/functional/clustering/rand_score.py index 98e7f0258ea..2e6848f46d5 100644 --- a/src/torchmetrics/functional/clustering/rand_score.py +++ b/src/torchmetrics/functional/clustering/rand_score.py @@ -67,7 +67,7 @@ def rand_score(preds: Tensor, target: Tensor) -> Tensor: target: ground truth cluster labels Returns: - rand_score: rand score + scalar tensor with the rand score Example: >>> from torchmetrics.functional.clustering import rand_score From 1a6d39fd1961cec96276d55d5d3e3f92fc97e707 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sat, 26 Aug 2023 10:09:33 +0200 Subject: [PATCH 31/44] Update tests/unittests/clustering/test_rand_score.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- tests/unittests/clustering/test_rand_score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/clustering/test_rand_score.py b/tests/unittests/clustering/test_rand_score.py index a5a6df11267..9ff11b1d6c5 100644 --- a/tests/unittests/clustering/test_rand_score.py +++ b/tests/unittests/clustering/test_rand_score.py @@ -87,7 +87,7 @@ def test_rand_score_functional_raises_invalid_task(): @pytest.mark.parametrize( ("preds", "target"), [ - (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs1.preds, _single_target_inputs1.target) ], ) def test_rand_score_functional_is_symmetric(preds, target): From 25eb88de0485578622849417bf588ef04afb2954 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 26 Aug 2023 08:10:07 +0000 Subject: [PATCH 32/44] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/clustering/test_rand_score.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/unittests/clustering/test_rand_score.py b/tests/unittests/clustering/test_rand_score.py index 9ff11b1d6c5..d00fd421d34 100644 --- a/tests/unittests/clustering/test_rand_score.py +++ b/tests/unittests/clustering/test_rand_score.py @@ -86,9 +86,7 @@ def test_rand_score_functional_raises_invalid_task(): @pytest.mark.parametrize( ("preds", "target"), - [ - (_single_target_inputs1.preds, _single_target_inputs1.target) - ], + [(_single_target_inputs1.preds, _single_target_inputs1.target)], ) def test_rand_score_functional_is_symmetric(preds, target): """Check that the metric funtional is symmetric.""" From 4bf95953469cd373b1da5cb12a3ac06f3d57b56d Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 25 Aug 2023 16:55:54 +0200 Subject: [PATCH 33/44] add implementation --- src/torchmetrics/clustering/__init__.py | 2 + src/torchmetrics/clustering/rand_score.py | 125 ++++++++++++++++++ .../functional/clustering/__init__.py | 3 +- .../functional/clustering/rand_score.py | 79 +++++++++++ .../functional/clustering/utils.py | 67 ++++++++++ 5 files changed, 275 insertions(+), 1 deletion(-) create mode 100644 src/torchmetrics/clustering/rand_score.py create mode 100644 src/torchmetrics/functional/clustering/rand_score.py diff --git a/src/torchmetrics/clustering/__init__.py b/src/torchmetrics/clustering/__init__.py index baeb8c88d31..118e59cc451 100644 --- a/src/torchmetrics/clustering/__init__.py +++ b/src/torchmetrics/clustering/__init__.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.clustering.mutual_info_score import MutualInfoScore +from torchmetrics.clustering.rand_score import RandScore __all__ = [ "MutualInfoScore", + "RandScore", ] diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py new file mode 100644 index 00000000000..7226d16ed1f --- /dev/null +++ b/src/torchmetrics/clustering/rand_score.py @@ -0,0 +1,125 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Sequence, Union + +from torch import Tensor + +from torchmetrics.functional.clustering.rand_score import rand_score +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["RandScore.plot"] + + +class RandScore(Metric): + r"""Compute `Mutual Information Score`_. + + .. math:: + MI(U,V) = \sum_{i=1}^{\abs{U}} \sum_{j=1}^{\abs{V}} \frac{\abs{U_i\cap V_j}}{N} + \log\frac{N\abs{U_i\cap V_j}}{\abs{U_i}\abs{V_j}} + + Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions, + :math:`\abs{U_i}` is the number of samples in cluster :math:`U_i`, and + :math:`\abs{V_i}` is the number of samples in cluster :math:`V_i`. + + The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields + the same mutual information score. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)`` + - ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)`` + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Mutual Information Score + + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torchmetrics.clustering import RandScore + >>> preds = torch.tensor([2, 1, 0, 1, 0]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) + >>> mi_score = RandScore() + >>> mi_score(preds, target) + tensor(0.5004) + + """ + + is_differentiable = True + higher_is_better = None + full_state_update: bool = True + plot_lower_bound: float = 0.0 + preds: List[Tensor] + target: List[Tensor] + contingency: Tensor + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.preds.append(preds) + self.target.append(target) + + def compute(self) -> Tensor: + """Compute mutual information over state.""" + return rand_score(dim_zero_cat(self.preds), dim_zero_cat(self.target)) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.clustering import MutualInfoScore + >>> metric = MutualInfoScore() + >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.clustering import MutualInfoScore + >>> metric = MutualInfoScore() + >>> for _ in range(10): + ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py index c6f46126ca3..a2c3c110b1d 100644 --- a/src/torchmetrics/functional/clustering/__init__.py +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score +from torchmetrics.functional.clustering.rand_score import rand_score -__all__ = ["mutual_info_score"] +__all__ = ["mutual_info_score", "rand_score"] diff --git a/src/torchmetrics/functional/clustering/rand_score.py b/src/torchmetrics/functional/clustering/rand_score.py new file mode 100644 index 00000000000..6c5b96cdf3b --- /dev/null +++ b/src/torchmetrics/functional/clustering/rand_score.py @@ -0,0 +1,79 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import Tensor + +from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels + + +def _rand_score_update(preds: Tensor, target: Tensor) -> Tensor: + """Update and return variables required to compute the rand score. + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + + Returns: + contingency: contingency matrix + + """ + check_cluster_labels(preds, target) + return calculate_contingency_matrix(preds, target) + + +def _rand_score_compute(contingency: Tensor) -> Tensor: + """Compute the rand score based on the contingency matrix. + + Args: + contingency: contingency matrix + + Returns: + rand_score: rand score + + """ + n_samples = contingency.sum() + n_c = contingency.sum(dim=1) + n_k = contingency.sum(dim=0) + sum_squared = (contingency**2).sum() + + pair_matrix = torch.zeros(2, 2, dtype=contingency.dtype, device=contingency.device) + pair_matrix[1, 1] = sum_squared - n_samples + pair_matrix[0, 1] = (contingency * n_k).sum() - sum_squared + pair_matrix[1, 0] = (contingency.T * n_c).sum() - sum_squared + pair_matrix[0, 0] = n_samples**2 - pair_matrix[0, 1] - pair_matrix[1, 0] - sum_squared + + numerator = pair_matrix.diagonal().sum() + denominator = pair_matrix.sum() + if numerator == denominator or denominator == 0: + # Special limit cases: no clustering since the data is not split; + # or trivial clustering where each document is assigned a unique + # cluster. These are perfect matches hence return 1.0. + return 1.0 + + return numerator / denominator + + +def rand_score(preds: Tensor, target: Tensor) -> Tensor: + """Compute the Rand score between two clusterings. + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + + Returns: + rand_score: rand score + + """ + contingency = _rand_score_update(preds, target) + return _rand_score_compute(contingency) diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index 64dff0377ee..048a3a0d106 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -99,3 +99,70 @@ def check_cluster_labels(preds: Tensor, target: Tensor) -> None: f"Expected real, discrete values but received {preds.dtype} for" f"predictions and {target.dtype} for target labels instead." ) + + +def calcualte_pair_cluster_confusion_matrix( + preds: Optional[Tensor] = None, + target: Optional[Tensor] = None, + contingency: Optional[Tensor] = None, +) -> Tensor: + """Calculates the pair cluster confusion matrix. + + Can either be calculated from predicted cluster labels and target cluster labels or from a pre-computed + contingency matrix. The pair cluster confusion matrix is a 2x2 matrix where that defines the similarity between + two clustering by considering all pairs of samples and counting pairs that are assigned into same or different + clusters in the predicted and target clusterings. + + Note that the matrix is not symmetric. + + Inspired by: + https://scikit-learn.org/stable/modules/generated/sklearn.metrics.cluster.pair_confusion_matrix.html + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + contingency: contingency matrix + + Returns: + A 2x2 tensor containing the pair cluster confusion matrix. + + Raises: + ValueError: + If neither `preds` and `target` nor `contingency` are provided. + ValueError: + If both `preds` and `target` and `contingency` are provided. + + Example: + >>> import torch + >>> from torchmetrics.functional.clustering.utils import calcualte_pair_cluster_confusion_matrix + >>> preds = torch.tensor([0, 0, 1, 1]) + >>> target = torch.tensor([1, 1, 0, 0]) + >>> calcualte_pair_cluster_confusion_matrix(preds, target) + tensor([[8, 0], + [0, 4]]) + >>> preds = torch.tensor([0, 0, 1, 2]) + >>> target = torch.tensor([0, 0, 1, 1]) + >>> calcualte_pair_cluster_confusion_matrix(preds, target) + tensor([[8, 2], + [0, 2]]) + + """ + if preds is None and target is None and contingency is None: + raise ValueError("Must provide either `preds` and `target` or `contingency`.") + if preds is not None and target is not None and contingency is not None: + raise ValueError("Must provide either `preds` and `target` or `contingency`, not both.") + + if preds is not None and target is not None: + contingency = calculate_contingency_matrix(preds, target) + + n_samples = contingency.sum() + n_c = contingency.sum(dim=1) + n_k = contingency.sum(dim=0) + sum_squared = (contingency**2).sum() + + pair_matrix = torch.zeros(2, 2, dtype=contingency.dtype, device=contingency.device) + pair_matrix[1, 1] = sum_squared - n_samples + pair_matrix[1, 0] = (contingency * n_k).sum() - sum_squared + pair_matrix[0, 1] = (contingency.T * n_c).sum() - sum_squared + pair_matrix[0, 0] = n_samples**2 - pair_matrix[0, 1] - pair_matrix[1, 0] - sum_squared + return pair_matrix From 349408f26df3d6ad79125f763036ab6077d99a81 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 25 Aug 2023 16:56:09 +0200 Subject: [PATCH 34/44] add docs --- docs/source/clustering/rand_score.rst | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 docs/source/clustering/rand_score.rst diff --git a/docs/source/clustering/rand_score.rst b/docs/source/clustering/rand_score.rst new file mode 100644 index 00000000000..62650c2d454 --- /dev/null +++ b/docs/source/clustering/rand_score.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Rand Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/clustering.svg + :tags: Clustering + +.. include:: ../links.rst + +########## +Rand Score +########## + +Module Interface +________________ + +.. autoclass:: torchmetrics.clustering.RandScore + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.clustering.rand_score From d7161140b36d11200ca7a90b71d32c89bbbcadfa Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 25 Aug 2023 17:03:57 +0200 Subject: [PATCH 35/44] fix text --- docs/source/links.rst | 1 + src/torchmetrics/clustering/rand_score.py | 19 ++-- tests/unittests/clustering/test_rand_score.py | 96 +++++++++++++++++++ tests/unittests/clustering/test_utils.py | 22 ++++- 4 files changed, 126 insertions(+), 12 deletions(-) create mode 100644 tests/unittests/clustering/test_rand_score.py diff --git a/docs/source/links.rst b/docs/source/links.rst index 7627490c661..7e875191a1f 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -152,3 +152,4 @@ .. _GIOU: https://arxiv.org/abs/1902.09630 .. _Mutual Information Score: https://en.wikipedia.org/wiki/Mutual_information .. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools +.. _Rand Score: https://link.springer.com/article/10.1007/BF01908075 diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py index 7226d16ed1f..fd87706864f 100644 --- a/src/torchmetrics/clustering/rand_score.py +++ b/src/torchmetrics/clustering/rand_score.py @@ -26,27 +26,24 @@ class RandScore(Metric): - r"""Compute `Mutual Information Score`_. + r"""Compute `Rand Score`_ (alternative know as Rand Index). .. math:: - MI(U,V) = \sum_{i=1}^{\abs{U}} \sum_{j=1}^{\abs{V}} \frac{\abs{U_i\cap V_j}}{N} - \log\frac{N\abs{U_i\cap V_j}}{\abs{U_i}\abs{V_j}} + RS(U, V) = \text{number of agreeing pairs} / \text{number of pairs} - Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions, - :math:`\abs{U_i}` is the number of samples in cluster :math:`U_i`, and - :math:`\abs{V_i}` is the number of samples in cluster :math:`V_i`. + The number of agreeing pairs is every :math:`(i, j)` pair of samples where :math:`i \in U` and :math:`j \in V` + (the predicted and true clusterings, respectively) that are in the same cluster for both clusterings. - The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields - the same mutual information score. + The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score. As input to ``forward`` and ``update`` the metric accepts the following input: - - ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)`` - - ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)`` + - ``preds`` (:class:`~torch.Tensor`): either single integer tensor with shape ``(N,)`` + - ``target`` (:class:`~torch.Tensor`): either single integer tensor with shape ``(N,)`` As output of ``forward`` and ``compute`` the metric returns the following output: - - ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Mutual Information Score + - ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Rand Score Args: kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. diff --git a/tests/unittests/clustering/test_rand_score.py b/tests/unittests/clustering/test_rand_score.py new file mode 100644 index 00000000000..a5a6df11267 --- /dev/null +++ b/tests/unittests/clustering/test_rand_score.py @@ -0,0 +1,96 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple + +import pytest +import torch +from sklearn.metrics import rand_score as sklearn_rand_score +from torchmetrics.clustering.rand_score import RandScore +from torchmetrics.functional.clustering.rand_score import rand_score + +from unittests import BATCH_SIZE, NUM_BATCHES +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + +Input = namedtuple("Input", ["preds", "target"]) +NUM_CLASSES = 10 + +_single_target_inputs1 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_single_target_inputs2 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_float_inputs = Input( + preds=torch.rand((NUM_BATCHES, BATCH_SIZE)), + target=torch.rand((NUM_BATCHES, BATCH_SIZE)), +) + + +@pytest.mark.parametrize( + "preds, target", + [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), + ], +) +class TestRandScore(MetricTester): + """Test class for `RandScore` metric.""" + + atol = 1e-5 + + @pytest.mark.parametrize("ddp", [True, False]) + def test_rand_score(self, preds, target, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=RandScore, + reference_metric=sklearn_rand_score, + ) + + def test_rand_score_functional(self, preds, target): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=rand_score, + reference_metric=sklearn_rand_score, + ) + + +def test_rand_score_functional_raises_invalid_task(): + """Check that metric rejects continuous-valued inputs.""" + preds, target = _float_inputs + with pytest.raises(ValueError, match=r"Expected *"): + rand_score(preds, target) + + +@pytest.mark.parametrize( + ("preds", "target"), + [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + ], +) +def test_rand_score_functional_is_symmetric(preds, target): + """Check that the metric funtional is symmetric.""" + for p, t in zip(preds, target): + assert torch.allclose(rand_score(p, t), rand_score(t, p)) diff --git a/tests/unittests/clustering/test_utils.py b/tests/unittests/clustering/test_utils.py index 95ee1a6a4a7..571ee5614ee 100644 --- a/tests/unittests/clustering/test_utils.py +++ b/tests/unittests/clustering/test_utils.py @@ -17,7 +17,11 @@ import pytest import torch from sklearn.metrics.cluster import contingency_matrix as sklearn_contingency_matrix -from torchmetrics.functional.clustering.utils import calculate_contingency_matrix +from sklearn.metrics.cluster import pair_confusion_matrix as sklearn_pair_confusion_matrix +from torchmetrics.functional.clustering.utils import ( + calcualte_pair_cluster_confusion_matrix, + calculate_contingency_matrix, +) from unittests import BATCH_SIZE from unittests.helpers import seed_all @@ -76,3 +80,19 @@ def test_multidimensional_contingency_error(): """Check that contingency matrix is not calculated for multidimensional input.""" with pytest.raises(ValueError, match="Expected 1d*"): calculate_contingency_matrix(_multi_dim_inputs.preds, _multi_dim_inputs.target) + + +@pytest.mark.parametrize( + ("preds", "target"), + [(_sklearn_inputs.preds, _sklearn_inputs.target), (_single_dim_inputs.preds, _single_dim_inputs.target)], +) +class TestPairClusterConfusionMatrix: + """Test that implementation matches sklearns.""" + + atol = 1e-8 + + def test_pair_cluster_confusion_matrix(self, preds, target): + """Check that pair cluster confusion matrix is calculated correctly.""" + tm_res = calcualte_pair_cluster_confusion_matrix(preds, target) + sklearn_res = sklearn_pair_confusion_matrix(preds, target) + assert np.allclose(tm_res, sklearn_res, atol=self.atol) From c272ee3d207951a21187687823bf985d5b23dafb Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 25 Aug 2023 19:55:56 +0200 Subject: [PATCH 36/44] fixes to doctest + docstrings --- .../clustering/mutual_info_score.py | 4 +-- src/torchmetrics/clustering/rand_score.py | 20 +++++++------- .../clustering/mutual_info_score.py | 4 +-- .../functional/clustering/rand_score.py | 27 ++++++++++--------- tests/unittests/utilities/test_plot.py | 3 +++ 5 files changed, 32 insertions(+), 26 deletions(-) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index 86118daf41c..e943c2aec27 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -41,8 +41,8 @@ class MutualInfoScore(Metric): As input to ``forward`` and ``update`` the metric accepts the following input: - - ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)`` - - ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)`` + - ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels + - ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels As output of ``forward`` and ``compute`` the metric returns the following output: diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py index fd87706864f..49341b00464 100644 --- a/src/torchmetrics/clustering/rand_score.py +++ b/src/torchmetrics/clustering/rand_score.py @@ -38,12 +38,12 @@ class RandScore(Metric): As input to ``forward`` and ``update`` the metric accepts the following input: - - ``preds`` (:class:`~torch.Tensor`): either single integer tensor with shape ``(N,)`` - - ``target`` (:class:`~torch.Tensor`): either single integer tensor with shape ``(N,)`` + - ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels + - ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels As output of ``forward`` and ``compute`` the metric returns the following output: - - ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Rand Score + - ``rand_score`` (:class:`~torch.Tensor`): A tensor with the Rand Score Args: kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -53,9 +53,9 @@ class RandScore(Metric): >>> from torchmetrics.clustering import RandScore >>> preds = torch.tensor([2, 1, 0, 1, 0]) >>> target = torch.tensor([0, 2, 1, 1, 0]) - >>> mi_score = RandScore() - >>> mi_score(preds, target) - tensor(0.5004) + >>> metric = RandScore() + >>> metric(preds, target) + tensor(0.6000) """ @@ -102,8 +102,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting a single value >>> import torch - >>> from torchmetrics.clustering import MutualInfoScore - >>> metric = MutualInfoScore() + >>> from torchmetrics.clustering import RandScore + >>> metric = RandScore() >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) >>> fig_, ax_ = metric.plot(metric.compute()) @@ -112,8 +112,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting multiple values >>> import torch - >>> from torchmetrics.clustering import MutualInfoScore - >>> metric = MutualInfoScore() + >>> from torchmetrics.clustering import RandScore + >>> metric = RandScore() >>> for _ in range(10): ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) >>> fig_, ax_ = metric.plot(metric.compute()) diff --git a/src/torchmetrics/functional/clustering/mutual_info_score.py b/src/torchmetrics/functional/clustering/mutual_info_score.py index f7c7cbfa587..a729726436e 100644 --- a/src/torchmetrics/functional/clustering/mutual_info_score.py +++ b/src/torchmetrics/functional/clustering/mutual_info_score.py @@ -64,8 +64,8 @@ def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor: """Compute mutual information between two clusterings. Args: - preds: predicted classes - target: ground truth classes + preds: predicted cluster labels + target: ground truth cluster labels Example: >>> from torchmetrics.functional.clustering import mutual_info_score diff --git a/src/torchmetrics/functional/clustering/rand_score.py b/src/torchmetrics/functional/clustering/rand_score.py index 6c5b96cdf3b..98e7f0258ea 100644 --- a/src/torchmetrics/functional/clustering/rand_score.py +++ b/src/torchmetrics/functional/clustering/rand_score.py @@ -14,7 +14,11 @@ import torch from torch import Tensor -from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels +from torchmetrics.functional.clustering.utils import ( + calcualte_pair_cluster_confusion_matrix, + calculate_contingency_matrix, + check_cluster_labels, +) def _rand_score_update(preds: Tensor, target: Tensor) -> Tensor: @@ -42,16 +46,7 @@ def _rand_score_compute(contingency: Tensor) -> Tensor: rand_score: rand score """ - n_samples = contingency.sum() - n_c = contingency.sum(dim=1) - n_k = contingency.sum(dim=0) - sum_squared = (contingency**2).sum() - - pair_matrix = torch.zeros(2, 2, dtype=contingency.dtype, device=contingency.device) - pair_matrix[1, 1] = sum_squared - n_samples - pair_matrix[0, 1] = (contingency * n_k).sum() - sum_squared - pair_matrix[1, 0] = (contingency.T * n_c).sum() - sum_squared - pair_matrix[0, 0] = n_samples**2 - pair_matrix[0, 1] - pair_matrix[1, 0] - sum_squared + pair_matrix = calcualte_pair_cluster_confusion_matrix(contingency=contingency) numerator = pair_matrix.diagonal().sum() denominator = pair_matrix.sum() @@ -59,7 +54,7 @@ def _rand_score_compute(contingency: Tensor) -> Tensor: # Special limit cases: no clustering since the data is not split; # or trivial clustering where each document is assigned a unique # cluster. These are perfect matches hence return 1.0. - return 1.0 + return torch.ones_like(numerator, dtype=torch.float32) return numerator / denominator @@ -74,6 +69,14 @@ def rand_score(preds: Tensor, target: Tensor) -> Tensor: Returns: rand_score: rand score + Example: + >>> from torchmetrics.functional.clustering import rand_score + >>> import torch + >>> rand_score(torch.tensor([0, 0, 1, 1]), torch.tensor([1, 1, 0, 0])) + tensor(1.) + >>> rand_score(torch.tensor([0, 0, 1, 2]), torch.tensor([0, 0, 1, 1])) + tensor(0.8333) + """ contingency = _rand_score_update(preds, target) return _rand_score_compute(contingency) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index b47b06da7f8..f5c7d8d4562 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -91,6 +91,7 @@ MultilabelROC, MultilabelSpecificity, ) +from torchmetrics.clustering import MutualInfoScore, RandScore from torchmetrics.detection import PanopticQuality from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio @@ -614,6 +615,8 @@ id="squad", ), pytest.param(TranslationEditRate, _text_input_3, _text_input_4, id="translation edit rate"), + pytest.param(MutualInfoScore, _nominal_input, _nominal_input, id="mutual info score"), + pytest.param(RandScore, _nominal_input, _nominal_input, id="rand score"), ], ) @pytest.mark.parametrize("num_vals", [1, 3]) From e733451e9a9323bf7c9f229c6f9d0e8a96381a0b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Fri, 25 Aug 2023 19:58:52 +0200 Subject: [PATCH 37/44] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b23ce58d355..25743bfb0dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008) +- Added `RandScore` metric to cluster package ([#2025](https://github.com/Lightning-AI/torchmetrics/pull/2025) + + ### Changed - From 741f5df758380c12d7bcb07766dc0b553c662ef1 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 25 Aug 2023 20:22:35 +0200 Subject: [PATCH 38/44] fix --- src/torchmetrics/functional/clustering/rand_score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/clustering/rand_score.py b/src/torchmetrics/functional/clustering/rand_score.py index 98e7f0258ea..2e6848f46d5 100644 --- a/src/torchmetrics/functional/clustering/rand_score.py +++ b/src/torchmetrics/functional/clustering/rand_score.py @@ -67,7 +67,7 @@ def rand_score(preds: Tensor, target: Tensor) -> Tensor: target: ground truth cluster labels Returns: - rand_score: rand score + scalar tensor with the rand score Example: >>> from torchmetrics.functional.clustering import rand_score From b75c0b3332703b4e8e4064637a575059d5c33018 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sat, 26 Aug 2023 10:09:33 +0200 Subject: [PATCH 39/44] Update tests/unittests/clustering/test_rand_score.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- tests/unittests/clustering/test_rand_score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/clustering/test_rand_score.py b/tests/unittests/clustering/test_rand_score.py index a5a6df11267..9ff11b1d6c5 100644 --- a/tests/unittests/clustering/test_rand_score.py +++ b/tests/unittests/clustering/test_rand_score.py @@ -87,7 +87,7 @@ def test_rand_score_functional_raises_invalid_task(): @pytest.mark.parametrize( ("preds", "target"), [ - (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs1.preds, _single_target_inputs1.target) ], ) def test_rand_score_functional_is_symmetric(preds, target): From 30dfbde4173f69a5b7d9da463068f429fa4fa5d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 26 Aug 2023 08:10:07 +0000 Subject: [PATCH 40/44] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/clustering/test_rand_score.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/unittests/clustering/test_rand_score.py b/tests/unittests/clustering/test_rand_score.py index 9ff11b1d6c5..d00fd421d34 100644 --- a/tests/unittests/clustering/test_rand_score.py +++ b/tests/unittests/clustering/test_rand_score.py @@ -86,9 +86,7 @@ def test_rand_score_functional_raises_invalid_task(): @pytest.mark.parametrize( ("preds", "target"), - [ - (_single_target_inputs1.preds, _single_target_inputs1.target) - ], + [(_single_target_inputs1.preds, _single_target_inputs1.target)], ) def test_rand_score_functional_is_symmetric(preds, target): """Check that the metric funtional is symmetric.""" From f4386c5eedc668c986d77f5dd30c6e1ba1ab6dd1 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sat, 26 Aug 2023 19:21:32 +0200 Subject: [PATCH 41/44] Update src/torchmetrics/clustering/rand_score.py Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> --- src/torchmetrics/clustering/rand_score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py index 49341b00464..ba522b39062 100644 --- a/src/torchmetrics/clustering/rand_score.py +++ b/src/torchmetrics/clustering/rand_score.py @@ -26,7 +26,7 @@ class RandScore(Metric): - r"""Compute `Rand Score`_ (alternative know as Rand Index). + r"""Compute `Rand Score`_ (alternatively known as Rand Index). .. math:: RS(U, V) = \text{number of agreeing pairs} / \text{number of pairs} From a0581fa339873abee682bd3003443165c737cfcd Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sat, 26 Aug 2023 19:21:41 +0200 Subject: [PATCH 42/44] Update src/torchmetrics/clustering/rand_score.py Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> --- src/torchmetrics/clustering/rand_score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py index ba522b39062..a7fa5bb83f8 100644 --- a/src/torchmetrics/clustering/rand_score.py +++ b/src/torchmetrics/clustering/rand_score.py @@ -79,7 +79,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.target.append(target) def compute(self) -> Tensor: - """Compute mutual information over state.""" + """Compute rand score over state.""" return rand_score(dim_zero_cat(self.preds), dim_zero_cat(self.target)) def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: From 10c579b55c7509527d0e7b0d0ea1c4d5fc008d54 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 26 Aug 2023 19:24:50 +0200 Subject: [PATCH 43/44] fix mypy --- src/torchmetrics/functional/clustering/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/torchmetrics/functional/clustering/utils.py b/src/torchmetrics/functional/clustering/utils.py index 048a3a0d106..c50a2b03f5b 100644 --- a/src/torchmetrics/functional/clustering/utils.py +++ b/src/torchmetrics/functional/clustering/utils.py @@ -155,6 +155,9 @@ def calcualte_pair_cluster_confusion_matrix( if preds is not None and target is not None: contingency = calculate_contingency_matrix(preds, target) + if contingency is None: + raise ValueError("Must provide `contingency` if `preds` and `target` are not provided.") + n_samples = contingency.sum() n_c = contingency.sum(dim=1) n_k = contingency.sum(dim=0) From 90bf3b7b90ef64ae3e3c8603b6bf063166c96240 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 28 Aug 2023 11:11:07 +0200 Subject: [PATCH 44/44] skip on too much memory --- tests/unittests/__init__.py | 12 ++++++++++- tests/unittests/conftest.py | 20 +++++++++++++++++++ .../image/test_perceptual_path_length.py | 3 +++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/unittests/__init__.py b/tests/unittests/__init__.py index bed0bda5147..e77f74161bb 100644 --- a/tests/unittests/__init__.py +++ b/tests/unittests/__init__.py @@ -3,7 +3,16 @@ import numpy import torch -from unittests.conftest import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, NUM_PROCESSES, THRESHOLD, setup_ddp +from unittests.conftest import ( + BATCH_SIZE, + EXTRA_DIM, + NUM_BATCHES, + NUM_CLASSES, + NUM_PROCESSES, + THRESHOLD, + setup_ddp, + skip_on_running_out_of_memory, +) # adding compatibility for numpy >= 1.24 for tp_name, tp_ins in [("object", object), ("bool", bool), ("int", int), ("float", float)]: @@ -25,4 +34,5 @@ "NUM_PROCESSES", "THRESHOLD", "setup_ddp", + "skip_on_running_out_of_memory", ] diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index 61c1b6fd864..90c53a387f7 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -14,6 +14,8 @@ import contextlib import os import sys +from functools import wraps +from typing import Any, Callable, Optional import pytest import torch @@ -69,3 +71,21 @@ def pytest_sessionfinish(): """ pytest.pool.close() pytest.pool.join() + + +def skip_on_running_out_of_memory(reason: str = "Skipping test as it ran out of memory."): + """Handle tests that sometimes runs out of memory, by simply skipping them.""" + + def test_decorator(function: Callable, *args: Any, **kwargs: Any) -> Optional[Callable]: + @wraps(function) + def run_test(*args: Any, **kwargs: Any) -> Optional[Any]: + try: + return function(*args, **kwargs) + except RuntimeError as ex: + if "DefaultCPUAllocator: not enough memory:" not in str(ex): + raise ex + pytest.skip(reason) + + return run_test + + return test_decorator diff --git a/tests/unittests/image/test_perceptual_path_length.py b/tests/unittests/image/test_perceptual_path_length.py index 8535f74a524..0f76ce51372 100644 --- a/tests/unittests/image/test_perceptual_path_length.py +++ b/tests/unittests/image/test_perceptual_path_length.py @@ -24,6 +24,7 @@ from torchmetrics.image.perceptual_path_length import PerceptualPathLength from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE +from unittests import skip_on_running_out_of_memory from unittests.helpers import seed_all seed_all(42) @@ -42,6 +43,7 @@ def test_interpolation_methods(interpolation_method): @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity") +@skip_on_running_out_of_memory() def test_sim_net(): """Check that the similiarity network is the same as the one used in torch_fidelity.""" compare = SampleSimilarityLPIPS("sample_similarity", resize=64) @@ -113,6 +115,7 @@ def sample(self, num_samples): ({"upper_discard": 2}, "Argument `upper_discard` must be a float between 0 and 1 or `None`, but got 2"), ], ) +@skip_on_running_out_of_memory() def test_raises_error_on_wrong_arguments(argument, match): """Test that appropriate errors are raised on wrong arguments.""" with pytest.raises(ValueError, match=match):