Skip to content

Commit

Permalink
add back compatibility for deprecated metrics 1/n (#5067)
Browse files Browse the repository at this point in the history
* add back compatibility for metrics

* tests

* Add deprecated metric utility functions back to functional (#5062)

* add back *deprecated* metric utility functions to functional

* pep

* pep

* suggestions

* move

Co-authored-by: Jirka Borovec <[email protected]>

* more

* fix

* import

* docs

* tests

* fix

Co-authored-by: Teddy Koker <[email protected]>
  • Loading branch information
Borda and teddykoker authored Dec 11, 2020
1 parent ddc3757 commit 4a3f906
Show file tree
Hide file tree
Showing 5 changed files with 306 additions and 12 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Removed `multiclass_roc` and `multiclass_precision_recall_curve`, use `roc` and `precision_recall_curve` instead ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549))
- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))
- WandbLogger does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648))
- `WandbLogger` does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648))
- Changed `automatic_optimization` to be a model attribute ([#4602](https://github.com/PyTorchLightning/pytorch-lightning/pull/4602))
- Changed `Simple Profiler` report to order by percentage time spent + num calls ([#4880](https://github.com/PyTorchLightning/pytorch-lightning/pull/4880))
- Simplify optimization Logic ([#4984](https://github.com/PyTorchLightning/pytorch-lightning/pull/4984))
Expand All @@ -105,6 +104,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Removed

- Removed `reorder` parameter of the `auc` metric ([#5004](https://github.com/PyTorchLightning/pytorch-lightning/pull/5004))
- Removed `multiclass_roc` and `multiclass_precision_recall_curve`, use `roc` and `precision_recall_curve` instead ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549))

### Fixed

Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
auc,
auroc,
dice_score,
get_num_classes,
iou,
multiclass_auroc,
precision,
precision_recall,
recall,
stat_scores,
stat_scores_multiple_classes,
iou,
to_categorical,
to_onehot,
)
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix
# TODO: unify metrics between class and functional, add below
Expand Down
217 changes: 208 additions & 9 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,73 @@

import torch

from pytorch_lightning.metrics.functional.roc import roc
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve
from pytorch_lightning.metrics.utils import to_categorical, get_num_classes, reduce, class_reduce
from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve, precision_recall_curve as __prc
from pytorch_lightning.metrics.functional.roc import roc as __roc
from pytorch_lightning.metrics.utils import (
to_categorical as __tc,
to_onehot as __to,
get_num_classes as __gnc,
reduce,
class_reduce,
)
from pytorch_lightning.utilities import rank_zero_warn


def to_onehot(
tensor: torch.Tensor,
num_classes: Optional[int] = None,
) -> torch.Tensor:
"""
Converts a dense label tensor to one-hot format
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_onehot`
"""
rank_zero_warn(
"This `to_onehot` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.utils import to_onehot`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __to(tensor, num_classes)


def to_categorical(
tensor: torch.Tensor,
argmax_dim: int = 1
) -> torch.Tensor:
"""
Converts a tensor of probabilities to a dense label tensor
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_categorical`
"""
rank_zero_warn(
"This `to_categorical` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.utils import to_categorical`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __tc(tensor)


def get_num_classes(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
) -> int:
"""
Calculates the number of classes for a given prediction and target tensor.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.get_num_classes`
"""
rank_zero_warn(
"This `get_num_classes` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.utils import get_num_classes`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __gnc(pred,target, num_classes)


def stat_scores(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -333,8 +394,79 @@ def recall(
num_classes=num_classes, class_reduction=class_reduction)[1]


# todo: remove in 1.3
def roc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
"""
rank_zero_warn(
"This `multiclass_roc` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.roc import roc`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)


# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
def _roc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 1, 1])
>>> fpr, tpr, thresholds = _roc(x, y)
>>> fpr
tensor([0., 0., 0., 0., 1.])
>>> tpr
tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
>>> thresholds
tensor([4, 3, 2, 1, 0])
"""
rank_zero_warn(
"This `multiclass_roc` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.roc import roc`."
" It will be removed in v1.3.0", DeprecationWarning
)
fps, tps, thresholds = _binary_clf_curve(pred, target, sample_weights=sample_weight, pos_label=pos_label)

# Add an extra threshold position
# to make sure that the curve starts at (0, 0)
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])

if fps[-1] <= 0:
raise ValueError("No negative samples in targets, false positive value should be meaningless")

fpr = fps / fps[-1]

if tps[-1] <= 0:
raise ValueError("No positive samples in targets, true positive value should be meaningless")

tpr = tps / tps[-1]

return fpr, tpr, thresholds


# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
def __multiclass_roc(
def multiclass_roc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
Expand All @@ -343,7 +475,7 @@ def __multiclass_roc(
"""
Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.
.. warning:: Deprecated
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
Args:
pred: estimated probabilities
Expand All @@ -362,19 +494,24 @@ def __multiclass_roc(
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> __multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE
>>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE
((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
"""
rank_zero_warn(
"This `multiclass_roc` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.roc import roc`."
" It will be removed in v1.3.0", DeprecationWarning
)
num_classes = get_num_classes(pred, target, num_classes)

class_roc_vals = []
for c in range(num_classes):
pred_c = pred[:, c]

class_roc_vals.append(roc(preds=pred_c, target=target, sample_weights=sample_weight, pos_label=c, num_classes=1))
class_roc_vals.append(_roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c))

return tuple(class_roc_vals)

Expand Down Expand Up @@ -472,7 +609,7 @@ def auroc(

@auc_decorator()
def _auroc(pred, target, sample_weight, pos_label):
return roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label, num_classes=1)
return _roc(pred, target, sample_weight, pos_label)

return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label)

Expand Down Expand Up @@ -525,7 +662,7 @@ def multiclass_auroc(

@multiclass_auc_decorator()
def _multiclass_auroc(pred, target, sample_weight, num_classes):
return __multiclass_roc(pred=pred, target=target, sample_weight=sample_weight, num_classes=num_classes)
return multiclass_roc(pred, target, sample_weight, num_classes)

class_aurocs = _multiclass_auroc(pred=pred, target=target,
sample_weight=sample_weight,
Expand Down Expand Up @@ -672,3 +809,65 @@ def iou(
])

return reduce(scores, reduction=reduction)


# todo: remove in 1.3
def precision_recall_curve(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
):
"""
Computes precision-recall pairs for different thresholds.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
"""
rank_zero_warn(
"This `precision_recall_curve` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __prc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)


# todo: remove in 1.3
def multiclass_precision_recall_curve(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
num_classes: Optional[int] = None,
):
"""
Computes precision-recall pairs for different thresholds given a multiclass scores.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
"""
rank_zero_warn(
"This `multiclass_precision_recall_curve` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`."
" It will be removed in v1.3.0", DeprecationWarning
)
if num_classes is None:
num_classes = get_num_classes(pred, target, num_classes)
return __prc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes)


# todo: remove in 1.3
def average_precision(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
):
"""
Compute average precision from prediction scores.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.average_precision.average_precision`
"""
rank_zero_warn(
"This `average_precision` was deprecated in v1.1.0 in favor of"
" `pytorch_lightning.metrics.functional.average_precision import average_precision`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __ap(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)
35 changes: 35 additions & 0 deletions pytorch_lightning/metrics/functional/reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from pytorch_lightning.metrics.utils import reduce as __reduce, class_reduce as __cr
from pytorch_lightning.utilities import rank_zero_warn


def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
rank_zero_warn(
"This `reduce` was deprecated in v1.1.0 in favor of"
" `pytorch_lightning.metrics.utils import reduce`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __reduce(to_reduce=to_reduce, reduction=reduction)


def class_reduce(num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = 'none'):
rank_zero_warn(
"This `class_reduce` was deprecated in v1.1.0 in favor of"
" `pytorch_lightning.metrics.utils import class_reduce`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __cr(num=num, denom=denom, weights=weights, class_reduction=class_reduction)
Loading

0 comments on commit 4a3f906

Please sign in to comment.