Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Explained Variance Metric + metric fix #4013

Merged
merged 5 commits into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ MeanSquaredLogError
:noindex:


ExplainedVariance
^^^^^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance
:noindex:


Functional Metrics
==================

Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from pytorch_lightning.metrics.metric import Metric

from pytorch_lightning.metrics.classification.accuracy import Accuracy
from pytorch_lightning.metrics.regression import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError
from pytorch_lightning.metrics.regression import (
MeanSquaredError,
MeanAbsoluteError,
MeanSquaredLogError,
ExplainedVariance,
)
8 changes: 6 additions & 2 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call
the format discussed in the above note.

"""
if not isinstance(default, torch.Tensor) or (isinstance(default, list) and len(default) != 0):
if (
not isinstance(default, torch.Tensor)
and not isinstance(default, list) # noqa: W503
or (isinstance(default, list) and len(default) != 0) # noqa: W503
):
raise ValueError(
"state variable must be a tensor or any empty list (where you can append tensors)"
)
Expand Down Expand Up @@ -163,7 +167,7 @@ def _sync_dist(self):
elif isinstance(output_dict[attr][0], list):
output_dict[attr] = _flatten(output_dict[attr])

assert isinstance(reduction_fn, (Callable, None))
assert isinstance(reduction_fn, (Callable)) or reduction_fn is None
reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr]
setattr(self, attr, reduced)

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/regression/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError
from pytorch_lightning.metrics.regression.mean_absolute_error import MeanAbsoluteError
from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSquaredLogError
from pytorch_lightning.metrics.regression.explained_variance import ExplainedVariance
63 changes: 63 additions & 0 deletions pytorch_lightning/metrics/regression/explained_variance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
from typing import Any, Callable, Optional, Union

from pytorch_lightning.metrics.metric import Metric


class ExplainedVariance(Metric):
"""
Computes explained variance.

Example:

>>> from pytorch_lightning.metrics import ExplainedVariance
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> explained_variance = ExplainedVariance()
>>> explained_variance(preds, target)
tensor(0.9572)


"""

def __init__(
self,
compute_on_step: bool = True,
ddp_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step,
ddp_sync_on_step=ddp_sync_on_step,
process_group=process_group,
)

self.add_state("y", default=[], dist_reduce_fx=None)
self.add_state("y_pred", default=[], dist_reduce_fx=None)

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.

Args:
preds: Predictions from model
target: Ground truth values
"""
self.y.append(target)
self.y_pred.append(preds)

def compute(self):
"""
Computes explained variance over state.
"""
y_true = torch.cat(self.y, dim=0)
y_pred = torch.cat(self.y_pred, dim=0)

y_diff_avg = torch.mean(y_true - y_pred, dim=0)
numerator = torch.mean((y_true - y_pred - y_diff_avg) ** 2, dim=0)

y_true_avg = torch.mean(y_true, dim=0)
denominator = torch.mean((y_true - y_true_avg) ** 2, dim=0)

# TODO: multioutput
return 1.0 - torch.mean(numerator / denominator)
51 changes: 51 additions & 0 deletions tests/metrics/regression/test_explained_variance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import pytest
from collections import namedtuple
from functools import partial

from pytorch_lightning.metrics.regression import ExplainedVariance
from sklearn.metrics import explained_variance_score

from tests.metrics.utils import compute_batch, setup_ddp
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE

torch.manual_seed(42)

num_targets = 5

Input = namedtuple('Input', ["preds", "target"])

_single_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
)

_multi_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
)


def _single_target_sk_metric(preds, target, sk_fn=explained_variance_score):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_fn(sk_target, sk_preds)


def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score):
sk_preds = preds.view(-1, num_targets).numpy()
sk_target = target.view(-1, num_targets).numpy()
return sk_fn(sk_target, sk_preds)


@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("ddp_sync_on_step", [True, False])
@pytest.mark.parametrize(
"preds, target, sk_metric",
[
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric),
(_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric),
],
)
def test_explained_variance(ddp, ddp_sync_on_step, preds, target, sk_metric):
compute_batch(preds, target, ExplainedVariance, sk_metric, ddp_sync_on_step, ddp)