Skip to content

Commit

Permalink
Update explained variance metric (#4024)
Browse files Browse the repository at this point in the history
* update metrics

* pep8

* Update pytorch_lightning/metrics/regression/explained_variance.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* add typing for testing utils

* change from assert to raise exception

* add test for raised shape error

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
SkafteNicki and Borda authored Oct 9, 2020

Verified

This commit was signed with the committer’s verified signature.
KyleFromNVIDIA Kyle Edwards
1 parent 1b4209b commit 8a3c800
Showing 9 changed files with 140 additions and 25 deletions.
5 changes: 5 additions & 0 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -234,3 +234,8 @@ def __setstate__(self, state):
self.__dict__.update(state)
self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)

def _check_same_shape(self, pred: torch.Tensor, target: torch.Tensor):
""" Check that predictions and target have the same shape, else raise error """
if pred.shape != target.shape:
raise RuntimeError('Predictions and targets are expected to have the same shape')
67 changes: 62 additions & 5 deletions pytorch_lightning/metrics/regression/explained_variance.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,40 @@
import torch
from typing import Any, Callable, Optional, Union
from typing import Any, Optional

from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import rank_zero_warn


class ExplainedVariance(Metric):
"""
Computes explained variance.
Forward accepts
- ``preds`` (float tensor): ``(N,)`` or ``(N, ...)`` (multioutput)
- ``target`` (long tensor): ``(N,)`` or ``(N, ...)`` (multioutput)
In the case of multioutput, as default the variances will be uniformly
averaged over the additional dimensions. Please see argument `multioutput`
for changing this behavior.
Args:
multioutput:
Defines aggregation in the case of multiple output scores. Can be one
of the following strings (default is `'uniform_average'`.):
* `'raw_values'` returns full set of scores
* `'uniform_average'` scores are uniformly averaged
* `'variance_weighted'` scores are weighted by their individual variances
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
ddp_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Example:
>>> from pytorch_lightning.metrics import ExplainedVariance
@@ -17,11 +44,16 @@ class ExplainedVariance(Metric):
>>> explained_variance(preds, target)
tensor(0.9572)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> explained_variance = ExplainedVariance(multioutput='raw_values')
>>> explained_variance(preds, target)
tensor([0.9677, 1.0000])
"""

def __init__(
self,
multioutput: str = 'uniform_average',
compute_on_step: bool = True,
ddp_sync_on_step: bool = False,
process_group: Optional[Any] = None,
@@ -31,10 +63,19 @@ def __init__(
ddp_sync_on_step=ddp_sync_on_step,
process_group=process_group,
)

allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted')
if multioutput not in allowed_multioutput:
raise ValueError(
f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}'
)
self.multioutput = multioutput
self.add_state("y", default=[], dist_reduce_fx=None)
self.add_state("y_pred", default=[], dist_reduce_fx=None)

rank_zero_warn('Metric `ExplainedVariance` will save all targets and'
' predictions in buffer. For large datasets this may lead'
' to large memory footprint.')

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.
@@ -43,6 +84,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
preds: Predictions from model
target: Ground truth values
"""
self._check_same_shape(preds, target)
self.y.append(target)
self.y_pred.append(preds)

@@ -59,5 +101,20 @@ def compute(self):
y_true_avg = torch.mean(y_true, dim=0)
denominator = torch.mean((y_true - y_true_avg) ** 2, dim=0)

# TODO: multioutput
return 1.0 - torch.mean(numerator / denominator)
# Take care of division by zero
nonzero_numerator = numerator != 0
nonzero_denominator = denominator != 0
valid_score = nonzero_numerator & nonzero_denominator
output_scores = torch.ones_like(y_diff_avg)
output_scores[valid_score] = 1.0 - (numerator[valid_score] / denominator[valid_score])
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.

# Decide what to do in multioutput case
# Todo: allow user to pass in tensor with weights
if self.multioutput == 'raw_values':
return output_scores
if self.multioutput == 'uniform_average':
return torch.mean(output_scores)
if self.multioutput == 'variance_weighted':
denom_sum = torch.sum(denominator)
return torch.sum(denominator / denom_sum * output_scores)
3 changes: 1 addition & 2 deletions pytorch_lightning/metrics/regression/mean_absolute_error.py
Original file line number Diff line number Diff line change
@@ -50,8 +50,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
preds: Predictions from model
target: Ground truth values
"""
assert preds.shape == target.shape, \
'Predictions and targets are expected to have the same shape'
self._check_same_shape(preds, target)
abs_error = torch.abs(preds - target)

self.sum_abs_error += torch.sum(abs_error)
3 changes: 1 addition & 2 deletions pytorch_lightning/metrics/regression/mean_squared_error.py
Original file line number Diff line number Diff line change
@@ -51,8 +51,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
preds: Predictions from model
target: Ground truth values
"""
assert preds.shape == target.shape, \
'Predictions and targets are expected to have the same shape'
self._check_same_shape(preds, target)
squared_error = torch.pow(preds - target, 2)

self.sum_squared_error += torch.sum(squared_error)
Original file line number Diff line number Diff line change
@@ -51,8 +51,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
preds: Predictions from model
target: Ground truth values
"""
assert preds.shape == target.shape, \
'Predictions and targets are expected to have the same shape'
self._check_same_shape(preds, target)
squared_log_error = torch.pow(torch.log1p(preds) - torch.log1p(target), 2)

self.sum_squared_log_error += torch.sum(squared_log_error)
4 changes: 1 addition & 3 deletions tests/metrics/classification/test_accuracy.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import os
import pytest
import torch
import os
import numpy as np
from collections import namedtuple

from pytorch_lightning.metrics.classification.accuracy import Accuracy
from sklearn.metrics import accuracy_score

from tests.metrics.utils import compute_batch, setup_ddp
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE
from tests.metrics.utils import compute_batch, NUM_BATCHES, BATCH_SIZE

torch.manual_seed(42)

23 changes: 19 additions & 4 deletions tests/metrics/regression/test_explained_variance.py
Original file line number Diff line number Diff line change
@@ -6,8 +6,7 @@
from pytorch_lightning.metrics.regression import ExplainedVariance
from sklearn.metrics import explained_variance_score

from tests.metrics.utils import compute_batch, setup_ddp
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE
from tests.metrics.utils import compute_batch, NUM_BATCHES, BATCH_SIZE

torch.manual_seed(42)

@@ -40,12 +39,28 @@ def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("ddp_sync_on_step", [True, False])
@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted'])
@pytest.mark.parametrize(
"preds, target, sk_metric",
[
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric),
(_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric),
],
)
def test_explained_variance(ddp, ddp_sync_on_step, preds, target, sk_metric):
compute_batch(preds, target, ExplainedVariance, sk_metric, ddp_sync_on_step, ddp)
def test_explained_variance(ddp, ddp_sync_on_step, multioutput, preds, target, sk_metric):
compute_batch(
preds,
target,
ExplainedVariance,
partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)),
ddp_sync_on_step,
ddp,
metric_args=dict(multioutput=multioutput),
)


def test_error_on_different_shape(metric_class=ExplainedVariance):
metric = metric_class()
with pytest.raises(RuntimeError,
match='Predictions and targets are expected to have the same shape'):
metric(torch.randn(100,), torch.randn(50,))
11 changes: 9 additions & 2 deletions tests/metrics/regression/test_mean_error.py
Original file line number Diff line number Diff line change
@@ -6,8 +6,7 @@
from pytorch_lightning.metrics.regression import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError
from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_squared_log_error

from tests.metrics.utils import compute_batch, setup_ddp
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE
from tests.metrics.utils import compute_batch, NUM_BATCHES, BATCH_SIZE

torch.manual_seed(42)

@@ -57,3 +56,11 @@ def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error):
)
def test_mean_error(ddp, ddp_sync_on_step, preds, target, sk_metric, metric_class, sk_fn):
compute_batch(preds, target, metric_class, partial(sk_metric, sk_fn=sk_fn), ddp_sync_on_step, ddp)


@pytest.mark.parametrize("metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError])
def test_error_on_different_shape(metric_class):
metric = metric_class()
with pytest.raises(RuntimeError,
match='Predictions and targets are expected to have the same shape'):
metric(torch.randn(100,), torch.randn(50,))
46 changes: 41 additions & 5 deletions tests/metrics/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
import torch
import numpy as np
import os
import sys
import pytest
import pickle
from typing import Callable

import torch
import numpy as np

from pytorch_lightning.metrics import Metric

NUM_PROCESSES = 2
NUM_BATCHES = 10
BATCH_SIZE = 16


def setup_ddp(rank, world_size):
""" Setup ddp enviroment """
os.environ["MASTER_ADDR"] = 'localhost'
os.environ['MASTER_PORT'] = '8088'
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)


def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_step, worldsize=1, metric_args={}):

def _compute_batch(rank: int,
preds: torch.Tensor,
target: torch.Tensor,
metric_class: Metric,
sk_metric: Callable,
ddp_sync_on_step: bool,
worldsize: int = 1,
metric_args: dict = {}
):
""" Utility function doing the actual comparison between lightning metric
and reference metric
"""
# Instanciate lightning metric
metric = metric_class(compute_on_step=True, ddp_sync_on_step=ddp_sync_on_step, **metric_args)

# verify metrics work after being loaded from pickled state
@@ -52,7 +68,27 @@ def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_ste
assert np.allclose(result.numpy(), sk_result)


def compute_batch(preds, target, metric_class, sk_metric, ddp_sync_on_step, ddp=False, metric_args={}):
def compute_batch(preds: torch.Tensor,
target: torch.Tensor,
metric_class: Metric,
sk_metric: Callable,
ddp_sync_on_step: bool,
ddp: bool = False,
metric_args: dict = {}
):
""" Utility function for comparing the result between a lightning class
metric and another metric (often sklearns)
Args:
preds: prediction tensor
target: target tensor
metric_class: lightning metric class to test
sk_metric: function to compare with
ddp_sync_on_step: bool, determine if values should be reduce on step
ddp: bool, determine if test should run in ddp mode
metric_args: dict, additional kwargs that are use when instanciating
the lightning metric
"""
if ddp:
if sys.platform == "win32":
pytest.skip("DDP not supported on windows")

0 comments on commit 8a3c800

Please sign in to comment.