-
Notifications
You must be signed in to change notification settings - Fork 423
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into bugfix/no_iter_allowed
- Loading branch information
Showing
14 changed files
with
304 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
.. customcarditem:: | ||
:header: Log Cosh Error | ||
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg | ||
:tags: Regression | ||
|
||
.. include:: ../links.rst | ||
|
||
############## | ||
Log Cosh Error | ||
############## | ||
|
||
Module Interface | ||
________________ | ||
|
||
.. autoclass:: torchmetrics.LogCoshError | ||
:noindex: | ||
|
||
Functional Interface | ||
____________________ | ||
|
||
.. autofunction:: torchmetrics.functional.log_cosh_error | ||
:noindex: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs | ||
from torchmetrics.utilities.checks import _check_same_shape | ||
|
||
|
||
def _unsqueeze_tensors(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: | ||
if preds.ndim == 2: | ||
return preds, target | ||
return preds.unsqueeze(1), target.unsqueeze(1) | ||
|
||
|
||
def _log_cosh_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, Tensor]: | ||
"""Updates and returns variables required to compute LogCosh error. | ||
Checks for same shape of input tensors. | ||
Args: | ||
preds: Predicted tensor | ||
target: Ground truth tensor | ||
Return: | ||
Sum of LogCosh error over examples, and total number of examples | ||
""" | ||
_check_same_shape(preds, target) | ||
_check_data_shape_to_num_outputs(preds, target, num_outputs) | ||
|
||
preds, target = _unsqueeze_tensors(preds, target) | ||
diff = preds - target | ||
sum_log_cosh_error = torch.log((torch.exp(diff) + torch.exp(-diff)) / 2).sum(0).squeeze() | ||
n_obs = torch.tensor(target.shape[0], device=preds.device) | ||
return sum_log_cosh_error, n_obs | ||
|
||
|
||
def _log_cosh_error_compute(sum_log_cosh_error: Tensor, n_obs: Tensor) -> Tensor: | ||
"""Computes Mean Squared Error. | ||
Args: | ||
sum_squared_error: Sum of LogCosh errors over all observations | ||
n_obs: Number of predictions or observations | ||
""" | ||
return (sum_log_cosh_error / n_obs).squeeze() | ||
|
||
|
||
def log_cosh_error(preds: Tensor, target: Tensor) -> Tensor: | ||
r"""Compute the `LogCosh Error`_. | ||
.. math:: \text{LogCoshError} = \log\left(\frac{\exp(\hat{y} - y) + \exp(\hat{y - y})}{2}\right) | ||
Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. | ||
Args: | ||
preds: estimated labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)`` | ||
target: ground truth labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)`` | ||
Return: | ||
Tensor with LogCosh error | ||
Example (single output regression):: | ||
>>> from torchmetrics.functional import log_cosh_error | ||
>>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) | ||
>>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) | ||
>>> log_cosh_error(preds, target) | ||
tensor(0.3523) | ||
Example (multi output regression):: | ||
>>> from torchmetrics.functional import log_cosh_error | ||
>>> preds = torch.tensor([[3.0, 5.0, 1.2], [-2.1, 2.5, 7.0]]) | ||
>>> target = torch.tensor([[2.5, 5.0, 1.3], [0.3, 4.0, 8.0]]) | ||
>>> log_cosh_error(preds, target) | ||
tensor([0.9176, 0.4277, 0.2194]) | ||
""" | ||
sum_log_cosh_error, n_obs = _log_cosh_error_update( | ||
preds, target, num_outputs=1 if preds.ndim == 1 else preds.shape[-1] | ||
) | ||
return _log_cosh_error_compute(sum_log_cosh_error, n_obs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from torchmetrics.functional.regression.log_cosh import _log_cosh_error_compute, _log_cosh_error_update | ||
from torchmetrics.metric import Metric | ||
|
||
|
||
class LogCoshError(Metric): | ||
r"""Compute the `LogCosh Error`_. | ||
.. math:: \text{LogCoshError} = \log\left(\frac{\exp(\hat{y} - y) + \exp(\hat{y - y})}{2}\right) | ||
Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. | ||
Args: | ||
num_outputs: Number of outputs in multioutput setting | ||
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. | ||
Example (single output regression):: | ||
>>> from torchmetrics import LogCoshError | ||
>>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) | ||
>>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) | ||
>>> log_cosh_error = LogCoshError() | ||
>>> log_cosh_error(preds, target) | ||
tensor(0.3523) | ||
Example (multi output regression):: | ||
>>> from torchmetrics import LogCoshError | ||
>>> preds = torch.tensor([[3.0, 5.0, 1.2], [-2.1, 2.5, 7.0]]) | ||
>>> target = torch.tensor([[2.5, 5.0, 1.3], [0.3, 4.0, 8.0]]) | ||
>>> log_cosh_error = LogCoshError(num_outputs=3) | ||
>>> log_cosh_error(preds, target) | ||
tensor([0.9176, 0.4277, 0.2194]) | ||
""" | ||
|
||
is_differentiable = True | ||
higher_is_better = False | ||
full_state_update = False | ||
sum_log_cosh_error: Tensor | ||
total: Tensor | ||
|
||
def __init__(self, num_outputs: int = 1, **kwargs: Any) -> None: | ||
super().__init__(**kwargs) | ||
|
||
if not isinstance(num_outputs, int) and num_outputs < 1: | ||
raise ValueError(f"Expected argument `num_outputs` to be an int larger than 0, but got {num_outputs}") | ||
self.num_outputs = num_outputs | ||
self.add_state("sum_log_cosh_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum") | ||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") | ||
|
||
def update(self, preds: Tensor, target: Tensor) -> None: | ||
"""Update state with predictions and targets. | ||
Args: | ||
preds: estimated labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)`` | ||
target: ground truth labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)`` | ||
Raises: | ||
ValueError: | ||
If ``preds`` or ``target`` has multiple outputs when ``num_outputs=1`` | ||
""" | ||
sum_log_cosh_error, n_obs = _log_cosh_error_update(preds, target, self.num_outputs) | ||
self.sum_log_cosh_error += sum_log_cosh_error | ||
self.total += n_obs | ||
|
||
def compute(self) -> Tensor: | ||
"""Compute LogCosh error over state.""" | ||
return _log_cosh_error_compute(self.sum_log_cosh_error, self.total) |
Oops, something went wrong.