Skip to content

Commit

Permalink
Docs and tests for how to save and load metrics (#2237)
Browse files Browse the repository at this point in the history
(cherry picked from commit f45945e)
  • Loading branch information
SkafteNicki authored and Borda committed Dec 1, 2023
1 parent 9010e44 commit d85afc7
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
43 changes: 43 additions & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,49 @@ In general we have a few recommendations for memory management:

See :ref:`Metric kwargs` for different advanced settings for controlling the memory footprint of metrics.

**************************
Saving and loading metrics
**************************

Because metrics are essentially just a subclass of :class:`torch.nn.Module`, saving and loading metrics works in the
same as any other `nn.Module`, with a key difference. Similar to `nn.Module` it is also recommended to save the state
dict instead of the actual metric e.g.:

.. code-block:: python
# Instead of this
torch.save(metric, "metric.pt")
# do this
torch.save(metric.state_dict(), "metric.pt")
The key difference is that metric states are not automatically a part of the state dict. This is to make sure that
torchmetrics is backward compatible with models that did not use the specific metrics when they were created. This
behavior can be overwritten by using the `metric.persistent` method, which will mark all metric states to also be saved
when `.state_dict` is called. Alternatively, for custom metrics, you can set the `persistent` argument when initializing
the state in the `self.add_state` method.

Therefore a correct example for saving and loading a metric would be:

.. code-block:: python
import torch
from torchmetrics.classification import MulticlassAccuracy
metric = MulticlassAccuracy(num_classes=5).to("cuda")
metric.persistent(True)
metric.update(torch.randint(5, (100,)).cuda(), torch.randint(5, (100,)).cuda())
torch.save(metric.state_dict(), "metric.pth")
metric2 = MulticlassAccuracy(num_classes=5).to("cpu")
metric2.load_state_dict(torch.load("metric.pth", map_location="cpu"))
# These will match, but be on different devices
print(metric.metric_state)
print(metric2.metric_state)
In the example, we also account for the initial metric state that is being saved on a different device than the
metric it is being loaded into by using the `map_location` argument.

***********************************************
Metrics in Distributed Data Parallel (DDP) mode
***********************************************
Expand Down
49 changes: 49 additions & 0 deletions tests/unittests/bases/test_saving_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from torchmetrics.classification import MulticlassAccuracy


@pytest.mark.parametrize("persistent", [True, False])
@pytest.mark.parametrize("in_device", ["cpu", "cuda"])
@pytest.mark.parametrize("out_device", ["cpu", "cuda"])
def test_saving_loading(persistent, in_device, out_device):
"""Test that saving and loading works as expected."""
if (in_device == "cuda" or out_device == "cuda") and not torch.cuda.is_available():
pytest.skip("Test requires cuda, but GPU not available.")

metric1 = MulticlassAccuracy(num_classes=5).to(in_device)
metric1.persistent(persistent)
metric1.update(torch.randint(5, (100,)).to(in_device), torch.randint(5, (100,)).to(in_device))
torch.save(metric1.state_dict(), "metric.pth")

metric2 = MulticlassAccuracy(num_classes=5).to(out_device)
metric2.load_state_dict(torch.load("metric.pth", map_location=out_device))

metric_state1 = metric1.metric_state
metric_state2 = metric2.metric_state

for k, v in metric_state1.items():
v2 = metric_state2[k]
if in_device == out_device:
if persistent:
assert torch.allclose(v, v2)
else:
assert not torch.allclose(v, v2)
else:
if persistent:
assert torch.allclose(v, v2.to(v.device))
else:
assert not torch.allclose(v, v2.to(v.device))

0 comments on commit d85afc7

Please sign in to comment.