Skip to content

Commit

Permalink
[New metric] Matthews corrcoef (#98)
Browse files Browse the repository at this point in the history
* init implementation

* files

* move file

* init files

* docs

* fix tests

* pep8

* Update CHANGELOG.md

* isort
  • Loading branch information
SkafteNicki authored Mar 23, 2021
1 parent 84cd7ff commit 536357c
Show file tree
Hide file tree
Showing 10 changed files with 325 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `RetrievalMAP` metric for Information Retrieval ([#5032](https://github.com/PyTorchLightning/pytorch-lightning/pull/5032))


- Added `MatthewsCorrcoef` metric ([#98](https://github.com/PyTorchLightning/metrics/pull/98))


### Changed

- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))
Expand Down
5 changes: 5 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ iou [func]
.. autofunction:: torchmetrics.functional.iou
:noindex:

matthews_corrcoef [func]
~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.matthews_corrcoef
:noindex:

roc [func]
~~~~~~~~~~~~~~~~~~~~~
Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ IoU
.. autoclass:: torchmetrics.IoU
:noindex:

MatthewsCorrcoef
~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.MatthewsCorrcoef
:noindex:

Hamming Distance
~~~~~~~~~~~~~~~~

Expand Down
127 changes: 127 additions & 0 deletions tests/classification/test_matthews_corrcoef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import torch
from sklearn.metrics import matthews_corrcoef as sk_matthews_corrcoef

from tests.classification.inputs import _input_binary, _input_binary_prob
from tests.classification.inputs import _input_multiclass as _input_mcls
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mlb
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics.classification.matthews_corrcoef import MatthewsCorrcoef
from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef

torch.manual_seed(42)


def _sk_matthews_corrcoef_binary_prob(preds, target):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


def _sk_matthews_corrcoef_binary(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


def _sk_matthews_corrcoef_multilabel_prob(preds, target):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


def _sk_matthews_corrcoef_multilabel(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


def _sk_matthews_corrcoef_multiclass_prob(preds, target):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


def _sk_matthews_corrcoef_multiclass(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


def _sk_matthews_corrcoef_multidim_multiclass_prob(preds, target):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


def _sk_matthews_corrcoef_multidim_multiclass(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()

return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds)


@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes",
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_matthews_corrcoef_binary_prob, 2),
(_input_binary.preds, _input_binary.target, _sk_matthews_corrcoef_binary, 2),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_matthews_corrcoef_multilabel_prob, 2),
(_input_mlb.preds, _input_mlb.target, _sk_matthews_corrcoef_multilabel, 2),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_matthews_corrcoef_multiclass_prob, NUM_CLASSES),
(_input_mcls.preds, _input_mcls.target, _sk_matthews_corrcoef_multiclass, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_matthews_corrcoef_multidim_multiclass_prob, NUM_CLASSES),
(_input_mdmc.preds, _input_mdmc.target, _sk_matthews_corrcoef_multidim_multiclass, NUM_CLASSES)]
)
class TestMatthewsCorrCoef(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_matthews_corrcoef(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=MatthewsCorrcoef,
sk_metric=sk_metric,
dist_sync_on_step=dist_sync_on_step,
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
}
)

def test_matthews_corrcoef_functional(self, preds, target, sk_metric, num_classes):
self.run_functional_metric_test(
preds,
target,
metric_functional=matthews_corrcoef,
sk_metric=sk_metric,
metric_args={
"num_classes": num_classes,
"threshold": THRESHOLD,
}
)
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
FBeta,
HammingDistance,
IoU,
MatthewsCorrcoef,
Precision,
PrecisionRecallCurve,
Recall,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.classification.f_beta import F1, FBeta # noqa: F401
from torchmetrics.classification.hamming_distance import HammingDistance # noqa: F401
from torchmetrics.classification.iou import IoU # noqa: F401
from torchmetrics.classification.matthews_corrcoef import MatthewsCorrcoef # noqa: F401
from torchmetrics.classification.precision_recall import Precision, Recall # noqa: F401
from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
from torchmetrics.classification.roc import ROC # noqa: F401
Expand Down
114 changes: 114 additions & 0 deletions torchmetrics/classification/matthews_corrcoef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional

import torch
from torch import Tensor

from torchmetrics.functional.classification.matthews_corrcoef import (
_matthews_corrcoef_compute,
_matthews_corrcoef_update,
)
from torchmetrics.metric import Metric


class MatthewsCorrcoef(Metric):
r"""
Calculates `Matthews correlation coefficient
<https://en.wikipedia.org/wiki/Matthews_correlation_coefficient>`_ that measures
the general correlation or quality of a classification. In the binary case it
is defined as:
.. math::
MCC = \frac{TP*TN - FP*FN}{\sqrt{(TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)}}
where TP, TN, FP and FN are respectively the true postitives, true negatives,
false positives and false negatives. Also works in the case of multi-label or
multi-class input.
Note:
This metric produces a multi-dimensional output, so it can not be directly logged.
Forward accepts
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
- ``target`` (long tensor): ``(N, ...)``
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument
to convert into integer labels. This is the case for binary and multi-label probabilities.
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
Args:
num_classes: Number of classes in the dataset.
threshold:
Threshold value for binary or multi-label probabilites. default: 0.5
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather
Example:
>>> from torchmetrics import MatthewsCorrcoef
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> matthews_corrcoef = MatthewsCorrcoef(num_classes=2)
>>> matthews_corrcoef(preds, target)
tensor(0.5774)
"""
def __init__(
self,
num_classes: int,
threshold: float = 0.5,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):

super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.num_classes = num_classes
self.threshold = threshold

self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor):
"""
Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
confmat = _matthews_corrcoef_update(preds, target, self.num_classes, self.threshold)
self.confmat += confmat

def compute(self) -> Tensor:
"""
Computes matthews correlation coefficient
"""
return _matthews_corrcoef_compute(self.confmat)
1 change: 1 addition & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.functional.classification.f_beta import f1, fbeta # noqa: F401
from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401
from torchmetrics.functional.classification.iou import iou # noqa: F401
from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401
from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401
from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401
from torchmetrics.functional.classification.roc import roc # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.functional.classification.f_beta import f1, fbeta # noqa: F401
from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401
from torchmetrics.functional.classification.iou import iou # noqa: F401
from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401
from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401
from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve # noqa: F401
from torchmetrics.functional.classification.roc import roc # noqa: F401
Expand Down
66 changes: 66 additions & 0 deletions torchmetrics/functional/classification/matthews_corrcoef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor

from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update

_matthews_corrcoef_update = _confusion_matrix_update


def _matthews_corrcoef_compute(confmat: Tensor) -> Tensor:
tk = confmat.sum(dim=0).float()
pk = confmat.sum(dim=1).float()
c = torch.trace(confmat).float()
s = confmat.sum().float()
return (c * s - sum(tk * pk)) / (torch.sqrt(s ** 2 - sum(pk * pk)) * torch.sqrt(s ** 2 - sum(tk * tk)))


def matthews_corrcoef(
preds: Tensor,
target: Tensor,
num_classes: int,
threshold: float = 0.5
) -> Tensor:
r"""
Calculates `Matthews correlation coefficient
<https://en.wikipedia.org/wiki/Matthews_correlation_coefficient>`_ that measures
the general correlation or quality of a classification. In the binary case it
is defined as:
.. math::
MCC = \frac{TP*TN - FP*FN}{\sqrt{(TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)}}
where TP, TN, FP and FN are respectively the true postitives, true negatives,
false positives and false negatives. Also works in the case of multi-label or
multi-class input.
Args:
preds: (float or long tensor), Either a ``(N, ...)`` tensor with labels or
``(N, C, ...)`` where C is the number of classes, tensor with labels/probabilities
target: ``target`` (long tensor), tensor with shape ``(N, ...)`` with ground true labels
num_classes: Number of classes in the dataset.
threshold:
Threshold value for binary or multi-label probabilities. default: 0.5
Example:
>>> from torchmetrics.functional import matthews_corrcoef
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> matthews_corrcoef(preds, target, num_classes=2)
tensor(0.5774)
"""
confmat = _matthews_corrcoef_update(preds, target, num_classes, threshold)
return _matthews_corrcoef_compute(confmat)

0 comments on commit 536357c

Please sign in to comment.