Skip to content

Commit

Permalink
Add automatic optimization property setter to lightning module (#5169)
Browse files Browse the repository at this point in the history
* add automatic optimization property setter to lightning module

* Update test_manual_optimization.py

Co-authored-by: chaton <[email protected]>
(cherry picked from commit 8748293)
  • Loading branch information
ananthsub authored and Borda committed Jan 26, 2021
1 parent 3782a06 commit eee2515
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
8 changes: 7 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, *args, **kwargs):
self._running_manual_backward = False
self._current_hook_fx_name = None
self._current_dataloader_idx = None
self._automatic_optimization: bool = True

def optimizers(self):
opts = self.trainer.optimizers
Expand Down Expand Up @@ -151,7 +152,12 @@ def automatic_optimization(self) -> bool:
"""
If False you are responsible for calling .backward, .step, zero_grad.
"""
return True
return self._automatic_optimization

@automatic_optimization.setter
def automatic_optimization(self, automatic_optimization: bool) -> None:
self._automatic_optimization = automatic_optimization


def print(self, *args, **kwargs) -> None:
r"""
Expand Down
9 changes: 5 additions & 4 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def test_multiple_optimizers_manual(tmpdir):
Tests that only training_step can be used
"""
class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx, optimizer_idx):
# manual
(opt_a, opt_b) = self.optimizers()
Expand Down Expand Up @@ -69,10 +74,6 @@ def configure_optimizers(self):
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2

@property
def automatic_optimization(self) -> bool:
return False

model = TestModel()
model.val_dataloader = None

Expand Down

0 comments on commit eee2515

Please sign in to comment.