From d85afc7b5719c14887458f1bec44e06ea1765b7c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 28 Nov 2023 17:01:07 +0100 Subject: [PATCH] Docs and tests for how to save and load metrics (#2237) (cherry picked from commit f45945e81c236bfb946faa13c5192fca8bb09e1f) --- docs/source/pages/overview.rst | 43 +++++++++++++++++ tests/unittests/bases/test_saving_loading.py | 49 ++++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 tests/unittests/bases/test_saving_loading.py diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index 9659bde0b92..bbe33c69cca 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -182,6 +182,49 @@ In general we have a few recommendations for memory management: See :ref:`Metric kwargs` for different advanced settings for controlling the memory footprint of metrics. +************************** +Saving and loading metrics +************************** + +Because metrics are essentially just a subclass of :class:`torch.nn.Module`, saving and loading metrics works in the +same as any other `nn.Module`, with a key difference. Similar to `nn.Module` it is also recommended to save the state +dict instead of the actual metric e.g.: + +.. code-block:: python + + # Instead of this + torch.save(metric, "metric.pt") + # do this + torch.save(metric.state_dict(), "metric.pt") + +The key difference is that metric states are not automatically a part of the state dict. This is to make sure that +torchmetrics is backward compatible with models that did not use the specific metrics when they were created. This +behavior can be overwritten by using the `metric.persistent` method, which will mark all metric states to also be saved +when `.state_dict` is called. Alternatively, for custom metrics, you can set the `persistent` argument when initializing +the state in the `self.add_state` method. + +Therefore a correct example for saving and loading a metric would be: + +.. code-block:: python + + import torch + from torchmetrics.classification import MulticlassAccuracy + + metric = MulticlassAccuracy(num_classes=5).to("cuda") + metric.persistent(True) + metric.update(torch.randint(5, (100,)).cuda(), torch.randint(5, (100,)).cuda()) + torch.save(metric.state_dict(), "metric.pth") + + metric2 = MulticlassAccuracy(num_classes=5).to("cpu") + metric2.load_state_dict(torch.load("metric.pth", map_location="cpu")) + + # These will match, but be on different devices + print(metric.metric_state) + print(metric2.metric_state) + +In the example, we also account for the initial metric state that is being saved on a different device than the +metric it is being loaded into by using the `map_location` argument. + *********************************************** Metrics in Distributed Data Parallel (DDP) mode *********************************************** diff --git a/tests/unittests/bases/test_saving_loading.py b/tests/unittests/bases/test_saving_loading.py new file mode 100644 index 00000000000..f808674367c --- /dev/null +++ b/tests/unittests/bases/test_saving_loading.py @@ -0,0 +1,49 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +from torchmetrics.classification import MulticlassAccuracy + + +@pytest.mark.parametrize("persistent", [True, False]) +@pytest.mark.parametrize("in_device", ["cpu", "cuda"]) +@pytest.mark.parametrize("out_device", ["cpu", "cuda"]) +def test_saving_loading(persistent, in_device, out_device): + """Test that saving and loading works as expected.""" + if (in_device == "cuda" or out_device == "cuda") and not torch.cuda.is_available(): + pytest.skip("Test requires cuda, but GPU not available.") + + metric1 = MulticlassAccuracy(num_classes=5).to(in_device) + metric1.persistent(persistent) + metric1.update(torch.randint(5, (100,)).to(in_device), torch.randint(5, (100,)).to(in_device)) + torch.save(metric1.state_dict(), "metric.pth") + + metric2 = MulticlassAccuracy(num_classes=5).to(out_device) + metric2.load_state_dict(torch.load("metric.pth", map_location=out_device)) + + metric_state1 = metric1.metric_state + metric_state2 = metric2.metric_state + + for k, v in metric_state1.items(): + v2 = metric_state2[k] + if in_device == out_device: + if persistent: + assert torch.allclose(v, v2) + else: + assert not torch.allclose(v, v2) + else: + if persistent: + assert torch.allclose(v, v2.to(v.device)) + else: + assert not torch.allclose(v, v2.to(v.device))