diff --git a/CHANGELOG.md b/CHANGELOG.md index 87a70b4012f..5fbaf20fbe2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ) - Added `prefix` argument to `MetricCollection` ([#70](https://github.com/PyTorchLightning/metrics/pull/70)) - Added `__getitem__` as metric arithmetic operation ([#142](https://github.com/PyTorchLightning/metrics/pull/142)) +- Added property `is_differentiable` to metrics and test for differentiability ([#154](https://github.com/PyTorchLightning/metrics/pull/154)) ### Changed diff --git a/docs/source/pages/implement.rst b/docs/source/pages/implement.rst index 8354bd0943f..c505d1aedae 100644 --- a/docs/source/pages/implement.rst +++ b/docs/source/pages/implement.rst @@ -39,18 +39,6 @@ Example implementation: def compute(self): return self.correct.float() / self.total -Metrics support backpropagation, if all computations involved in the metric calculation -are differentiable. However, note that the cached state is detached from the computational -graph and cannot be backpropagated. Not doing this would mean storing the computational -graph for each update call, which can lead to out-of-memory errors. -In practise this means that: - -.. code-block:: python - - metric = MyMetric() - val = metric(pred, target) # this value can be backpropagated - val = metric.compute() # this value cannot be backpropagated - Internal implementation details ------------------------------- diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index 1c946ae66ad..24597bb8ae6 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -281,3 +281,31 @@ They simply compute the metric value based on the given inputs. Also, the integration within other parts of PyTorch Lightning will never be as tight as with the Module-based interface. If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the Module interface. + + +***************************** +Metrics and differentiability +***************************** + +Metrics support backpropagation, if all computations involved in the metric calculation +are differentiable. All modular metrics have a property that determines if a metric is +differentible or not. + +.. code-block:: python + + @property + def is_differentiable(self) -> bool: + return True/False + +However, note that the cached state is detached from the computational +graph and cannot be backpropagated. Not doing this would mean storing the computational +graph for each update call, which can lead to out-of-memory errors. +In practise this means that: + +.. code-block:: python + + metric = MyMetric() + val = metric(pred, target) # this value can be backpropagated + val = metric.compute() # this value cannot be backpropagated + +A functional metric is differentiable if its corresponding modular metric is differentiable. diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index cc342ec8570..e64303162f1 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -103,6 +103,18 @@ def test_accuracy_fn(self, preds, target, subset_accuracy): }, ) + def test_accuracy_differentiability(self, preds, target, subset_accuracy): + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=Accuracy, + metric_functional=accuracy, + metric_args={ + "threshold": THRESHOLD, + "subset_accuracy": subset_accuracy + } + ) + _l1to4 = [0.1, 0.2, 0.3, 0.4] _l1to4t3 = np.array([_l1to4, _l1to4, _l1to4]) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 2fa72fd2df9..d1d406ff849 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -393,7 +393,7 @@ def run_precision_test_cpu( metric_args: dict = {}, **kwargs_update, ): - """Test if an metric can be used with half precision tensors on cpu + """Test if a metric can be used with half precision tensors on cpu Args: preds: torch tensor with predictions target: torch tensor with targets @@ -416,7 +416,7 @@ def run_precision_test_gpu( metric_args: dict = {}, **kwargs_update, ): - """Test if an metric can be used with half precision tensors on gpu + """Test if a metric can be used with half precision tensors on gpu Args: preds: torch tensor with predictions target: torch tensor with targets @@ -430,6 +430,39 @@ def run_precision_test_gpu( metric_module(**metric_args), metric_functional, preds, target, device="cuda", **kwargs_update ) + def run_differentiability_test( + self, + preds: torch.Tensor, + target: torch.Tensor, + metric_module: Metric, + metric_functional: Callable, + metric_args: dict = {}, + ): + """Test if a metric is differentiable or not + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_module: the metric module to test + metric_args: dict with additional arguments used for class initialization + """ + # only floating point tensors can require grad + metric = metric_module(**metric_args) + if preds.is_floating_point(): + preds.requires_grad = True + out = metric(preds[0], target[0]) + assert metric.is_differentiable == out.requires_grad + + if metric.is_differentiable: + # check for numerical correctness + assert torch.autograd.gradcheck( + partial(metric_functional, **metric_args), + (preds[0].double(), target[0]) + ) + + # reset as else it will carry over to other tests + preds.requires_grad = False + class DummyMetric(Metric): name = "Dummy" diff --git a/tests/regression/test_mean_error.py b/tests/regression/test_mean_error.py index ac20e0cd2f5..3e1832f99d3 100644 --- a/tests/regression/test_mean_error.py +++ b/tests/regression/test_mean_error.py @@ -96,6 +96,14 @@ def test_mean_error_functional(self, preds, target, sk_metric, metric_class, met sk_metric=partial(sk_metric, sk_fn=sk_fn), ) + def test_mean_error_differentiability(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=metric_class, + metric_functional=metric_functional + ) + @pytest.mark.skipif( not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' ) diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index e40db2f5619..a3670c45692 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -153,3 +153,7 @@ def compute(self) -> Tensor: Computes accuracy based on inputs passed in to ``update`` previously. """ return _accuracy_compute(self.correct, self.total) + + @property + def is_differentiable(self): + return False diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 6cc88ad9ce4..93c74a43697 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -451,6 +451,10 @@ def __pos__(self): def __getitem__(self, idx): return CompositionalMetric(lambda x: x[idx], self, None) + @property + def is_differentiable(self): + raise NotImplementedError + def _neg(tensor: Tensor): return -torch.abs(tensor) diff --git a/torchmetrics/regression/mean_absolute_error.py b/torchmetrics/regression/mean_absolute_error.py index 470394a52b4..a828c9ce716 100644 --- a/torchmetrics/regression/mean_absolute_error.py +++ b/torchmetrics/regression/mean_absolute_error.py @@ -84,3 +84,7 @@ def compute(self): Computes mean absolute error over state. """ return _mean_absolute_error_compute(self.sum_abs_error, self.total) + + @property + def is_differentiable(self): + return True diff --git a/torchmetrics/regression/mean_squared_error.py b/torchmetrics/regression/mean_squared_error.py index 4a71639ea35..a1bd8a6a282 100644 --- a/torchmetrics/regression/mean_squared_error.py +++ b/torchmetrics/regression/mean_squared_error.py @@ -85,3 +85,7 @@ def compute(self): Computes mean squared error over state. """ return _mean_squared_error_compute(self.sum_squared_error, self.total) + + @property + def is_differentiable(self): + return True diff --git a/torchmetrics/regression/mean_squared_log_error.py b/torchmetrics/regression/mean_squared_log_error.py index 322a7cc770e..9519edfee69 100644 --- a/torchmetrics/regression/mean_squared_log_error.py +++ b/torchmetrics/regression/mean_squared_log_error.py @@ -90,3 +90,7 @@ def compute(self): Compute mean squared logarithmic error over state. """ return _mean_squared_log_error_compute(self.sum_squared_log_error, self.total) + + @property + def is_differentiable(self): + return True