forked from Lightning-AI/torchmetrics
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
New metric: Adjusted Rand Score (Lightning-AI#2032)
* initial implementation * add init files * add tests * docs * fix doc tests * changelog * fix * change image * fix * use new inputs * Update src/torchmetrics/clustering/adjusted_rand_score.py --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent
bb3cb38
commit 990ab04
Showing
11 changed files
with
318 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
.. customcarditem:: | ||
:header: Adjusted Rand Score | ||
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/default.svg | ||
:tags: Clustering | ||
|
||
.. include:: ../links.rst | ||
|
||
################### | ||
Adjusted Rand Score | ||
################### | ||
|
||
Module Interface | ||
________________ | ||
|
||
.. autoclass:: torchmetrics.clustering.AdjustedRandScore | ||
:exclude-members: update, compute | ||
|
||
Functional Interface | ||
____________________ | ||
|
||
.. autofunction:: torchmetrics.functional.clustering.adjusted_rand_score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright The Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, List, Optional, Sequence, Union | ||
|
||
from torch import Tensor | ||
|
||
from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score | ||
from torchmetrics.metric import Metric | ||
from torchmetrics.utilities.data import dim_zero_cat | ||
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE | ||
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE | ||
|
||
if not _MATPLOTLIB_AVAILABLE: | ||
__doctest_skip__ = ["AdjustedRandScore.plot"] | ||
|
||
|
||
class AdjustedRandScore(Metric): | ||
r"""Compute `Adjusted Rand Score`_ (also known as Adjusted Rand Index). | ||
.. math:: | ||
ARS(U, V) = (\text{RS} - \text{Expected RS}) / (\text{Max RS} - \text{Expected RS}) | ||
The adjusted rand score :math:`\text{ARS}` is in essence the :math:`\text{RS}` (rand score) adjusted for chance. | ||
The score ensures that completly randomly cluster labels have a score close to zero and only a perfect match will | ||
have a score of 1 (up to a permutation of the labels). The adjusted rand score is symmetric, therefore swapping | ||
:math:`U` and :math:`V` yields the same adjusted rand score. | ||
This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not | ||
be available in practice since clustering is generally used for unsupervised learning. | ||
As input to ``forward`` and ``update`` the metric accepts the following input: | ||
- ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels | ||
- ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels | ||
As output of ``forward`` and ``compute`` the metric returns the following output: | ||
- ``adj_rand_score`` (:class:`~torch.Tensor`): Scalar tensor with the adjusted rand score | ||
Args: | ||
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. | ||
Example: | ||
>>> import torch | ||
>>> from torchmetrics.clustering import AdjustedRandScore | ||
>>> metric = AdjustedRandScore() | ||
>>> metric(torch.tensor([0, 0, 1, 1]), torch.tensor([0, 0, 1, 1])) | ||
tensor(1.) | ||
>>> metric(torch.tensor([0, 0, 1, 1]), torch.tensor([0, 1, 0, 1])) | ||
tensor(-0.5000) | ||
""" | ||
|
||
is_differentiable = True | ||
higher_is_better = None | ||
full_state_update: bool = True | ||
plot_lower_bound: float = -0.5 | ||
plot_upper_bound: float = 1.0 | ||
preds: List[Tensor] | ||
target: List[Tensor] | ||
|
||
def __init__(self, **kwargs: Any) -> None: | ||
super().__init__(**kwargs) | ||
|
||
self.add_state("preds", default=[], dist_reduce_fx="cat") | ||
self.add_state("target", default=[], dist_reduce_fx="cat") | ||
|
||
def update(self, preds: Tensor, target: Tensor) -> None: | ||
"""Update state with predictions and targets.""" | ||
self.preds.append(preds) | ||
self.target.append(target) | ||
|
||
def compute(self) -> Tensor: | ||
"""Compute mutual information over state.""" | ||
return adjusted_rand_score(dim_zero_cat(self.preds), dim_zero_cat(self.target)) | ||
|
||
def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: | ||
"""Plot a single or multiple values from the metric. | ||
Args: | ||
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. | ||
If no value is provided, will automatically call `metric.compute` and plot that result. | ||
ax: An matplotlib axis object. If provided will add plot to that axis | ||
Returns: | ||
Figure and Axes object | ||
Raises: | ||
ModuleNotFoundError: | ||
If `matplotlib` is not installed | ||
.. plot:: | ||
:scale: 75 | ||
>>> # Example plotting a single value | ||
>>> import torch | ||
>>> from torchmetrics.clustering import AdjustedRandScore | ||
>>> metric = AdjustedRandScore() | ||
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) | ||
>>> fig_, ax_ = metric.plot(metric.compute()) | ||
.. plot:: | ||
:scale: 75 | ||
>>> # Example plotting multiple values | ||
>>> import torch | ||
>>> from torchmetrics.clustering import AdjustedRandScore | ||
>>> metric = AdjustedRandScore() | ||
>>> for _ in range(10): | ||
... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) | ||
>>> fig_, ax_ = metric.plot(metric.compute()) | ||
""" | ||
return self._plot(val, ax) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
75 changes: 75 additions & 0 deletions
75
src/torchmetrics/functional/clustering/adjusted_rand_score.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright The Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import torch | ||
from torch import Tensor | ||
|
||
from torchmetrics.functional.clustering.utils import ( | ||
calcualte_pair_cluster_confusion_matrix, | ||
calculate_contingency_matrix, | ||
check_cluster_labels, | ||
) | ||
|
||
|
||
def _adjusted_rand_score_update(preds: Tensor, target: Tensor) -> Tensor: | ||
"""Update and return variables required to compute the rand score. | ||
Args: | ||
preds: predicted cluster labels | ||
target: ground truth cluster labels | ||
Returns: | ||
contingency: contingency matrix | ||
""" | ||
check_cluster_labels(preds, target) | ||
return calculate_contingency_matrix(preds, target) | ||
|
||
|
||
def _adjusted_rand_score_compute(contingency: Tensor) -> Tensor: | ||
"""Compute the rand score based on the contingency matrix. | ||
Args: | ||
contingency: contingency matrix | ||
Returns: | ||
rand_score: rand score | ||
""" | ||
(tn, fp), (fn, tp) = calcualte_pair_cluster_confusion_matrix(contingency=contingency) | ||
if fn == 0 and fp == 0: | ||
return torch.ones_like(tn, dtype=torch.float32) | ||
return 2.0 * (tp * tn - fn * fp) / ((tp + fn) * (fn + tn) + (tp + fp) * (fp + tn)) | ||
|
||
|
||
def adjusted_rand_score(preds: Tensor, target: Tensor) -> Tensor: | ||
"""Compute the Adjusted Rand score between two clusterings. | ||
Args: | ||
preds: predicted cluster labels | ||
target: ground truth cluster labels | ||
Returns: | ||
Scalar tensor with adjusted rand score | ||
Example: | ||
>>> from torchmetrics.functional.clustering import adjusted_rand_score | ||
>>> import torch | ||
>>> adjusted_rand_score(torch.tensor([0, 0, 1, 1]), torch.tensor([0, 0, 1, 1])) | ||
tensor(1.) | ||
>>> adjusted_rand_score(torch.tensor([0, 0, 1, 2]), torch.tensor([0, 0, 1, 1])) | ||
tensor(0.5714) | ||
""" | ||
contingency = _adjusted_rand_score_update(preds, target) | ||
return _adjusted_rand_score_compute(contingency) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright The Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import pytest | ||
import torch | ||
from sklearn.metrics import adjusted_rand_score as sklearn_adjusted_rand_score | ||
from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore | ||
from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score | ||
|
||
from unittests.clustering.inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 | ||
from unittests.helpers.testers import MetricTester | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"preds, target", | ||
[ | ||
(_single_target_extrinsic1.preds, _single_target_extrinsic1.target), | ||
(_single_target_extrinsic2.preds, _single_target_extrinsic2.target), | ||
], | ||
) | ||
class TestAdjustedRandScore(MetricTester): | ||
"""Test class for `AdjustedRandScore` metric.""" | ||
|
||
atol = 1e-5 | ||
|
||
@pytest.mark.parametrize("ddp", [True, False]) | ||
def test_adjusted_rand_score(self, preds, target, ddp): | ||
"""Test class implementation of metric.""" | ||
self.run_class_metric_test( | ||
ddp=ddp, | ||
preds=preds, | ||
target=target, | ||
metric_class=AdjustedRandScore, | ||
reference_metric=sklearn_adjusted_rand_score, | ||
) | ||
|
||
def test_rand_score_functional(self, preds, target): | ||
"""Test functional implementation of metric.""" | ||
self.run_functional_metric_test( | ||
preds=preds, | ||
target=target, | ||
metric_functional=adjusted_rand_score, | ||
reference_metric=sklearn_adjusted_rand_score, | ||
) | ||
|
||
|
||
def test_rand_score_functional_raises_invalid_task(): | ||
"""Check that metric rejects continuous-valued inputs.""" | ||
preds, target = _float_inputs_extrinsic | ||
with pytest.raises(ValueError, match=r"Expected *"): | ||
adjusted_rand_score(preds, target) | ||
|
||
|
||
def test_rand_score_functional_is_symmetric( | ||
preds=_single_target_extrinsic1.preds, target=_single_target_extrinsic1.target | ||
): | ||
"""Check that the metric funtional is symmetric.""" | ||
for p, t in zip(preds, target): | ||
assert torch.allclose(adjusted_rand_score(p, t), adjusted_rand_score(t, p)) |
Oops, something went wrong.