Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prune metrics base classes 2/n #6530

Merged
merged 15 commits into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505))
- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505),

[#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530),

)


### Removed
Expand Down
2 changes: 1 addition & 1 deletion pl_examples/basic_examples/conv_sequential_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchmetrics.functional import accuracy

import pytorch_lightning as pl
from pl_examples import cli_lightning_logo
from pytorch_lightning import Trainer
from torchmetrics.functional import accuracy
from pytorch_lightning.plugins import RPCSequentialPlugin
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from typing import TYPE_CHECKING, Any
from typing import Any, TYPE_CHECKING

import torch

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/classification/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

import numpy as np
import torch
from torchmetrics.utilities.data import select_topk, to_onehot

from pytorch_lightning.metrics.utils import select_topk, to_onehot
from pytorch_lightning.utilities import LightningEnum


Expand Down
100 changes: 24 additions & 76 deletions pytorch_lightning/metrics/compositional.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Union

import torch
from torchmetrics import Metric
from torchmetrics.metric import CompositionalMetric as __CompositionalMetric

from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import rank_zero_warn


class CompositionalMetric(Metric):
"""Composition of two metrics with a specific operator
which will be executed upon metric's compute
class CompositionalMetric(__CompositionalMetric):
r"""
This implementation refers to :class:`~torchmetrics.metric.CompositionalMetric`.
.. warning:: This metric is deprecated, use ``torchmetrics.metric.CompositionalMetric``. Will be removed in v1.5.0.
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
Expand All @@ -17,76 +33,8 @@ def __init__(
metric_a: Union[Metric, int, float, torch.Tensor],
metric_b: Union[Metric, int, float, torch.Tensor, None],
):
"""
Args:
operator: the operator taking in one (if metric_b is None)
or two arguments. Will be applied to outputs of metric_a.compute()
and (optionally if metric_b is not None) metric_b.compute()
metric_a: first metric whose compute() result is the first argument of operator
metric_b: second metric whose compute() result is the second argument of operator.
For operators taking in only one input, this should be None
"""
super().__init__()

self.op = operator

if isinstance(metric_a, torch.Tensor):
self.register_buffer("metric_a", metric_a)
else:
self.metric_a = metric_a

if isinstance(metric_b, torch.Tensor):
self.register_buffer("metric_b", metric_b)
else:
self.metric_b = metric_b

def _sync_dist(self, dist_sync_fn=None):
# No syncing required here. syncing will be done in metric_a and metric_b
pass

def update(self, *args, **kwargs):
if isinstance(self.metric_a, Metric):
self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs))

if isinstance(self.metric_b, Metric):
self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs))

def compute(self):

# also some parsing for kwargs?
if isinstance(self.metric_a, Metric):
val_a = self.metric_a.compute()
else:
val_a = self.metric_a

if isinstance(self.metric_b, Metric):
val_b = self.metric_b.compute()
else:
val_b = self.metric_b

if val_b is None:
return self.op(val_a)

return self.op(val_a, val_b)

def reset(self):
if isinstance(self.metric_a, Metric):
self.metric_a.reset()

if isinstance(self.metric_b, Metric):
self.metric_b.reset()

def persistent(self, mode: bool = False):
if isinstance(self.metric_a, Metric):
self.metric_a.persistent(mode=mode)
if isinstance(self.metric_b, Metric):
self.metric_b.persistent(mode=mode)

def __repr__(self):
repr_str = (
self.__class__.__name__
+ f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)"
rank_zero_warn(
"This `Metric` was deprecated since v1.3.0 in favor of `torchmetrics.Metric`."
" It will be removed in v1.5.0", DeprecationWarning
Comment on lines +36 to +38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we maybe introduce a temporary decorator/helper function for that? So that we can just forward all init arguments to the base class and have this function raise the warining?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is a great point! thx :]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, I'll prepare it in another PR as well as we would need to limit calling all warnings only once, especially if they are used in functional...

)

return repr_str
super().__init__(operator=operator, metric_a=metric_a, metric_b=metric_b)
3 changes: 1 addition & 2 deletions pytorch_lightning/metrics/functional/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from typing import Tuple

import torch

from pytorch_lightning.metrics.utils import _stable_1d_sort
from torchmetrics.utilities.data import _stable_1d_sort


def _auc_update(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from typing import Callable, Optional, Sequence, Tuple

import torch
from torchmetrics.utilities import class_reduce, reduce
from torchmetrics.utilities.data import get_num_classes, to_categorical

from pytorch_lightning.metrics.functional.auc import auc as __auc
from pytorch_lightning.metrics.functional.auroc import auroc as __auroc
from pytorch_lightning.metrics.functional.iou import iou as __iou
from pytorch_lightning.metrics.utils import class_reduce, get_num_classes, reduce, to_categorical
from pytorch_lightning.utilities import rank_zero_warn


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from typing import Sequence, Tuple, Union

import torch

from pytorch_lightning.metrics.utils import _check_same_shape
from torchmetrics.utilities.checks import _check_same_shape


def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from typing import Tuple

import torch

from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce
from torchmetrics.utilities import class_reduce
from torchmetrics.utilities.checks import _input_format_classification_one_hot


def _fbeta_update(
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/functional/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from typing import Optional

import torch
from torchmetrics.utilities import reduce
from torchmetrics.utilities.data import get_num_classes

from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_update
from pytorch_lightning.metrics.utils import get_num_classes, reduce


def _iou_from_confmat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from typing import Tuple

import torch

from pytorch_lightning.metrics.utils import _check_same_shape
from torchmetrics.utilities.checks import _check_same_shape


def _mean_absolute_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from typing import Tuple

import torch

from pytorch_lightning.metrics.utils import _check_same_shape
from torchmetrics.utilities.checks import _check_same_shape


def _mean_relative_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from typing import Tuple

import torch

from pytorch_lightning.metrics.utils import _check_same_shape
from torchmetrics.utilities.checks import _check_same_shape


def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from typing import Tuple

import torch

from pytorch_lightning.metrics.utils import _check_same_shape
from torchmetrics.utilities.checks import _check_same_shape


def _mean_squared_log_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]:
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/metrics/functional/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from typing import Optional, Tuple, Union

import torch
from torchmetrics.utilities import reduce

from pytorch_lightning import utilities
from pytorch_lightning.metrics import utils
from pytorch_lightning.utilities import rank_zero_warn


def _psnr_compute(
Expand All @@ -28,7 +28,7 @@ def _psnr_compute(
) -> torch.Tensor:
psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs)
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
return utils.reduce(psnr, reduction=reduction)
return reduce(psnr, reduction=reduction)


def _psnr_update(preds: torch.Tensor,
Expand Down Expand Up @@ -97,7 +97,7 @@ def psnr(

"""
if dim is None and reduction != 'elementwise_mean':
utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.')
rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.')

if data_range is None:
if dim is not None:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/functional/r2score.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from typing import Tuple

import torch
from torchmetrics.utilities.checks import _check_same_shape

from pytorch_lightning.metrics.utils import _check_same_shape
from pytorch_lightning.utilities import rank_zero_warn


Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

import torch
from torch.nn import functional as F

from pytorch_lightning.metrics.utils import _check_same_shape, reduce
from torchmetrics.utilities import reduce
from torchmetrics.utilities.checks import _check_same_shape


def _gaussian(kernel_size: int, sigma: int, dtype: torch.dtype, device: torch.device):
Expand Down
Loading