Skip to content

Commit 0281812

Browse files
SkafteNickiBordastancldpre-commit-ci[bot]
authored andcommitted
New metric: Rand Score (Lightning-AI#2025)
Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Daniel Stancl <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka <[email protected]>
1 parent 5e4644a commit 0281812

File tree

16 files changed

+459
-7
lines changed

16 files changed

+459
-7
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414
- Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008)
1515

1616

17+
- Added `RandScore` metric to cluster package ([#2025](https://github.com/Lightning-AI/torchmetrics/pull/2025)
18+
19+
1720
### Changed
1821

1922
-

docs/source/clustering/rand_score.rst

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
.. customcarditem::
2+
:header: Rand Score
3+
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/clustering.svg
4+
:tags: Clustering
5+
6+
.. include:: ../links.rst
7+
8+
##########
9+
Rand Score
10+
##########
11+
12+
Module Interface
13+
________________
14+
15+
.. autoclass:: torchmetrics.clustering.RandScore
16+
:exclude-members: update, compute
17+
18+
Functional Interface
19+
____________________
20+
21+
.. autofunction:: torchmetrics.functional.clustering.rand_score

docs/source/links.rst

+1
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,4 @@
152152
.. _GIOU: https://arxiv.org/abs/1902.09630
153153
.. _Mutual Information Score: https://en.wikipedia.org/wiki/Mutual_information
154154
.. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools
155+
.. _Rand Score: https://link.springer.com/article/10.1007/BF01908075

src/torchmetrics/clustering/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from torchmetrics.clustering.mutual_info_score import MutualInfoScore
15+
from torchmetrics.clustering.rand_score import RandScore
1516

1617
__all__ = [
1718
"MutualInfoScore",
19+
"RandScore",
1820
]

src/torchmetrics/clustering/mutual_info_score.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ class MutualInfoScore(Metric):
4141
4242
As input to ``forward`` and ``update`` the metric accepts the following input:
4343
44-
- ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)``
45-
- ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)``
44+
- ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels
45+
- ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels
4646
4747
As output of ``forward`` and ``compute`` the metric returns the following output:
4848
+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright The Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, List, Optional, Sequence, Union
15+
16+
from torch import Tensor
17+
18+
from torchmetrics.functional.clustering.rand_score import rand_score
19+
from torchmetrics.metric import Metric
20+
from torchmetrics.utilities.data import dim_zero_cat
21+
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
22+
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
23+
24+
if not _MATPLOTLIB_AVAILABLE:
25+
__doctest_skip__ = ["RandScore.plot"]
26+
27+
28+
class RandScore(Metric):
29+
r"""Compute `Rand Score`_ (alternatively known as Rand Index).
30+
31+
.. math::
32+
RS(U, V) = \text{number of agreeing pairs} / \text{number of pairs}
33+
34+
The number of agreeing pairs is every :math:`(i, j)` pair of samples where :math:`i \in U` and :math:`j \in V`
35+
(the predicted and true clusterings, respectively) that are in the same cluster for both clusterings.
36+
37+
The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score.
38+
39+
As input to ``forward`` and ``update`` the metric accepts the following input:
40+
41+
- ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels
42+
- ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels
43+
44+
As output of ``forward`` and ``compute`` the metric returns the following output:
45+
46+
- ``rand_score`` (:class:`~torch.Tensor`): A tensor with the Rand Score
47+
48+
Args:
49+
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
50+
51+
Example:
52+
>>> import torch
53+
>>> from torchmetrics.clustering import RandScore
54+
>>> preds = torch.tensor([2, 1, 0, 1, 0])
55+
>>> target = torch.tensor([0, 2, 1, 1, 0])
56+
>>> metric = RandScore()
57+
>>> metric(preds, target)
58+
tensor(0.6000)
59+
60+
"""
61+
62+
is_differentiable = True
63+
higher_is_better = None
64+
full_state_update: bool = True
65+
plot_lower_bound: float = 0.0
66+
preds: List[Tensor]
67+
target: List[Tensor]
68+
contingency: Tensor
69+
70+
def __init__(self, **kwargs: Any) -> None:
71+
super().__init__(**kwargs)
72+
73+
self.add_state("preds", default=[], dist_reduce_fx="cat")
74+
self.add_state("target", default=[], dist_reduce_fx="cat")
75+
76+
def update(self, preds: Tensor, target: Tensor) -> None:
77+
"""Update state with predictions and targets."""
78+
self.preds.append(preds)
79+
self.target.append(target)
80+
81+
def compute(self) -> Tensor:
82+
"""Compute rand score over state."""
83+
return rand_score(dim_zero_cat(self.preds), dim_zero_cat(self.target))
84+
85+
def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
86+
"""Plot a single or multiple values from the metric.
87+
88+
Args:
89+
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
90+
If no value is provided, will automatically call `metric.compute` and plot that result.
91+
ax: An matplotlib axis object. If provided will add plot to that axis
92+
93+
Returns:
94+
Figure and Axes object
95+
96+
Raises:
97+
ModuleNotFoundError:
98+
If `matplotlib` is not installed
99+
100+
.. plot::
101+
:scale: 75
102+
103+
>>> # Example plotting a single value
104+
>>> import torch
105+
>>> from torchmetrics.clustering import RandScore
106+
>>> metric = RandScore()
107+
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
108+
>>> fig_, ax_ = metric.plot(metric.compute())
109+
110+
.. plot::
111+
:scale: 75
112+
113+
>>> # Example plotting multiple values
114+
>>> import torch
115+
>>> from torchmetrics.clustering import RandScore
116+
>>> metric = RandScore()
117+
>>> for _ in range(10):
118+
... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
119+
>>> fig_, ax_ = metric.plot(metric.compute())
120+
121+
"""
122+
return self._plot(val, ax)

src/torchmetrics/functional/clustering/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score
15+
from torchmetrics.functional.clustering.rand_score import rand_score
1516

16-
__all__ = ["mutual_info_score"]
17+
__all__ = ["mutual_info_score", "rand_score"]

src/torchmetrics/functional/clustering/mutual_info_score.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor:
6464
"""Compute mutual information between two clusterings.
6565
6666
Args:
67-
preds: predicted classes
68-
target: ground truth classes
67+
preds: predicted cluster labels
68+
target: ground truth cluster labels
6969
7070
Example:
7171
>>> from torchmetrics.functional.clustering import mutual_info_score
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright The Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import torch
15+
from torch import Tensor
16+
17+
from torchmetrics.functional.clustering.utils import (
18+
calcualte_pair_cluster_confusion_matrix,
19+
calculate_contingency_matrix,
20+
check_cluster_labels,
21+
)
22+
23+
24+
def _rand_score_update(preds: Tensor, target: Tensor) -> Tensor:
25+
"""Update and return variables required to compute the rand score.
26+
27+
Args:
28+
preds: predicted cluster labels
29+
target: ground truth cluster labels
30+
31+
Returns:
32+
contingency: contingency matrix
33+
34+
"""
35+
check_cluster_labels(preds, target)
36+
return calculate_contingency_matrix(preds, target)
37+
38+
39+
def _rand_score_compute(contingency: Tensor) -> Tensor:
40+
"""Compute the rand score based on the contingency matrix.
41+
42+
Args:
43+
contingency: contingency matrix
44+
45+
Returns:
46+
rand_score: rand score
47+
48+
"""
49+
pair_matrix = calcualte_pair_cluster_confusion_matrix(contingency=contingency)
50+
51+
numerator = pair_matrix.diagonal().sum()
52+
denominator = pair_matrix.sum()
53+
if numerator == denominator or denominator == 0:
54+
# Special limit cases: no clustering since the data is not split;
55+
# or trivial clustering where each document is assigned a unique
56+
# cluster. These are perfect matches hence return 1.0.
57+
return torch.ones_like(numerator, dtype=torch.float32)
58+
59+
return numerator / denominator
60+
61+
62+
def rand_score(preds: Tensor, target: Tensor) -> Tensor:
63+
"""Compute the Rand score between two clusterings.
64+
65+
Args:
66+
preds: predicted cluster labels
67+
target: ground truth cluster labels
68+
69+
Returns:
70+
scalar tensor with the rand score
71+
72+
Example:
73+
>>> from torchmetrics.functional.clustering import rand_score
74+
>>> import torch
75+
>>> rand_score(torch.tensor([0, 0, 1, 1]), torch.tensor([1, 1, 0, 0]))
76+
tensor(1.)
77+
>>> rand_score(torch.tensor([0, 0, 1, 2]), torch.tensor([0, 0, 1, 1]))
78+
tensor(0.8333)
79+
80+
"""
81+
contingency = _rand_score_update(preds, target)
82+
return _rand_score_compute(contingency)

src/torchmetrics/functional/clustering/utils.py

+70
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,73 @@ def check_cluster_labels(preds: Tensor, target: Tensor) -> None:
9999
f"Expected real, discrete values but received {preds.dtype} for"
100100
f"predictions and {target.dtype} for target labels instead."
101101
)
102+
103+
104+
def calcualte_pair_cluster_confusion_matrix(
105+
preds: Optional[Tensor] = None,
106+
target: Optional[Tensor] = None,
107+
contingency: Optional[Tensor] = None,
108+
) -> Tensor:
109+
"""Calculates the pair cluster confusion matrix.
110+
111+
Can either be calculated from predicted cluster labels and target cluster labels or from a pre-computed
112+
contingency matrix. The pair cluster confusion matrix is a 2x2 matrix where that defines the similarity between
113+
two clustering by considering all pairs of samples and counting pairs that are assigned into same or different
114+
clusters in the predicted and target clusterings.
115+
116+
Note that the matrix is not symmetric.
117+
118+
Inspired by:
119+
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.cluster.pair_confusion_matrix.html
120+
121+
Args:
122+
preds: predicted cluster labels
123+
target: ground truth cluster labels
124+
contingency: contingency matrix
125+
126+
Returns:
127+
A 2x2 tensor containing the pair cluster confusion matrix.
128+
129+
Raises:
130+
ValueError:
131+
If neither `preds` and `target` nor `contingency` are provided.
132+
ValueError:
133+
If both `preds` and `target` and `contingency` are provided.
134+
135+
Example:
136+
>>> import torch
137+
>>> from torchmetrics.functional.clustering.utils import calcualte_pair_cluster_confusion_matrix
138+
>>> preds = torch.tensor([0, 0, 1, 1])
139+
>>> target = torch.tensor([1, 1, 0, 0])
140+
>>> calcualte_pair_cluster_confusion_matrix(preds, target)
141+
tensor([[8, 0],
142+
[0, 4]])
143+
>>> preds = torch.tensor([0, 0, 1, 2])
144+
>>> target = torch.tensor([0, 0, 1, 1])
145+
>>> calcualte_pair_cluster_confusion_matrix(preds, target)
146+
tensor([[8, 2],
147+
[0, 2]])
148+
149+
"""
150+
if preds is None and target is None and contingency is None:
151+
raise ValueError("Must provide either `preds` and `target` or `contingency`.")
152+
if preds is not None and target is not None and contingency is not None:
153+
raise ValueError("Must provide either `preds` and `target` or `contingency`, not both.")
154+
155+
if preds is not None and target is not None:
156+
contingency = calculate_contingency_matrix(preds, target)
157+
158+
if contingency is None:
159+
raise ValueError("Must provide `contingency` if `preds` and `target` are not provided.")
160+
161+
n_samples = contingency.sum()
162+
n_c = contingency.sum(dim=1)
163+
n_k = contingency.sum(dim=0)
164+
sum_squared = (contingency**2).sum()
165+
166+
pair_matrix = torch.zeros(2, 2, dtype=contingency.dtype, device=contingency.device)
167+
pair_matrix[1, 1] = sum_squared - n_samples
168+
pair_matrix[1, 0] = (contingency * n_k).sum() - sum_squared
169+
pair_matrix[0, 1] = (contingency.T * n_c).sum() - sum_squared
170+
pair_matrix[0, 0] = n_samples**2 - pair_matrix[0, 1] - pair_matrix[1, 0] - sum_squared
171+
return pair_matrix

tests/unittests/__init__.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,16 @@
33
import numpy
44
import torch
55

6-
from unittests.conftest import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, NUM_PROCESSES, THRESHOLD, setup_ddp
6+
from unittests.conftest import (
7+
BATCH_SIZE,
8+
EXTRA_DIM,
9+
NUM_BATCHES,
10+
NUM_CLASSES,
11+
NUM_PROCESSES,
12+
THRESHOLD,
13+
setup_ddp,
14+
skip_on_running_out_of_memory,
15+
)
716

817
# adding compatibility for numpy >= 1.24
918
for tp_name, tp_ins in [("object", object), ("bool", bool), ("int", int), ("float", float)]:
@@ -25,4 +34,5 @@
2534
"NUM_PROCESSES",
2635
"THRESHOLD",
2736
"setup_ddp",
37+
"skip_on_running_out_of_memory",
2838
]

0 commit comments

Comments
 (0)