Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added changes for Test Differentiability [1/n] #154

Merged
merged 17 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,17 @@ They simply compute the metric value based on the given inputs.
Also, the integration within other parts of PyTorch Lightning will never be as tight as with the Module-based interface.
If you look for just computing the values, the functional metrics are the way to go.
However, if you are looking for the best integration and user experience, please consider also using the Module interface.


*****************************
Metrics and Differentiability
*****************************

All Metric have a property that determines if a metric is differentiable or not.

.. code-block:: python

@property
def is_differentiable(self):
Borda marked this conversation as resolved.
Show resolved Hide resolved
return True/False

SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 8 additions & 0 deletions tests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def test_accuracy_fn(self, preds, target, subset_accuracy):
},
)

def test_metrics_differentiability(self, preds, target):
bhadreshpsavani marked this conversation as resolved.
Show resolved Hide resolved
metric = Accuracy()
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
self.test_differentiability(
preds=preds,
target=target,
metric_module=metric,
)


_l1to4 = [0.1, 0.2, 0.3, 0.4]
_l1to4t3 = np.array([_l1to4, _l1to4, _l1to4])
Expand Down
19 changes: 17 additions & 2 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def run_precision_test_cpu(
metric_functional: Callable,
metric_args: dict = {},
):
"""Test if an metric can be used with half precision tensors on cpu
"""Test if a metric can be used with half precision tensors on cpu
Args:
preds: torch tensor with predictions
target: torch tensor with targets
Expand All @@ -388,7 +388,7 @@ def run_precision_test_gpu(
metric_functional: Callable,
metric_args: dict = {},
):
"""Test if an metric can be used with half precision tensors on gpu
"""Test if a metric can be used with half precision tensors on gpu
Args:
preds: torch tensor with predictions
target: torch tensor with targets
Expand All @@ -400,6 +400,21 @@ def run_precision_test_gpu(
metric_module(**metric_args), partial(metric_functional, **metric_args), preds, target, device="cuda"
)

def test_differentiability(
self,
preds: torch.Tensor,
target: torch.Tensor,
metric_module: Metric,
):
"""Test if a metric is differentiable or not
Args:
Borda marked this conversation as resolved.
Show resolved Hide resolved
preds: torch tensor with predictions
target: torch tensor with targets
metric_module: the metric module to test
"""
preds.requires_grad = True
out = metric_module(preds, target)
assert metric_module.is_differentiable == out.requires_grad
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

class DummyMetric(Metric):
name = "Dummy"
Expand Down
8 changes: 8 additions & 0 deletions tests/regression/test_mean_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ def test_mean_error_functional(self, preds, target, sk_metric, metric_class, met
sk_metric=partial(sk_metric, sk_fn=sk_fn),
)

def test_metrics_differentiability(self, preds, target):
metric = MeanSquaredError()
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
self.test_differentiability(
preds=preds,
target=target,
metric_module=metric,
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6'
)
Expand Down
4 changes: 4 additions & 0 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,7 @@ def compute(self) -> Tensor:
Computes accuracy based on inputs passed in to ``update`` previously.
"""
return _accuracy_compute(self.correct, self.total)

@property
def is_differentiable(self):
return False
4 changes: 4 additions & 0 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,10 @@ def __pos__(self):
def __getitem__(self, idx):
return CompositionalMetric(lambda x: x[idx], self, None)

@property
def is_differentiable(self):
raise NotImplementedError
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved


def _neg(tensor: Tensor):
return -torch.abs(tensor)
Expand Down
4 changes: 4 additions & 0 deletions torchmetrics/regression/mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,7 @@ def compute(self):
Computes mean squared error over state.
"""
return _mean_squared_error_compute(self.sum_squared_error, self.total)

@property
def is_differentiable(self):
return True