Skip to content

Commit

Permalink
Make optimizers skippable when using amp (#7975)
Browse files Browse the repository at this point in the history
Co-authored-by: Yifu Wang <yifuwang@[email protected]>
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
3 people authored Jun 16, 2021
1 parent 0004216 commit b71aa55
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed training loop total batch counter when accumulate grad batches was enabled ([#7692](https://github.com/PyTorchLightning/pytorch-lightning/pull/7692))


- Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975))


## [1.3.2] - 2021-05-18

### Changed
Expand Down
14 changes: 8 additions & 6 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,21 @@ def pre_optimizer_step(
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)
lambda_closure()

if not pl_module.automatic_optimization:
self.scaler.unscale_(optimizer)
pl_module.trainer.call_hook("on_after_backward")
self.scaler.step(optimizer)
self.scaler.update()
else:
result = lambda_closure()
# lambda_closure returning None indicates that backward has been skipped
if result is not None:
self.scaler.step(optimizer)
self.scaler.update()

return False

def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
"""Updates the GradScaler"""
self.scaler.step(optimizer)
self.scaler.update()

@contextmanager
def train_step_context(self) -> Generator[None, None, None]:
"""Enable autocast context"""
Expand Down
41 changes: 41 additions & 0 deletions tests/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,47 @@ def test_amp_gradient_unscale(tmpdir, accum: int):
trainer.fit(model)


@RunIf(min_gpus=1, amp_native=True)
def test_amp_skip_optimizer(tmpdir):
"""
Test that optimizers can be skipped when using amp
"""

class CustomBoringModel(BoringModel):

def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(32, 32)
self.layer2 = torch.nn.Linear(32, 2)

def forward(self, x: torch.Tensor):
x = self.layer1(x)
x = self.layer2(x)
return x

def training_step(self, batch, batch_idx, optimizer_idx):
if optimizer_idx == 1:
return None
output = self(batch)
return self.loss(batch, output)

def configure_optimizers(self):
return [
torch.optim.SGD(self.layer1.parameters(), lr=0.1),
torch.optim.SGD(self.layer2.parameters(), lr=0.1),
]

trainer = Trainer(
default_root_dir=tmpdir,
gpus=1,
fast_dev_run=1,
amp_backend='native',
precision=16,
)
model = CustomBoringModel()
trainer.fit(model)


@RunIf(min_gpus=2, amp_apex=True, special=True)
@pytest.mark.parametrize("amp_level", ['O2'])
def test_amp_apex_ddp_fit(amp_level, tmpdir):
Expand Down

0 comments on commit b71aa55

Please sign in to comment.