diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index e8351072d2cc0..5fb9f7243b201 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -229,174 +229,171 @@ def train_dataloader(self): trainer.fit(model) -def test_trainer_model_hook_system(tmpdir): - """Test the LightningModule hook system.""" +class HookedModel(BoringModel): - class HookedModel(BoringModel): + def __init__(self): + super().__init__() + self.called = [] - def __init__(self): - super().__init__() - self.called = [] + def on_after_backward(self): + self.called.append("on_after_backward") + super().on_after_backward() - def on_after_backward(self): - self.called.append("on_after_backward") - super().on_after_backward() + def on_before_zero_grad(self, *args, **kwargs): + self.called.append("on_before_zero_grad") + super().on_before_zero_grad(*args, **kwargs) - def on_before_zero_grad(self, *args, **kwargs): - self.called.append("on_before_zero_grad") - super().on_before_zero_grad(*args, **kwargs) + def on_epoch_start(self): + self.called.append("on_epoch_start") + super().on_epoch_start() - def on_epoch_start(self): - self.called.append("on_epoch_start") - super().on_epoch_start() + def on_epoch_end(self): + self.called.append("on_epoch_end") + super().on_epoch_end() - def on_epoch_end(self): - self.called.append("on_epoch_end") - super().on_epoch_end() + def on_fit_start(self): + self.called.append("on_fit_start") + super().on_fit_start() - def on_fit_start(self): - self.called.append("on_fit_start") - super().on_fit_start() + def on_fit_end(self): + self.called.append("on_fit_end") + super().on_fit_end() - def on_fit_end(self): - self.called.append("on_fit_end") - super().on_fit_end() + def on_hpc_load(self, *args, **kwargs): + self.called.append("on_hpc_load") + super().on_hpc_load(*args, **kwargs) - def on_hpc_load(self, *args, **kwargs): - self.called.append("on_hpc_load") - super().on_hpc_load(*args, **kwargs) + def on_hpc_save(self, *args, **kwargs): + self.called.append("on_hpc_save") + super().on_hpc_save(*args, **kwargs) - def on_hpc_save(self, *args, **kwargs): - self.called.append("on_hpc_save") - super().on_hpc_save(*args, **kwargs) + def on_load_checkpoint(self, *args, **kwargs): + self.called.append("on_load_checkpoint") + super().on_load_checkpoint(*args, **kwargs) - def on_load_checkpoint(self, *args, **kwargs): - self.called.append("on_load_checkpoint") - super().on_load_checkpoint(*args, **kwargs) + def on_save_checkpoint(self, *args, **kwargs): + self.called.append("on_save_checkpoint") + super().on_save_checkpoint(*args, **kwargs) - def on_save_checkpoint(self, *args, **kwargs): - self.called.append("on_save_checkpoint") - super().on_save_checkpoint(*args, **kwargs) + def on_pretrain_routine_start(self): + self.called.append("on_pretrain_routine_start") + super().on_pretrain_routine_start() - def on_pretrain_routine_start(self): - self.called.append("on_pretrain_routine_start") - super().on_pretrain_routine_start() + def on_pretrain_routine_end(self): + self.called.append("on_pretrain_routine_end") + super().on_pretrain_routine_end() - def on_pretrain_routine_end(self): - self.called.append("on_pretrain_routine_end") - super().on_pretrain_routine_end() + def on_train_start(self): + self.called.append("on_train_start") + super().on_train_start() - def on_train_start(self): - self.called.append("on_train_start") - super().on_train_start() + def on_train_end(self): + self.called.append("on_train_end") + super().on_train_end() - def on_train_end(self): - self.called.append("on_train_end") - super().on_train_end() + def on_before_batch_transfer(self, *args, **kwargs): + self.called.append("on_before_batch_transfer") + return super().on_before_batch_transfer(*args, **kwargs) - def on_before_batch_transfer(self, *args, **kwargs): - self.called.append("on_before_batch_transfer") - return super().on_before_batch_transfer(*args, **kwargs) + def transfer_batch_to_device(self, *args, **kwargs): + self.called.append("transfer_batch_to_device") + return super().transfer_batch_to_device(*args, **kwargs) - def transfer_batch_to_device(self, *args, **kwargs): - self.called.append("transfer_batch_to_device") - return super().transfer_batch_to_device(*args, **kwargs) + def on_after_batch_transfer(self, *args, **kwargs): + self.called.append("on_after_batch_transfer") + return super().on_after_batch_transfer(*args, **kwargs) - def on_after_batch_transfer(self, *args, **kwargs): - self.called.append("on_after_batch_transfer") - return super().on_after_batch_transfer(*args, **kwargs) + def on_train_batch_start(self, *args, **kwargs): + self.called.append("on_train_batch_start") + super().on_train_batch_start(*args, **kwargs) - def on_train_batch_start(self, *args, **kwargs): - self.called.append("on_train_batch_start") - super().on_train_batch_start(*args, **kwargs) + def on_train_batch_end(self, *args, **kwargs): + self.called.append("on_train_batch_end") + super().on_train_batch_end(*args, **kwargs) - def on_train_batch_end(self, *args, **kwargs): - self.called.append("on_train_batch_end") - super().on_train_batch_end(*args, **kwargs) + def on_train_epoch_start(self): + self.called.append("on_train_epoch_start") + super().on_train_epoch_start() - def on_train_epoch_start(self): - self.called.append("on_train_epoch_start") - super().on_train_epoch_start() + def on_train_epoch_end(self): + self.called.append("on_train_epoch_end") + super().on_train_epoch_end() - def on_train_epoch_end(self): - self.called.append("on_train_epoch_end") - super().on_train_epoch_end() + def on_validation_start(self): + self.called.append("on_validation_start") + super().on_validation_start() - def on_validation_start(self): - self.called.append("on_validation_start") - super().on_validation_start() + def on_validation_end(self): + self.called.append("on_validation_end") + super().on_validation_end() - def on_validation_end(self): - self.called.append("on_validation_end") - super().on_validation_end() + def on_validation_batch_start(self, *args, **kwargs): + self.called.append("on_validation_batch_start") + super().on_validation_batch_start(*args, **kwargs) - def on_validation_batch_start(self, *args, **kwargs): - self.called.append("on_validation_batch_start") - super().on_validation_batch_start(*args, **kwargs) + def on_validation_batch_end(self, *args, **kwargs): + self.called.append("on_validation_batch_end") + super().on_validation_batch_end(*args, **kwargs) - def on_validation_batch_end(self, *args, **kwargs): - self.called.append("on_validation_batch_end") - super().on_validation_batch_end(*args, **kwargs) + def on_validation_epoch_start(self): + self.called.append("on_validation_epoch_start") + super().on_validation_epoch_start() - def on_validation_epoch_start(self): - self.called.append("on_validation_epoch_start") - super().on_validation_epoch_start() + def on_validation_epoch_end(self, *args, **kwargs): + self.called.append("on_validation_epoch_end") + super().on_validation_epoch_end(*args, **kwargs) - def on_validation_epoch_end(self, *args, **kwargs): - self.called.append("on_validation_epoch_end") - super().on_validation_epoch_end(*args, **kwargs) + def on_test_start(self): + self.called.append("on_test_start") + super().on_test_start() - def on_test_start(self): - self.called.append("on_test_start") - super().on_test_start() + def on_test_batch_start(self, *args, **kwargs): + self.called.append("on_test_batch_start") + super().on_test_batch_start(*args, **kwargs) - def on_test_batch_start(self, *args, **kwargs): - self.called.append("on_test_batch_start") - super().on_test_batch_start(*args, **kwargs) + def on_test_batch_end(self, *args, **kwargs): + self.called.append("on_test_batch_end") + super().on_test_batch_end(*args, **kwargs) - def on_test_batch_end(self, *args, **kwargs): - self.called.append("on_test_batch_end") - super().on_test_batch_end(*args, **kwargs) + def on_test_epoch_start(self): + self.called.append("on_test_epoch_start") + super().on_test_epoch_start() - def on_test_epoch_start(self): - self.called.append("on_test_epoch_start") - super().on_test_epoch_start() + def on_test_epoch_end(self, *args, **kwargs): + self.called.append("on_test_epoch_end") + super().on_test_epoch_end(*args, **kwargs) - def on_test_epoch_end(self, *args, **kwargs): - self.called.append("on_test_epoch_end") - super().on_test_epoch_end(*args, **kwargs) + def on_validation_model_eval(self): + self.called.append("on_validation_model_eval") + super().on_validation_model_eval() - def on_validation_model_eval(self): - self.called.append("on_validation_model_eval") - super().on_validation_model_eval() + def on_validation_model_train(self): + self.called.append("on_validation_model_train") + super().on_validation_model_train() - def on_validation_model_train(self): - self.called.append("on_validation_model_train") - super().on_validation_model_train() + def on_test_model_eval(self): + self.called.append("on_test_model_eval") + super().on_test_model_eval() - def on_test_model_eval(self): - self.called.append("on_test_model_eval") - super().on_test_model_eval() + def on_test_model_train(self): + self.called.append("on_test_model_train") + super().on_test_model_train() - def on_test_model_train(self): - self.called.append("on_test_model_train") - super().on_test_model_train() + def on_test_end(self): + self.called.append("on_test_end") + super().on_test_end() - def on_test_end(self): - self.called.append("on_test_end") - super().on_test_end() + def setup(self, stage=None): + self.called.append(f"setup_{stage}") + super().setup(stage=stage) - def setup(self, stage=None): - self.called.append(f"setup_{stage}") - super().setup(stage=stage) + def teardown(self, stage=None): + self.called.append(f"teardown_{stage}") + super().teardown(stage) - def teardown(self, stage=None): - self.called.append(f"teardown_{stage}") - super().teardown(stage) +def test_trainer_model_hook_system_fit(tmpdir): model = HookedModel() - - # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, @@ -406,9 +403,7 @@ def teardown(self, stage=None): progress_bar_refresh_rate=0, weights_summary=None, ) - assert model.called == [] - trainer.fit(model) expected = [ 'setup_fit', @@ -467,8 +462,19 @@ def teardown(self, stage=None): ] assert model.called == expected - model = HookedModel() +def test_trainer_model_hook_system_validate(tmpdir): + model = HookedModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=1, + limit_train_batches=2, + limit_test_batches=1, + progress_bar_refresh_rate=0, + weights_summary=None, + ) + assert model.called == [] trainer.validate(model, verbose=False) expected = [ 'setup_validate', @@ -489,9 +495,20 @@ def teardown(self, stage=None): ] assert model.called == expected + +def test_trainer_model_hook_system_test(tmpdir): model = HookedModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=1, + limit_train_batches=2, + limit_test_batches=1, + progress_bar_refresh_rate=0, + weights_summary=None, + ) + assert model.called == [] trainer.test(model, verbose=False) - expected = [ 'setup_test', 'on_test_model_eval', @@ -629,11 +646,24 @@ def on_after_batch_transfer(self, *args, **kwargs): trainer.fit(model, datamodule=dm) expected = [ - 'prepare_data', 'setup_fit', 'val_dataloader', 'on_before_batch_transfer', 'transfer_batch_to_device', - 'on_after_batch_transfer', 'train_dataloader', 'on_before_batch_transfer', 'transfer_batch_to_device', - 'on_after_batch_transfer', 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', - 'val_dataloader', 'on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer', - 'teardown_fit' + 'prepare_data', + 'setup_fit', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'train_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_fit', ] assert dm.called == expected @@ -641,8 +671,13 @@ def on_after_batch_transfer(self, *args, **kwargs): trainer.validate(model, datamodule=dm, verbose=False) expected = [ - 'prepare_data', 'setup_validate', 'val_dataloader', 'on_before_batch_transfer', 'transfer_batch_to_device', - 'on_after_batch_transfer', 'teardown_validate' + 'prepare_data', + 'setup_validate', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_validate', ] assert dm.called == expected @@ -650,7 +685,12 @@ def on_after_batch_transfer(self, *args, **kwargs): trainer.test(model, datamodule=dm, verbose=False) expected = [ - 'prepare_data', 'setup_test', 'test_dataloader', 'on_before_batch_transfer', 'transfer_batch_to_device', - 'on_after_batch_transfer', 'teardown_test' + 'prepare_data', + 'setup_test', + 'test_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_test', ] assert dm.called == expected