From 63f3cb7d3573a5973ba04a610a25ae65cb7b2d4b Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Thu, 8 Oct 2020 19:39:04 -0400 Subject: [PATCH 1/5] metric fix, explained variance --- docs/source/metrics.rst | 7 ++ pytorch_lightning/metrics/__init__.py | 2 +- pytorch_lightning/metrics/metric.py | 8 +- .../metrics/regression/__init__.py | 1 + .../metrics/regression/explained_variance.py | 74 +++++++++++++++++++ .../regression/test_explained_variance.py | 50 +++++++++++++ 6 files changed, 139 insertions(+), 3 deletions(-) create mode 100644 pytorch_lightning/metrics/regression/explained_variance.py create mode 100644 tests/metrics/regression/test_explained_variance.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index ab6af3b8f4616..30990a881a1e8 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -181,6 +181,13 @@ MeanSquaredLogError :noindex: +ExplainedVariance +^^^^^^^^^^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance + :noindex: + + Functional Metrics ================== diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 615a08e27a8e8..3833e7fa6cec9 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -1,4 +1,4 @@ from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.classification.accuracy import Accuracy -from pytorch_lightning.metrics.regression import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError +from pytorch_lightning.metrics.regression import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError, ExplainedVariance diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index acd2b2d5e2ef7..11733840d14f4 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -96,7 +96,11 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call the format discussed in the above note. """ - if not isinstance(default, torch.Tensor) or (isinstance(default, list) and len(default) != 0): + if ( + not isinstance(default, torch.Tensor) + and not isinstance(default, list) + and not (isinstance(default, list) and len(default) != 0) + ): raise ValueError( "state variable must be a tensor or any empty list (where you can append tensors)" ) @@ -163,7 +167,7 @@ def _sync_dist(self): elif isinstance(output_dict[attr][0], list): output_dict[attr] = _flatten(output_dict[attr]) - assert isinstance(reduction_fn, (Callable, None)) + assert isinstance(reduction_fn, (Callable)) or reduction_fn is None reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr] setattr(self, attr, reduced) diff --git a/pytorch_lightning/metrics/regression/__init__.py b/pytorch_lightning/metrics/regression/__init__.py index c5f235aeff12b..a7893e9c26a13 100644 --- a/pytorch_lightning/metrics/regression/__init__.py +++ b/pytorch_lightning/metrics/regression/__init__.py @@ -1,3 +1,4 @@ from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError from pytorch_lightning.metrics.regression.mean_absolute_error import MeanAbsoluteError from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSquaredLogError +from pytorch_lightning.metrics.regression.explained_variance import ExplainedVariance diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py new file mode 100644 index 0000000000000..e7610f3b53d35 --- /dev/null +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -0,0 +1,74 @@ +import torch +from typing import Any, Callable, Optional, Union + +from pytorch_lightning.metrics.metric import Metric + + +class ExplainedVariance(Metric): + """ + Computes explained variance. + + Example: + + >>> from pytorch_lightning.metrics import ExplainedVariance + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> explained_variance = ExplainedVariance() + >>> explained_variance(preds, target) + tensor(0.9572) + + + """ + + def __init__( + self, + compute_on_step: bool = True, + ddp_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + ddp_sync_on_step=ddp_sync_on_step, + process_group=process_group, + ) + + self.add_state("y", default=[], dist_reduce_fx=None) + self.add_state("y_pred", default=[], dist_reduce_fx=None) + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + self.y.append(target) + self.y_pred.append(preds) + + def compute(self): + """ + Computes explained variance over state. + """ + y_true = torch.cat(self.y, dim=0) + y_pred = torch.cat(self.y_pred, dim=0) + + y_diff_avg = torch.mean(y_true - y_pred, dim=0) + numerator = torch.mean((y_true - y_pred - y_diff_avg) ** 2, dim=0) + + y_true_avg = torch.mean(y_true, dim=0) + denominator = torch.mean((y_true - y_true_avg) ** 2, dim=0) + + # TODO: multioutput + return 1.0 - torch.mean(numerator / denominator) + + + + + + + +# target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) +# preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) +# explained_variance = ExplainedVariance() +# print(explained_variance(preds, target)) diff --git a/tests/metrics/regression/test_explained_variance.py b/tests/metrics/regression/test_explained_variance.py new file mode 100644 index 0000000000000..8a1c8606eee4b --- /dev/null +++ b/tests/metrics/regression/test_explained_variance.py @@ -0,0 +1,50 @@ +import torch +import pytest +from collections import namedtuple +from functools import partial + +from pytorch_lightning.metrics.regression import ExplainedVariance +from sklearn.metrics import explained_variance_score + +from tests.metrics.utils import compute_batch, setup_ddp +from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE +torch.manual_seed(42) + +num_targets = 5 + +Input = namedtuple('Input', ["preds", "target"]) + +_single_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.rand(NUM_BATCHES, BATCH_SIZE), +) + +_multi_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), +) + + +def _single_target_sk_metric(preds, target, sk_fn=explained_variance_score): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + return sk_fn(sk_target, sk_preds) + + +def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score): + sk_preds = preds.view(-1, num_targets).numpy() + sk_target = target.view(-1, num_targets).numpy() + return sk_fn(sk_target, sk_preds) + + +@pytest.mark.parametrize("ddp", [True, False]) +@pytest.mark.parametrize("ddp_sync_on_step", [True, False]) +@pytest.mark.parametrize( + "preds, target, sk_metric", + [ + #(_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), + (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), + ], +) +def test_explained_variance(ddp, ddp_sync_on_step, preds, target, sk_metric): + compute_batch(preds, target, ExplainedVariance, sk_metric, ddp_sync_on_step, ddp) From 4252b9f4a911edae401c6daa0e0061e03d7f08cc Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Thu, 8 Oct 2020 19:41:16 -0400 Subject: [PATCH 2/5] one more test --- tests/metrics/regression/test_explained_variance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/regression/test_explained_variance.py b/tests/metrics/regression/test_explained_variance.py index 8a1c8606eee4b..40b6d2840c752 100644 --- a/tests/metrics/regression/test_explained_variance.py +++ b/tests/metrics/regression/test_explained_variance.py @@ -42,7 +42,7 @@ def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score): @pytest.mark.parametrize( "preds, target, sk_metric", [ - #(_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), + (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), ], ) From b66748fcced9361545aee40e01da4708f828f2c9 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Thu, 8 Oct 2020 19:43:17 -0400 Subject: [PATCH 3/5] pep8 --- pytorch_lightning/metrics/__init__.py | 7 ++++++- pytorch_lightning/metrics/metric.py | 4 ++-- pytorch_lightning/metrics/regression/explained_variance.py | 5 ----- tests/metrics/regression/test_explained_variance.py | 3 ++- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 3833e7fa6cec9..6a20c6a0b1771 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -1,4 +1,9 @@ from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.classification.accuracy import Accuracy -from pytorch_lightning.metrics.regression import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError, ExplainedVariance +from pytorch_lightning.metrics.regression import ( + MeanSquaredError, + MeanAbsoluteError, + MeanSquaredLogError, + ExplainedVariance, +) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 11733840d14f4..7a9279eb59371 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -98,8 +98,8 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call """ if ( not isinstance(default, torch.Tensor) - and not isinstance(default, list) - and not (isinstance(default, list) and len(default) != 0) + and not isinstance(default, list) # noqa: W503 + and not (isinstance(default, list) and len(default) != 0) # noqa: W503 ): raise ValueError( "state variable must be a tensor or any empty list (where you can append tensors)" diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py index e7610f3b53d35..81680802de1c3 100644 --- a/pytorch_lightning/metrics/regression/explained_variance.py +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -63,11 +63,6 @@ def compute(self): return 1.0 - torch.mean(numerator / denominator) - - - - - # target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) # preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) # explained_variance = ExplainedVariance() diff --git a/tests/metrics/regression/test_explained_variance.py b/tests/metrics/regression/test_explained_variance.py index 40b6d2840c752..7c54d486efeff 100644 --- a/tests/metrics/regression/test_explained_variance.py +++ b/tests/metrics/regression/test_explained_variance.py @@ -7,7 +7,8 @@ from sklearn.metrics import explained_variance_score from tests.metrics.utils import compute_batch, setup_ddp -from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE +from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE + torch.manual_seed(42) num_targets = 5 From 52461c7cba90ca645e6459d4efedbd389a50158e Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Thu, 8 Oct 2020 19:46:12 -0400 Subject: [PATCH 4/5] remove comment --- pytorch_lightning/metrics/regression/explained_variance.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py index 81680802de1c3..e12ab8d8a5d4f 100644 --- a/pytorch_lightning/metrics/regression/explained_variance.py +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -61,9 +61,3 @@ def compute(self): # TODO: multioutput return 1.0 - torch.mean(numerator / denominator) - - -# target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) -# preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) -# explained_variance = ExplainedVariance() -# print(explained_variance(preds, target)) From e479ca02a0856976eb2cbe0041e2a6ab9fabcc6f Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Thu, 8 Oct 2020 20:41:58 -0400 Subject: [PATCH 5/5] fix add_state condition --- pytorch_lightning/metrics/metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 7a9279eb59371..34c8ef88a98bb 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -98,8 +98,8 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call """ if ( not isinstance(default, torch.Tensor) - and not isinstance(default, list) # noqa: W503 - and not (isinstance(default, list) and len(default) != 0) # noqa: W503 + and not isinstance(default, list) # noqa: W503 + or (isinstance(default, list) and len(default) != 0) # noqa: W503 ): raise ValueError( "state variable must be a tensor or any empty list (where you can append tensors)"