From 11aa1a36572a3569f093d9280e0d0c7a6f91fdbf Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 29 Oct 2020 08:54:35 -0700 Subject: [PATCH 1/5] Update metric.py --- pytorch_lightning/metrics/metric.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index f003e0d3da72a..b716817427230 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -145,6 +145,7 @@ def add_state( self._defaults[name] = deepcopy(default) self._reductions[name] = dist_reduce_fx + @torch.jit.unused def forward(self, *args, **kwargs): """ Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True. From ac00fcc733d365c82774e2cd140f2f79e8c9a4dd Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 2 Nov 2020 22:48:15 -0800 Subject: [PATCH 2/5] add test --- tests/metrics/test_metric_lightning.py | 48 +++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 7a860ea6c16fd..2f6f76cfbc670 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -1,5 +1,6 @@ -import torch +import os +import torch from pytorch_lightning import Trainer from pytorch_lightning.metrics import Metric from tests.base.boring_model import BoringModel @@ -78,3 +79,48 @@ def training_step(self, batch, batch_idx): logged = trainer.logged_metrics assert torch.allclose(torch.tensor(logged["sum"]), model.sum) + + +def test_scriptable(tmpdir): + class TestModel(BoringModel): + def __init__(self): + super().__init__() + # the metric is not used in the module's `forward` + # so the module should be exportable to TorchScript + self.metric = SumMetric() + self.sum = 0.0 + + def training_step(self, batch, batch_idx): + x = batch + self.metric(x.sum()) + self.sum += x.sum() + self.log("sum", self.metric, on_epoch=True, on_step=False) + return self.step(x) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + logger=False, + checkpoint_callback=False, + ) + trainer.fit(model) + rand_input = torch.randn(10, 32) + + script_model = model.to_torchscript() + + # test that we can still do inference + output = model(rand_input) + script_output = script_model(rand_input) + assert torch.allclose(output, script_output) + + # save to file and re-load to ensure export + load still works for inference + path = os.path.join(tmpdir, "tmp_script.pt") + torch.jit.save(script_model, path) + load_script_model = torch.jit.load(path) + load_script_model_output = load_script_model(rand_input) + assert torch.allclose(output, load_script_model_output) From 76de243fd6d884d7149804790b8dae596b169ffd Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 2 Nov 2020 22:50:39 -0800 Subject: [PATCH 3/5] Update CHANGELOG.md --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cc430356191c3..169760cd8f04f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340)) -- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807)) +- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807)) ### Changed @@ -45,6 +45,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) +- Fixed TorchScript export when module includes Metrics ([#4428](https://github.com/PyTorchLightning/pytorch-lightning/pull/4428)) + ## [1.0.4] - 2020-10-27 ### Added From d8ff339af4d5a10a7f3f5a4cfcf98b2767e25de2 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 2 Nov 2020 23:10:07 -0800 Subject: [PATCH 4/5] Update test_metric_lightning.py --- tests/metrics/test_metric_lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 2f6f76cfbc670..496a8b2125ba2 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -120,7 +120,7 @@ def training_step(self, batch, batch_idx): # save to file and re-load to ensure export + load still works for inference path = os.path.join(tmpdir, "tmp_script.pt") - torch.jit.save(script_model, path) + script_model.save(path) load_script_model = torch.jit.load(path) load_script_model_output = load_script_model(rand_input) assert torch.allclose(output, load_script_model_output) From 54e2f2640b8ca7e2f4d5df878a5cee100274325d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 2 Nov 2020 23:45:13 -0800 Subject: [PATCH 5/5] Update test_metric_lightning.py --- tests/metrics/test_metric_lightning.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 496a8b2125ba2..3c6938734be10 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -117,10 +117,3 @@ def training_step(self, batch, batch_idx): output = model(rand_input) script_output = script_model(rand_input) assert torch.allclose(output, script_output) - - # save to file and re-load to ensure export + load still works for inference - path = os.path.join(tmpdir, "tmp_script.pt") - script_model.save(path) - load_script_model = torch.jit.load(path) - load_script_model_output = load_script_model(rand_input) - assert torch.allclose(output, load_script_model_output)