From 944c79607ea74b6fbd04624caf735fc68d939a4f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 15 Mar 2021 20:28:18 +0100 Subject: [PATCH] Prune metrics base classes 2/n (#6530) * base class * extensions * chlog * _stable_1d_sort * _check_same_shape * _input_format_classification_one_hot * utils * to_onehot * select_topk * to_categorical * get_num_classes * reduce * class_reduce * tests (cherry picked from commit 6453091b8ab3713e2d58bad7acc9a4345dc5d10b) --- .../basic_examples/conv_sequential_example.py | 4 +- pytorch_lightning/accelerators/gpu.py | 2 +- pytorch_lightning/core/step_result.py | 2 +- pytorch_lightning/metrics/compositional.py | 100 +--- .../metrics/functional/classification.py | 21 + pytorch_lightning/metrics/functional/psnr.py | 8 +- pytorch_lightning/metrics/metric.py | 464 +----------------- pytorch_lightning/metrics/utils.py | 3 + pytorch_lightning/trainer/callback_hook.py | 2 +- .../trainer/connectors/callback_connector.py | 7 +- .../logger_connector/metrics_holder.py | 3 +- requirements.txt | 1 + tests/accelerators/test_dp.py | 2 +- tests/deprecated_api/test_remove_1-5.py | 2 +- tests/metrics/classification/test_inputs.py | 2 +- .../metrics/functional/test_classification.py | 2 +- tests/metrics/functional/test_reduction.py | 3 +- tests/metrics/test_metric_lightning.py | 8 +- 18 files changed, 98 insertions(+), 538 deletions(-) diff --git a/pl_examples/basic_examples/conv_sequential_example.py b/pl_examples/basic_examples/conv_sequential_example.py index b558020838cdb..6cfb6109f04fc 100644 --- a/pl_examples/basic_examples/conv_sequential_example.py +++ b/pl_examples/basic_examples/conv_sequential_example.py @@ -189,6 +189,7 @@ def instantiate_datamodule(args): ]) cifar10_dm = pl_bolts.datamodules.CIFAR10DataModule( + data_dir=args.data_dir, batch_size=args.batch_size, train_transforms=train_transforms, test_transforms=test_transforms, @@ -206,6 +207,7 @@ def instantiate_datamodule(args): parser = ArgumentParser(description="Pipe Example") parser.add_argument("--use_rpc_sequential", action="store_true") + parser.add_argument("--manual_optimization", action="store_true") parser = Trainer.add_argparse_args(parser) parser = pl_bolts.datamodules.CIFAR10DataModule.add_argparse_args(parser) args = parser.parse_args() @@ -216,7 +218,7 @@ def instantiate_datamodule(args): if args.use_rpc_sequential: plugins = RPCSequentialPlugin() - model = LitResnet(batch_size=args.batch_size, manual_optimization=not args.automatic_optimization) + model = LitResnet(batch_size=args.batch_size, manual_optimization=args.manual_optimization) trainer = pl.Trainer.from_argparse_args(args, plugins=[plugins] if plugins else None) trainer.fit(model, cifar10_dm) diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index af9ce25f902b3..5c5dc5cc6f531 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -1,6 +1,6 @@ import logging import os -from typing import TYPE_CHECKING, Any +from typing import Any, TYPE_CHECKING import torch diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index f8d7a2ffe3a23..3961586f4946a 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -20,8 +20,8 @@ import torch from torch import Tensor +from torchmetrics import Metric -from pytorch_lightning.metrics import Metric from pytorch_lightning.utilities.distributed import sync_ddp_if_available diff --git a/pytorch_lightning/metrics/compositional.py b/pytorch_lightning/metrics/compositional.py index df98d16a3ef7e..d51332c43b6b4 100644 --- a/pytorch_lightning/metrics/compositional.py +++ b/pytorch_lightning/metrics/compositional.py @@ -1,14 +1,30 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Callable, Union import torch +from torchmetrics.metric import CompositionalMetric as _CompositionalMetric -from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics import Metric +from pytorch_lightning.utilities import rank_zero_warn -class CompositionalMetric(Metric): - """Composition of two metrics with a specific operator - which will be executed upon metric's compute +class CompositionalMetric(_CompositionalMetric): + r""" + This implementation refers to :class:`~torchmetrics.metric.CompositionalMetric`. + .. warning:: This metric is deprecated, use ``torchmetrics.metric.CompositionalMetric``. Will be removed in v1.5.0. """ def __init__( @@ -17,76 +33,8 @@ def __init__( metric_a: Union[Metric, int, float, torch.Tensor], metric_b: Union[Metric, int, float, torch.Tensor, None], ): - """ - - Args: - operator: the operator taking in one (if metric_b is None) - or two arguments. Will be applied to outputs of metric_a.compute() - and (optionally if metric_b is not None) metric_b.compute() - metric_a: first metric whose compute() result is the first argument of operator - metric_b: second metric whose compute() result is the second argument of operator. - For operators taking in only one input, this should be None - """ - super().__init__() - - self.op = operator - - if isinstance(metric_a, torch.Tensor): - self.register_buffer("metric_a", metric_a) - else: - self.metric_a = metric_a - - if isinstance(metric_b, torch.Tensor): - self.register_buffer("metric_b", metric_b) - else: - self.metric_b = metric_b - - def _sync_dist(self, dist_sync_fn=None): - # No syncing required here. syncing will be done in metric_a and metric_b - pass - - def update(self, *args, **kwargs): - if isinstance(self.metric_a, Metric): - self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs)) - - if isinstance(self.metric_b, Metric): - self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs)) - - def compute(self): - - # also some parsing for kwargs? - if isinstance(self.metric_a, Metric): - val_a = self.metric_a.compute() - else: - val_a = self.metric_a - - if isinstance(self.metric_b, Metric): - val_b = self.metric_b.compute() - else: - val_b = self.metric_b - - if val_b is None: - return self.op(val_a) - - return self.op(val_a, val_b) - - def reset(self): - if isinstance(self.metric_a, Metric): - self.metric_a.reset() - - if isinstance(self.metric_b, Metric): - self.metric_b.reset() - - def persistent(self, mode: bool = False): - if isinstance(self.metric_a, Metric): - self.metric_a.persistent(mode=mode) - if isinstance(self.metric_b, Metric): - self.metric_b.persistent(mode=mode) - - def __repr__(self): - repr_str = ( - self.__class__.__name__ - + f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)" + rank_zero_warn( + "This `Metric` was deprecated since v1.3.0 in favor of `torchmetrics.Metric`." + " It will be removed in v1.5.0", DeprecationWarning ) - - return repr_str + super().__init__(operator=operator, metric_a=metric_a, metric_b=metric_b) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index fae9e0770f88d..7281ca3f83717 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -123,6 +123,7 @@ def stat_scores( return tp, fp, tn, fn, sup +# todo: remove in 1.4 def stat_scores_multiple_classes( pred: torch.Tensor, target: torch.Tensor, @@ -136,6 +137,9 @@ def stat_scores_multiple_classes( .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.stat_scores` + Raises: + ValueError: + If ``reduction`` is not one of ``"none"``, ``"sum"`` or ``"elementwise_mean"``. """ rank_zero_warn( @@ -211,6 +215,7 @@ def _confmat_normalize(cm): return cm +# todo: remove in 1.4 def precision_recall( pred: torch.Tensor, target: torch.Tensor, @@ -269,6 +274,7 @@ def precision_recall( return precision, recall +# todo: remove in 1.4 def precision( pred: torch.Tensor, target: torch.Tensor, @@ -312,6 +318,7 @@ def precision( return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[0] +# todo: remove in 1.4 def recall( pred: torch.Tensor, target: torch.Tensor, @@ -509,6 +516,7 @@ def auc( return __auc(x, y) +# todo: remove in 1.4 def auc_decorator() -> Callable: rank_zero_warn("This `auc_decorator` was deprecated in v1.2.0." " It will be removed in v1.4.0", DeprecationWarning) @@ -525,6 +533,7 @@ def new_func(*args, **kwargs) -> torch.Tensor: return wrapper +# todo: remove in 1.4 def multiclass_auc_decorator() -> Callable: rank_zero_warn( "This `multiclass_auc_decorator` was deprecated in v1.2.0." @@ -547,6 +556,7 @@ def new_func(*args, **kwargs) -> torch.Tensor: return wrapper +# todo: remove in 1.4 def auroc( pred: torch.Tensor, target: torch.Tensor, @@ -589,6 +599,7 @@ def auroc( ) +# todo: remove in 1.4 def multiclass_auroc( pred: torch.Tensor, target: torch.Tensor, @@ -612,6 +623,16 @@ def multiclass_auroc( Return: Tensor containing ROCAUC score + Raises: + ValueError: + If ``pred`` don't sum up to ``1`` over classes for ``Multiclass AUROC``. + ValueError: + If number of classes found in ``target`` does not equal the number of + columns in ``pred``. + ValueError: + If number of classes deduced from ``pred`` does not equal the number of + classes passed in ``num_classes``. + Example: >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], diff --git a/pytorch_lightning/metrics/functional/psnr.py b/pytorch_lightning/metrics/functional/psnr.py index 434b2ae60218d..87ee1a93c9d19 100644 --- a/pytorch_lightning/metrics/functional/psnr.py +++ b/pytorch_lightning/metrics/functional/psnr.py @@ -15,8 +15,8 @@ import torch -from pytorch_lightning import utilities -from pytorch_lightning.metrics import utils +from pytorch_lightning.metrics.utils import reduce +from pytorch_lightning.utilities import rank_zero_warn def _psnr_compute( @@ -28,7 +28,7 @@ def _psnr_compute( ) -> torch.Tensor: psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs) psnr = psnr_base_e * (10 / torch.log(torch.tensor(base))) - return utils.reduce(psnr, reduction=reduction) + return reduce(psnr, reduction=reduction) def _psnr_update(preds: torch.Tensor, @@ -93,7 +93,7 @@ def psnr( """ if dim is None and reduction != 'elementwise_mean': - utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') + rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') if data_range is None: if dim is not None: diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index ab198356f7279..730011d998f10 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -11,52 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -import inspect -from abc import ABC, abstractmethod -from collections.abc import Sequence -from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from torch import nn +from torchmetrics import Metric as _Metric +from torchmetrics import MetricCollection as _MetricCollection -from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum -from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import gather_all_tensors +from pytorch_lightning.utilities.distributed import rank_zero_warn -class Metric(nn.Module, ABC): - """ - Base class for all metrics present in the Metrics API. - - Implements ``add_state()``, ``forward()``, ``reset()`` and a few other things to - handle distributed synchronization and per-step metric computation. - - Override ``update()`` and ``compute()`` functions to implement your own metric. Use - ``add_state()`` to register metric state variables which keep track of state on each - call of ``update()`` and are synchronized across processes when ``compute()`` is called. - - Note: - Metric state variables can either be ``torch.Tensors`` or an empty list which can we used - to store `torch.Tensors``. - - Note: - Different metrics only override ``update()`` and not ``forward()``. A call to ``update()`` - is valid, but it won't return the metric value at the current step. A call to ``forward()`` - automatically calls ``update()`` and also returns the metric value at the current step. - - Args: - compute_on_step: - Forward only calls ``update()`` and returns None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None +class Metric(_Metric): + r""" + This implementation refers to :class:`~torchmetrics.Metric`. + + .. warning:: This metric is deprecated, use ``torchmetrics.Metric``. Will be removed in v1.5.0. """ def __init__( @@ -66,356 +34,78 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__() - - self.dist_sync_on_step = dist_sync_on_step - self.compute_on_step = compute_on_step - self.process_group = process_group - self.dist_sync_fn = dist_sync_fn - self._to_sync = True - - self._update_signature = inspect.signature(self.update) - self.update = self._wrap_update(self.update) - self.compute = self._wrap_compute(self.compute) - self._computed = None - self._forward_cache = None - - # initialize state - self._defaults = {} - self._persistent = {} - self._reductions = {} - - def add_state( - self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = False - ): - """ - Adds metric state variable. Only used by subclasses. - - Args: - name: The name of the state variable. The variable will then be accessible at ``self.name``. - default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be - reset to this value when ``self.reset()`` is called. - dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode. - If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, - and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction - only makes sense if the state is a list, and not a tensor. The user can also pass a custom - function in this parameter. - persistent (Optional): whether the state will be saved as part of the modules ``state_dict``. - Default is ``False``. - - Note: - Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes. - However, there won't be any reduction function applied to the synchronized metric state. - - The metric states would be synced as follows - - - If the metric state is ``torch.Tensor``, the synced value will be a stacked ``torch.Tensor`` across - the process dimension if the metric state was a ``torch.Tensor``. The original ``torch.Tensor`` metric - state retains dimension and hence the synchronized output will be of shape ``(num_process, ...)``. - - - If the metric state is a ``list``, the synced value will be a ``list`` containing the - combined elements from all processes. - - Note: - When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow - the format discussed in the above note. - - """ - if ( - not isinstance(default, torch.Tensor) and not isinstance(default, list) # noqa: W503 - or (isinstance(default, list) and len(default) != 0) # noqa: W503 - ): - raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)") - - if dist_reduce_fx == "sum": - dist_reduce_fx = dim_zero_sum - elif dist_reduce_fx == "mean": - dist_reduce_fx = dim_zero_mean - elif dist_reduce_fx == "cat": - dist_reduce_fx = dim_zero_cat - elif dist_reduce_fx is not None and not isinstance(dist_reduce_fx, Callable): - raise ValueError("`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]") - - setattr(self, name, default) - - self._defaults[name] = deepcopy(default) - self._persistent[name] = persistent - self._reductions[name] = dist_reduce_fx - - @torch.jit.unused - def forward(self, *args, **kwargs): - """ - Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True. - """ - # add current step - with torch.no_grad(): - self.update(*args, **kwargs) - self._forward_cache = None - - if self.compute_on_step: - self._to_sync = self.dist_sync_on_step - - # save context before switch - cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} - - # call reset, update, compute, on single batch - self.reset() - self.update(*args, **kwargs) - self._forward_cache = self.compute() - - # restore context - for attr, val in cache.items(): - setattr(self, attr, val) - self._to_sync = True - self._computed = None - - return self._forward_cache - - def _sync_dist(self, dist_sync_fn=gather_all_tensors): - input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()} - output_dict = apply_to_collection( - input_dict, - torch.Tensor, - dist_sync_fn, - group=self.process_group, + rank_zero_warn( + "This `Metric` was deprecated since v1.3.0 in favor of `torchmetrics.Metric`." + " It will be removed in v1.5.0", DeprecationWarning + ) + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, ) - - for attr, reduction_fn in self._reductions.items(): - # pre-processing ops (stack or flatten for inputs) - if isinstance(output_dict[attr][0], torch.Tensor): - output_dict[attr] = torch.stack(output_dict[attr]) - elif isinstance(output_dict[attr][0], list): - output_dict[attr] = _flatten(output_dict[attr]) - - assert isinstance(reduction_fn, (Callable)) or reduction_fn is None - reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr] - setattr(self, attr, reduced) - - def _wrap_update(self, update): - - @functools.wraps(update) - def wrapped_func(*args, **kwargs): - self._computed = None - return update(*args, **kwargs) - - return wrapped_func - - def _wrap_compute(self, compute): - - @functools.wraps(compute) - def wrapped_func(*args, **kwargs): - # return cached value - if self._computed is not None: - return self._computed - - dist_sync_fn = self.dist_sync_fn - if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized(): - # User provided a bool, so we assume DDP if available - dist_sync_fn = gather_all_tensors - - synced = False - if self._to_sync and dist_sync_fn is not None: - # cache prior to syncing - cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} - - # sync - self._sync_dist(dist_sync_fn) - synced = True - - self._computed = compute(*args, **kwargs) - if synced: - # if we synced, restore to cache so that we can continue to accumulate un-synced state - for attr, val in cache.items(): - setattr(self, attr, val) - - return self._computed - - return wrapped_func - - @abstractmethod - def update(self) -> None: # pylint: disable=E0202 - """ - Override this method to update the state variables of your metric class. - """ - pass - - @abstractmethod - def compute(self): # pylint: disable=E0202 - """ - Override this method to compute the final metric value from state variables - synchronized across the distributed backend. - """ - pass - - def reset(self): - """ - This method automatically resets the metric state variables to their default value. - """ - for attr, default in self._defaults.items(): - current_val = getattr(self, attr) - if isinstance(default, torch.Tensor): - setattr(self, attr, deepcopy(default).to(current_val.device)) - else: - setattr(self, attr, deepcopy(default)) - - def clone(self): - """ Make a copy of the metric """ - return deepcopy(self) - - def __getstate__(self): - # ignore update and compute functions for pickling - return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]} - - def __setstate__(self, state): - # manually restore update and compute functions for pickling - self.__dict__.update(state) - self.update = self._wrap_update(self.update) - self.compute = self._wrap_compute(self.compute) - - def _apply(self, fn): - """Overwrite _apply function such that we can also move metric states - to the correct device when `.to`, `.cuda`, etc methods are called - """ - self = super()._apply(fn) - # Also apply fn to metric states - for key in self._defaults.keys(): - current_val = getattr(self, key) - if isinstance(current_val, torch.Tensor): - setattr(self, key, fn(current_val)) - elif isinstance(current_val, Sequence): - setattr(self, key, [fn(cur_v) for cur_v in current_val]) - else: - raise TypeError( - "Expected metric state to be either a torch.Tensor" - f"or a list of torch.Tensor, but encountered {current_val}" - ) - return self - - def persistent(self, mode: bool = False): - """Method for post-init to change if metric states should be saved to - its state_dict - """ - for key in self._persistent.keys(): - self._persistent[key] = mode - - def state_dict(self, destination=None, prefix='', keep_vars=False): - destination = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - # Register metric states to be part of the state_dict - for key in self._defaults.keys(): - if self._persistent[key]: - current_val = getattr(self, key) - if not keep_vars: - if torch.is_tensor(current_val): - current_val = current_val.detach() - elif isinstance(current_val, list): - current_val = [cur_v.detach() if torch.is_tensor(cur_v) else cur_v for cur_v in current_val] - destination[prefix + key] = current_val - return destination - - def _filter_kwargs(self, **kwargs): - """ filter kwargs such that they match the update signature of the metric """ - - # filter all parameters based on update signature except those of - # type VAR_POSITIONAL (*args) and VAR_KEYWORD (**kwargs) - _params = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) - filtered_kwargs = { - k: v - for k, v in kwargs.items() if k in self._update_signature.parameters.keys() - and self._update_signature.parameters[k].kind not in _params - } - - # if no kwargs filtered, return al kwargs as default - if not filtered_kwargs: - filtered_kwargs = kwargs - return filtered_kwargs def __hash__(self): - hash_vals = [self.__class__.__name__] - - for key in self._defaults.keys(): - val = getattr(self, key) - # Special case: allow list values, so long - # as their elements are hashable - if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor): - hash_vals.extend(val) - else: - hash_vals.append(val) - - return hash(tuple(hash_vals)) + return super().__hash__() def __add__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.add, self, other) def __and__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.bitwise_and, self, other) def __eq__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.eq, self, other) def __floordiv__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.floor_divide, self, other) def __ge__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.ge, self, other) def __gt__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.gt, self, other) def __le__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.le, self, other) def __lt__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.lt, self, other) def __matmul__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.matmul, self, other) def __mod__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.fmod, self, other) def __mul__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.mul, self, other) def __ne__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.ne, self, other) def __or__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.bitwise_or, self, other) def __pow__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.pow, self, other) def __radd__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.add, other, self) def __rand__(self, other: Any): @@ -426,72 +116,58 @@ def __rand__(self, other: Any): def __rfloordiv__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.floor_divide, other, self) def __rmatmul__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.matmul, other, self) def __rmod__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.fmod, other, self) def __rmul__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.mul, other, self) def __ror__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.bitwise_or, other, self) def __rpow__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.pow, other, self) def __rsub__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.sub, other, self) def __rtruediv__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.true_divide, other, self) def __rxor__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.bitwise_xor, other, self) def __sub__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.sub, self, other) def __truediv__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.true_divide, self, other) def __xor__(self, other: Any): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.bitwise_xor, self, other) def __abs__(self): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.abs, self, None) def __inv__(self): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.bitwise_not, self, None) def __invert__(self): @@ -499,12 +175,10 @@ def __invert__(self): def __neg__(self): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(_neg, self, None) def __pos__(self): from pytorch_lightning.metrics.compositional import CompositionalMetric - return CompositionalMetric(torch.abs, self, None) @@ -512,100 +186,16 @@ def _neg(tensor: torch.Tensor): return -torch.abs(tensor) -class MetricCollection(nn.ModuleDict): - """ - MetricCollection class can be used to chain metrics that have the same - call pattern into one single class. - - Args: - metrics: One of the following - - * list or tuple: if metrics are passed in as a list, will use the - metrics class name as key for output dict. Therefore, two metrics - of the same class cannot be chained this way. - - * dict: if metrics are passed in as a dict, will use each key in the - dict as key for output dict. Use this format if you want to chain - together multiple of the same metric with different parameters. - - Example (input as list): - - >>> from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall - >>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) - >>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) - >>> metrics = MetricCollection([Accuracy(), - ... Precision(num_classes=3, average='macro'), - ... Recall(num_classes=3, average='macro')]) - >>> metrics(preds, target) - {'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)} - - Example (input as dict): - - >>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'), - ... 'macro_recall': Recall(num_classes=3, average='macro')}) - >>> metrics(preds, target) - {'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)} +class MetricCollection(_MetricCollection): + r""" + This implementation refers to :class:`~torchmetrics.MetricCollection`. + .. warning:: This metric is deprecated, use ``torchmetrics.MetricCollection``. Will be removed in v1.5.0. """ def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): - super().__init__() - if isinstance(metrics, dict): - # Check all values are metrics - for name, metric in metrics.items(): - if not isinstance(metric, Metric): - raise ValueError( - f"Value {metric} belonging to key {name}" - " is not an instance of `pl.metrics.Metric`" - ) - self[name] = metric - elif isinstance(metrics, (tuple, list)): - for metric in metrics: - if not isinstance(metric, Metric): - raise ValueError( - f"Input {metric} to `MetricCollection` is not a instance" - " of `pl.metrics.Metric`" - ) - name = metric.__class__.__name__ - if name in self: - raise ValueError(f"Encountered two metrics both named {name}") - self[name] = metric - else: - raise ValueError("Unknown input to MetricCollection.") - - def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202 - """ - Iteratively call forward for each metric. Positional arguments (args) will - be passed to every metric in the collection, while keyword arguments (kwargs) - will be filtered based on the signature of the individual metric. - """ - return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} - - def update(self, *args, **kwargs): # pylint: disable=E0202 - """ - Iteratively call update for each metric. Positional arguments (args) will - be passed to every metric in the collection, while keyword arguments (kwargs) - will be filtered based on the signature of the individual metric. - """ - for _, m in self.items(): - m_kwargs = m._filter_kwargs(**kwargs) - m.update(*args, **m_kwargs) - - def compute(self) -> Dict[str, Any]: - return {k: m.compute() for k, m in self.items()} - - def reset(self): - """ Iteratively call reset for each metric """ - for _, m in self.items(): - m.reset() - - def clone(self): - """ Make a copy of the metric collection """ - return deepcopy(self) - - def persistent(self, mode: bool = True): - """Method for post-init to change if metric states should be saved to - its state_dict - """ - for _, m in self.items(): - m.persistent(mode) + rank_zero_warn( + "This `MetricCollection` was deprecated since v1.3.0 in favor of `torchmetrics.MetricCollection`." + " It will be removed in v1.5.0", DeprecationWarning + ) + super().__init__(metrics=metrics) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 7abf260a822ef..5084294bfbf98 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -245,6 +245,9 @@ def class_reduce( - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - ``'none'`` or ``None``: returns calculated metric per class + Raises: + ValueError: + If ``class_reduction`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"`` or ``None``. """ valid_reduction = ("micro", "macro", "weighted", "none", None) if class_reduction == "micro": diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 60e9183ac42f7..d33338055a5b1 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,7 +15,7 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import List, Dict, Any, Type, Callable +from typing import Any, Callable, Dict, List, Type from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 14694c8f77811..ee768c05cc8a2 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -14,12 +14,7 @@ import os from typing import List, Union -from pytorch_lightning.callbacks import ( - Callback, - ModelCheckpoint, - ProgressBar, - ProgressBarBase, -) +from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py index 82f328a927485..554f1d3faf9ed 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -15,8 +15,7 @@ from typing import Any import torch - -from pytorch_lightning.metrics.metric import Metric +from torchmetrics import Metric class MetricsHolder: diff --git a/requirements.txt b/requirements.txt index bdfd6601ba4c2..f196b5e639bf5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ PyYAML>=5.1, !=5.4.* # OmegaConf requirement >=5.1 tqdm>=4.41.0 fsspec[http]>=0.8.1 tensorboard>=2.2.0 +torchmetrics>=0.2.0 diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index 15faf98d94d57..8252aac9e9092 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -23,10 +23,10 @@ from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.core import memory from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import EvalModelTemplate from tests.helpers import BoringModel, RandomDataset from tests.helpers.datamodules import ClassifDataModule from tests.helpers.simple_models import ClassificationModel -from tests.base import EvalModelTemplate class CustomClassificationModelDP(ClassificationModel): diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 415f1d040ba70..c6b2dc24b35ff 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -17,7 +17,7 @@ import pytest -from pytorch_lightning import Trainer, Callback +from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import WandbLogger from tests.helpers import BoringModel from tests.helpers.utils import no_warning_call diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index a78d799b1a07d..2b7be8caa7a0d 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -1,9 +1,9 @@ import pytest import torch from torch import rand, randint +from torchmetrics.utilities.data import select_topk, to_onehot from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType -from pytorch_lightning.metrics.utils import select_topk, to_onehot from tests.metrics.classification.inputs import _input_binary as _bin from tests.metrics.classification.inputs import _input_binary_prob as _bin_prob from tests.metrics.classification.inputs import _input_multiclass as _mc diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 39622c4cd3550..bca50867dcb44 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -1,10 +1,10 @@ import pytest import torch +from torchmetrics.utilities.data import get_num_classes, to_categorical, to_onehot from pytorch_lightning import seed_everything from pytorch_lightning.metrics.functional.classification import dice_score from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve -from pytorch_lightning.metrics.utils import get_num_classes, to_categorical, to_onehot def test_onehot(): diff --git a/tests/metrics/functional/test_reduction.py b/tests/metrics/functional/test_reduction.py index 03a34f6c5a25b..9949c8086a44a 100644 --- a/tests/metrics/functional/test_reduction.py +++ b/tests/metrics/functional/test_reduction.py @@ -1,7 +1,6 @@ import pytest import torch - -from pytorch_lightning.metrics.utils import class_reduce, reduce +from torchmetrics.utilities import class_reduce, reduce def test_reduce(): diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 895305fa9da7e..e52e39cb16488 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -1,11 +1,13 @@ import torch +from torchmetrics import Metric as TMetric from pytorch_lightning import Trainer -from pytorch_lightning.metrics import Metric, MetricCollection +from pytorch_lightning.metrics import Metric as PLMetric +from pytorch_lightning.metrics import MetricCollection from tests.helpers.boring_model import BoringModel -class SumMetric(Metric): +class SumMetric(TMetric): def __init__(self): super().__init__() @@ -18,7 +20,7 @@ def compute(self): return self.x -class DiffMetric(Metric): +class DiffMetric(PLMetric): def __init__(self): super().__init__()