Skip to content

Commit

Permalink
Classification metrics overhaul: precision & recall (4/n) (#4842)
Browse files Browse the repository at this point in the history
* Add stuff

* Change metrics documentation layout

* Add stuff

* Add stat scores

* Change testing utils

* Replace len(*.shape) with *.ndim

* More descriptive error message for input formatting

* Replace movedim with permute

* PEP 8 compliance

* WIP

* Add reduce_scores function

* Temporarily add back legacy class_reduce

* Division with float

* PEP 8 compliance

* Remove precision recall

* Replace movedim with permute

* Add back tests

* Add empty newlines

* Add precision recall back

* Add empty line

* Fix permute

* Fix some issues with old versions of PyTorch

* Style changes in error messages

* More error message style improvements

* Fix typo in docs

* Add more descriptive variable names in utils

* Change internal var names

* Revert unwanted changes

* Revert unwanted changes pt 2

* Update metrics interface

* Add top_k parameter

* Add back reduce function

* Add stuff

* PEP3

* Add depreciation

* PEP8

* Deprecate param

* PEP8

* Fix and simplify testing for older PT versions

* Update Changelog

* Remove redundant import

* Add tests to increase coverage

* Remove zero_division

* fix zero_division

* Add zero_div + edge case tests

* Reorder cls metric args

* Add back quotes for is_multiclass

* Add precision_recall and tests

* PEP8

* Fix docs

* Fix docs

* Update

* Change precision_recall output

* PEP8/isort

* Add method _get_final_stats

* Fix depr test

* Add comment to deprecation tests

* isort

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <[email protected]>

* Add typing to test

* Add matc str to pytest.raises

Co-authored-by: chaton <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored Jan 18, 2021
1 parent 1ff6b18 commit c8f605e
Show file tree
Hide file tree
Showing 13 changed files with 1,185 additions and 296 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added missing val/test hooks in `LightningModule` ([#5467](https://github.com/PyTorchLightning/pytorch-lightning/pull/5467))


- `Recall` and `Precision` metrics (and their functional counterparts `recall` and `precision`) can now be generalized to Recall@K and Precision@K with the use of `top_k` parameter ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842))



### Changed

- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
Expand Down
10 changes: 5 additions & 5 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of differ
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])


Using the ``is_multiclass`` parameter
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Using the is_multiclass parameter
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

In some cases, you might have inputs which appear to be (multi-dimensional) multi-class
but are actually binary/multi-label - for example, if both predictions and targets are
Expand Down Expand Up @@ -602,14 +602,14 @@ roc [func]
precision [func]
~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.precision
.. autofunction:: pytorch_lightning.metrics.functional.precision
:noindex:


precision_recall [func]
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.precision_recall
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall
:noindex:


Expand All @@ -623,7 +623,7 @@ precision_recall_curve [func]
recall [func]
~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.recall
.. autofunction:: pytorch_lightning.metrics.functional.recall
:noindex:

select_topk [func]
Expand Down
71 changes: 69 additions & 2 deletions pytorch_lightning/metrics/classification/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Tuple, Optional

import numpy as np
import torch

from pytorch_lightning.metrics.utils import to_onehot, select_topk
Expand Down Expand Up @@ -249,7 +250,7 @@ def _check_classification_inputs(
is_multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <metrics:Using the \\`\\`is_multiclass\\`\\` parameter>`
:ref:`documentation section <metrics:Using the is_multiclass parameter>`
for a more detailed explanation and examples.
Expand Down Expand Up @@ -375,7 +376,7 @@ def _input_format_classification(
is_multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <metrics:Using the \\`\\`is_multiclass\\`\\` parameter>`
:ref:`documentation section <metrics:Using the is_multiclass parameter>`
for a more detailed explanation and examples.
Expand Down Expand Up @@ -437,3 +438,69 @@ def _input_format_classification(
preds, target = preds.squeeze(-1), target.squeeze(-1)

return preds.int(), target.int(), case


def _reduce_stat_scores(
numerator: torch.Tensor,
denominator: torch.Tensor,
weights: Optional[torch.Tensor],
average: str,
mdmc_average: Optional[str],
zero_division: int = 0,
) -> torch.Tensor:
"""
Reduces scores of type ``numerator/denominator`` or
``weights * (numerator/denominator)``, if ``average='weighted'``.
Args:
numerator: A tensor with numerator numbers.
denominator: A tensor with denominator numbers. If a denominator is
negative, the class will be ignored (if averaging), or its score
will be returned as ``nan`` (if ``average=None``).
If the denominator is zero, then ``zero_division`` score will be
used for those elements.
weights:
A tensor of weights to be used if ``average='weighted'``.
average:
The method to average the scores. Should be one of ``'micro'``, ``'macro'``,
``'weighted'``, ``'none'``, ``None`` or ``'samples'``. The behavior
corresponds to `sklearn averaging methods <https://scikit-learn.org/stable/modules/\
model_evaluation.html#multiclass-and-multilabel-classification>`__.
mdmc_average:
The method to average the scores if inputs were multi-dimensional multi-class (MDMC).
Should be either ``'global'`` or ``'samplewise'``. If inputs were not
multi-dimensional multi-class, it should be ``None`` (default).
zero_division:
The value to use for the score if denominator equals zero.
"""
numerator, denominator = numerator.float(), denominator.float()
zero_div_mask = denominator == 0
ignore_mask = denominator < 0

if weights is None:
weights = torch.ones_like(denominator)
else:
weights = weights.float()

numerator = torch.where(zero_div_mask, torch.tensor(float(zero_division), device=numerator.device), numerator)
denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator)
weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights)

if average not in ["micro", "none", None]:
weights = weights / weights.sum(dim=-1, keepdim=True)

scores = weights * (numerator / denominator)

# This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted'
scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores)

if mdmc_average == "samplewise":
scores = scores.mean(dim=0)
ignore_mask = ignore_mask.sum(dim=0).bool()

if average in ["none", None]:
scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores)
else:
scores = scores.sum()

return scores
Loading

0 comments on commit c8f605e

Please sign in to comment.