From 37ef01a5f88b50379b9ddc15b580215f507e7bab Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 17 Dec 2020 02:52:16 -0800 Subject: [PATCH 1/2] add automatic optimization property setter to lightning module --- pytorch_lightning/core/lightning.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ab66435a2935d..672c09895a66f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -112,6 +112,8 @@ def __init__(self, *args, **kwargs): self._current_hook_fx_name = None self._current_dataloader_idx = None + self._automatic_optimization: bool = True + def optimizers(self): opts = self.trainer.optimizers @@ -160,7 +162,12 @@ def automatic_optimization(self) -> bool: """ If False you are responsible for calling .backward, .step, zero_grad. """ - return True + return self._automatic_optimization + + @automatic_optimization.setter + def automatic_optimization(self, automatic_optimization: bool) -> None: + self._automatic_optimization = automatic_optimization + def print(self, *args, **kwargs) -> None: r""" From 551364901ecf875cb5004925c24d06a0cb2383bc Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 17 Dec 2020 02:55:48 -0800 Subject: [PATCH 2/2] Update test_manual_optimization.py --- tests/trainer/optimization/test_manual_optimization.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 33d14e852b285..05fa6667cfe0a 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -33,6 +33,11 @@ def test_multiple_optimizers_manual(tmpdir): Tests that only training_step can be used """ class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + def training_step(self, batch, batch_idx, optimizer_idx): # manual (opt_a, opt_b) = self.optimizers() @@ -69,10 +74,6 @@ def configure_optimizers(self): optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer, optimizer_2 - @property - def automatic_optimization(self) -> bool: - return False - model = TestModel() model.val_dataloader = None