Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimizer hooks in callbacks #4379

Merged
merged 4 commits into from
Oct 28, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,15 @@ def on_save_checkpoint(self, trainer, pl_module):
def on_load_checkpoint(self, checkpointed_state):
"""Called when loading a model checkpoint, use to reload state."""
pass

def on_after_backward(self, trainer, pl_module):
"""
Called after loss.backward() and before optimizers do anything.
"""
pass

def on_before_zero_grad(self, trainer, pl_module, optimizer):
"""
Called after optimizer.step() and before optimizer.zero_grad().
"""
pass
14 changes: 14 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,17 @@ def on_load_checkpoint(self, checkpoint):
if state:
state = deepcopy(state)
callback.on_load_checkpoint(state)

def on_after_backward(self):
"""
Called after loss.backward() and before optimizers do anything.
"""
for callback in self.callbacks:
callback.on_after_backward(self, self.get_model())

def on_before_zero_grad(self, optimizer):
"""
Called after optimizer.step() and before optimizer.zero_grad().
"""
for callback in self.callbacks:
callback.on_before_zero_grad(self, self.get_model(), optimizer)
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,7 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_
)

def on_before_zero_grad(self, optimizer):
model = self.trainer.get_model()
model.on_before_zero_grad(optimizer)
self.trainer.call_hook('on_before_zero_grad', optimizer)

def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
Expand Down
18 changes: 18 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(self):
self.on_validation_end_called = False
self.on_test_start_called = False
self.on_test_end_called = False
self.on_after_backward_called = False
self.on_before_zero_grad_called = False

def setup(self, trainer, pl_module, stage: str):
assert isinstance(trainer, Trainer)
Expand Down Expand Up @@ -160,6 +162,14 @@ def on_test_end(self, trainer, pl_module):
_check_args(trainer, pl_module)
self.on_test_end_called = True

def on_after_backward(self, trainer, pl_module):
_check_args(trainer, pl_module)
self.on_after_backward_called = True

def on_before_zero_grad(self, trainer, pl_module, optimizer):
_check_args(trainer, pl_module)
self.on_before_zero_grad_called = True

test_callback = TestCallback()

trainer_options = dict(
Expand Down Expand Up @@ -197,6 +207,8 @@ def on_test_end(self, trainer, pl_module):
assert not test_callback.on_validation_end_called
assert not test_callback.on_test_start_called
assert not test_callback.on_test_end_called
assert not test_callback.on_after_backward_called
assert not test_callback.on_before_zero_grad_called

# fit model
trainer = Trainer(**trainer_options)
Expand Down Expand Up @@ -228,6 +240,8 @@ def on_test_end(self, trainer, pl_module):
assert not test_callback.on_validation_end_called
assert not test_callback.on_test_start_called
assert not test_callback.on_test_end_called
assert not test_callback.on_after_backward_called
assert not test_callback.on_before_zero_grad_called

trainer.fit(model)

Expand Down Expand Up @@ -257,6 +271,8 @@ def on_test_end(self, trainer, pl_module):
assert not test_callback.on_test_batch_end_called
assert not test_callback.on_test_start_called
assert not test_callback.on_test_end_called
assert test_callback.on_after_backward_called
assert test_callback.on_before_zero_grad_called

# reset setup teardown callback
test_callback.teardown_called = False
Expand All @@ -277,3 +293,5 @@ def on_test_end(self, trainer, pl_module):
assert not test_callback.on_validation_end_called
assert not test_callback.on_validation_batch_end_called
assert not test_callback.on_validation_batch_start_called
assert not test_callback.on_after_backward_called
assert not test_callback.on_before_zero_grad_called