Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove memory-retaining epoch-end hooks #16520

Merged
merged 15 commits into from
Feb 6, 2023

Conversation

carmocca
Copy link
Contributor

@carmocca carmocca commented Jan 26, 2023

Migration guide

training_epoch_end -> on_train_epoch_end
 class MyLightningModule(L.LightningModule):
+    def __init__(self):
+        super().__init__()
+        self.training_step_outputs = []

     def training_step(self, ...):
         loss = ...
+        self.training_step_outputs.append(loss)
         return loss

-    def training_epoch_end(self, outputs):
-        epoch_average = torch.stack([output["loss"] for output in outputs]).mean()
+    def on_train_epoch_end(self):
+        epoch_average = torch.stack(self.training_step_outputs).mean()
         self.log("training_epoch_average", epoch_average)
+        self.training_step_outputs.clear()  # free memory

The same suggestions apply to those implementing Callback.training_epoch_end

validation_epoch_end -> on_validation_epoch_end
 class MyLightningModule(L.LightningModule):
+    def __init__(self):
+        super().__init__()
+        self.validation_step_outputs = []

     def validation_step(self, ...):
         loss = ...
+        self.validation_step_outputs.append(loss)
         return loss

-    def validation_epoch_end(self, outputs):
-        epoch_average = torch.stack(outputs).mean()
+    def on_validation_epoch_end(self):
+        epoch_average = torch.stack(self.validation_step_outputs).mean()
         self.log("validation_epoch_average", epoch_average)
+        self.validation_step_outputs.clear()  # free memory

The same suggestions apply to those implementing Callback.validation_epoch_end

test_epoch_end -> on_test_epoch_end
 class MyLightningModule(L.LightningModule):
+    def __init__(self):
+        super().__init__()
+        self.test_step_outputs = []

     def test_step(self, ...):
         loss = ...
+        self.test_step_outputs.append(loss)
         return loss

-    def test_epoch_end(self, outputs):
-        epoch_average = torch.stack(outputs).mean()
+    def on_test_epoch_end(self):
+        epoch_average = torch.stack(self.test_step_outputs).mean()
         self.log("test_epoch_average", epoch_average)
+        self.test_step_outputs.clear()  # free memory

The same suggestions apply to those implementing Callback.test_epoch_end

Example with two DataLoaders
 class MyLightningModule(L.LightningModule):
+    def __init__(self):
+        super().__init__()
+        self.test_step_outputs = [[], []]  # two dataloaders

     def test_step(self, batch, batch_idx, dataloader_idx=0):
         loss = ...
+        self.test_step_outputs[dataloader_idx].append(loss)
         return loss

-    def test_epoch_end(self, outputs):
+    def on_test_epoch_end(self):
-        for dl_idx in range(len(outputs)):
+        for dl_idx in range(len(self.test_step_outputs)):
-            dataloader_epoch_average = torch.stack(outputs[dl_idx]).mean()
+            dataloader_epoch_average = torch.stack(self.test_step_outputs[dl_idx]).mean()
             self.log(f"test_epoch_average_dl_{dl_idx}", dataloader_epoch_average)
-            outputs[dl_idx].clear()
+            self.test_step_outputs[dl_idx].clear()

     def test_dataloader(self):
         dl1 = DataLoader(RandomDataset(32, 64), batch_size=2)
         dl2 = DataLoader(RandomDataset(32, 64), batch_size=2)
         return dl1, dl2
Example with strategy="dp" (DataParallel)
 class MyLightningModule(L.LightningModule):
+    def __init__(self):
+        super().__init__()
+        self.training_step_outputs = []
+        self.validation_step_outputs = []

     def training_step(self, batch, batch_idx):
         output = ...
         return output
 
     def validation_step(self, batch, batch_idx):
         output = ...
         return output

+    def training_step_end(self, training_step_output):
+        training_step_output = self.trainer.strategy.reduce(training_step_output)
+        self.training_step_outputs.append(training_step_output)
+        return training_step_output

+    def validation_step_end(self, validation_step_output):
+        self.validation_step_outputs.append(validation_step_output)
 
-    def training_epoch_end(self, outputs):
-        epoch_average = torch.stack([output["loss"] for output in outputs]).mean()
+    def on_train_epoch_end(self):
+        epoch_average = torch.stack(self.training_step_outputs).mean()
         self.log("training_epoch_average", epoch_average)
+        self.training_step_outputs.clear()  # free memory
 
-    def validation_epoch_end(self, outputs):
+    def on_validation_epoch_end(self):
         epoch_average = torch.stack(self.validation_step_outputs).mean()
         self.log("validation_epoch_average", epoch_average)
+        self.validation_step_outputs.clear()  # free memory

If you have questions about how to migrate your use case, you can ask in this PR.

What does this PR do?

Removes the training_epoch_end, validation_epoch_end, and test_epoch_end hooks.
In favor of on_train_epoch_end, on_validation_epoch_end, and on_test_epoch_end.

These hooks were becoming problematic as just implementing them could lead to memory issues if the user was unaware of their implementation.
They also increased the loop's complexity and were hard to hack or customize externally.

At runtime, we check whether the old hooks are overridden, and fail if they are with an error message that points to the migration guide above

Blocked by #16567

Fixes #8731
Closes #9380
Closes #9968
Closes #10878
Closes #11554

Follow-up things to address:
#8479: need to remove outputs from on_predict_epoch_end

Does your PR introduce any breaking changes? If yes, please list them.

Removes the hooks described above.

cc @Borda @justusschock @carmocca @awaelchli

@carmocca carmocca self-assigned this Jan 26, 2023
@github-actions github-actions bot added app (removed) Generic label for Lightning App package pl Generic label for PyTorch Lightning package labels Jan 26, 2023
@carmocca carmocca added breaking change Includes a breaking change lightningmodule pl.LightningModule hooks Related to the hooks API and removed app (removed) Generic label for Lightning App package labels Jan 26, 2023
@carmocca carmocca force-pushed the refactor/epoch-end-hook-removal branch from 2f301a1 to 55a5a51 Compare January 26, 2023 18:17
@github-actions github-actions bot added the app (removed) Generic label for Lightning App package label Jan 26, 2023
@carmocca carmocca added this to the 2.0 milestone Jan 30, 2023
@carmocca carmocca marked this pull request as ready for review January 30, 2023 17:55
@github-actions
Copy link
Contributor

github-actions bot commented Jan 30, 2023

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-11, lightning, 3.8, 1.11) success
pl-cpu (macOS-11, lightning, 3.9, 1.12) success
pl-cpu (macOS-11, lightning, 3.10, 1.13) success
pl-cpu (macOS-11, lightning, 3.8, 1.11, oldest) success
pl-cpu (ubuntu-20.04, lightning, 3.9, 1.11) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 1.12) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.11, oldest) success
pl-cpu (windows-2022, lightning, 3.9, 1.11) success
pl-cpu (windows-2022, lightning, 3.10, 1.12) success
pl-cpu (windows-2022, lightning, 3.10, 1.13) success
pl-cpu (windows-2022, lightning, 3.8, 1.11, oldest) success
pl-cpu (slow, macOS-11, lightning, 3.8, 1.11) success
pl-cpu (slow, ubuntu-20.04, lightning, 3.8, 1.11) success
pl-cpu (slow, windows-2022, lightning, 3.8, 1.11) success
pl-cpu (macOS-11, pytorch, 3.8, 1.13) success
pl-cpu (ubuntu-20.04, pytorch, 3.8, 1.13) success
pl-cpu (windows-2022, pytorch, 3.8, 1.13) success

These checks are required after the changes to src/lightning/pytorch/callbacks/callback.py, src/lightning/pytorch/core/hooks.py, src/lightning/pytorch/core/module.py, src/lightning/pytorch/demos/boring_classes.py, src/lightning/pytorch/loops/dataloader/evaluation_loop.py, src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py, src/lightning/pytorch/loops/epoch/training_epoch_loop.py, src/lightning/pytorch/loops/fit_loop.py, src/lightning/pytorch/loops/optimization/manual.py, src/lightning/pytorch/trainer/configuration_validator.py, src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py, src/lightning/pytorch/trainer/trainer.py, src/lightning/pytorch/utilities/types.py, tests/tests_pytorch/accelerators/test_ipu.py, tests/tests_pytorch/accelerators/test_tpu.py, tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py, tests/tests_pytorch/callbacks/test_callback_hook_outputs.py, tests/tests_pytorch/callbacks/test_lr_monitor.py, tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py, tests/tests_pytorch/checkpointing/test_model_checkpoint.py, tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py, tests/tests_pytorch/core/test_datamodules.py, tests/tests_pytorch/core/test_lightning_module.py, tests/tests_pytorch/core/test_lightning_optimizer.py, tests/tests_pytorch/helpers/deterministic_model.py, tests/tests_pytorch/loggers/test_all.py, tests/tests_pytorch/loggers/test_logger.py, tests/tests_pytorch/loggers/test_neptune.py, tests/tests_pytorch/loggers/test_tensorboard.py, tests/tests_pytorch/loops/optimization/test_optimizer_loop.py, tests/tests_pytorch/loops/test_evaluation_loop.py, tests/tests_pytorch/loops/test_evaluation_loop_flow.py, tests/tests_pytorch/loops/test_flow_warnings.py, tests/tests_pytorch/loops/test_loops.py, tests/tests_pytorch/loops/test_training_loop.py, tests/tests_pytorch/loops/test_training_loop_flow_dict.py, tests/tests_pytorch/loops/test_training_loop_flow_scalar.py, tests/tests_pytorch/models/test_hooks.py, tests/tests_pytorch/plugins/test_double_plugin.py, tests/tests_pytorch/strategies/test_deepspeed_strategy.py, tests/tests_pytorch/strategies/test_dp.py, tests/tests_pytorch/trainer/connectors/test_data_connector.py, tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py, tests/tests_pytorch/trainer/flags/test_fast_dev_run.py, tests/tests_pytorch/trainer/flags/test_min_max_epochs.py, tests/tests_pytorch/trainer/logging_/test_distributed_logging.py, tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py, tests/tests_pytorch/trainer/logging_/test_logger_connector.py, tests/tests_pytorch/trainer/logging_/test_loop_logging.py, tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py, tests/tests_pytorch/trainer/optimization/test_manual_optimization.py, tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py, tests/tests_pytorch/trainer/optimization/test_optimizers.py, tests/tests_pytorch/trainer/test_config_validator.py, tests/tests_pytorch/trainer/test_dataloaders.py, tests/tests_pytorch/trainer/test_trainer.py, tests/tests_pytorch/tuner/test_scale_batch_size.py, tests/tests_pytorch/utilities/test_all_gather_grad.py, tests/tests_pytorch/utilities/test_auto_restart.py, tests/tests_pytorch/utilities/test_fetching.py.

🟢 pytorch_lightning: Azure GPU
Check ID Status
pytorch-lightning (GPUs) success

These checks are required after the changes to src/lightning/pytorch/callbacks/callback.py, src/lightning/pytorch/core/hooks.py, src/lightning/pytorch/core/module.py, src/lightning/pytorch/demos/boring_classes.py, src/lightning/pytorch/loops/dataloader/evaluation_loop.py, src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py, src/lightning/pytorch/loops/epoch/training_epoch_loop.py, src/lightning/pytorch/loops/fit_loop.py, src/lightning/pytorch/loops/optimization/manual.py, src/lightning/pytorch/trainer/configuration_validator.py, src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py, src/lightning/pytorch/trainer/trainer.py, src/lightning/pytorch/utilities/types.py, tests/tests_pytorch/accelerators/test_ipu.py, tests/tests_pytorch/accelerators/test_tpu.py, tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py, tests/tests_pytorch/callbacks/test_callback_hook_outputs.py, tests/tests_pytorch/callbacks/test_lr_monitor.py, tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py, tests/tests_pytorch/checkpointing/test_model_checkpoint.py, tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py, tests/tests_pytorch/core/test_datamodules.py, tests/tests_pytorch/core/test_lightning_module.py, tests/tests_pytorch/core/test_lightning_optimizer.py, tests/tests_pytorch/helpers/deterministic_model.py, tests/tests_pytorch/loggers/test_all.py, tests/tests_pytorch/loggers/test_logger.py, tests/tests_pytorch/loggers/test_neptune.py, tests/tests_pytorch/loggers/test_tensorboard.py, tests/tests_pytorch/loops/optimization/test_optimizer_loop.py, tests/tests_pytorch/loops/test_evaluation_loop.py, tests/tests_pytorch/loops/test_evaluation_loop_flow.py, tests/tests_pytorch/loops/test_flow_warnings.py, tests/tests_pytorch/loops/test_loops.py, tests/tests_pytorch/loops/test_training_loop.py, tests/tests_pytorch/loops/test_training_loop_flow_dict.py, tests/tests_pytorch/loops/test_training_loop_flow_scalar.py, tests/tests_pytorch/models/test_hooks.py, tests/tests_pytorch/plugins/test_double_plugin.py, tests/tests_pytorch/strategies/test_deepspeed_strategy.py, tests/tests_pytorch/strategies/test_dp.py, tests/tests_pytorch/trainer/connectors/test_data_connector.py, tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py, tests/tests_pytorch/trainer/flags/test_fast_dev_run.py, tests/tests_pytorch/trainer/flags/test_min_max_epochs.py, tests/tests_pytorch/trainer/logging_/test_distributed_logging.py, tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py, tests/tests_pytorch/trainer/logging_/test_logger_connector.py, tests/tests_pytorch/trainer/logging_/test_loop_logging.py, tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py, tests/tests_pytorch/trainer/optimization/test_manual_optimization.py, tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py, tests/tests_pytorch/trainer/optimization/test_optimizers.py, tests/tests_pytorch/trainer/test_config_validator.py, tests/tests_pytorch/trainer/test_dataloaders.py, tests/tests_pytorch/trainer/test_trainer.py, tests/tests_pytorch/tuner/test_scale_batch_size.py, tests/tests_pytorch/utilities/test_all_gather_grad.py, tests/tests_pytorch/utilities/test_auto_restart.py, tests/tests_pytorch/utilities/test_fetching.py.

🟢 pytorch_lightning: Azure HPU
Check ID Status
pytorch-lightning (HPUs) success

These checks are required after the changes to src/lightning/pytorch/callbacks/callback.py, src/lightning/pytorch/core/hooks.py, src/lightning/pytorch/core/module.py, src/lightning/pytorch/demos/boring_classes.py, src/lightning/pytorch/loops/dataloader/evaluation_loop.py, src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py, src/lightning/pytorch/loops/epoch/training_epoch_loop.py, src/lightning/pytorch/loops/fit_loop.py, src/lightning/pytorch/loops/optimization/manual.py, src/lightning/pytorch/trainer/configuration_validator.py, src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py, src/lightning/pytorch/trainer/trainer.py, src/lightning/pytorch/utilities/types.py, tests/tests_pytorch/accelerators/test_ipu.py, tests/tests_pytorch/accelerators/test_tpu.py, tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py, tests/tests_pytorch/callbacks/test_callback_hook_outputs.py, tests/tests_pytorch/callbacks/test_lr_monitor.py, tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py, tests/tests_pytorch/checkpointing/test_model_checkpoint.py, tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py, tests/tests_pytorch/core/test_datamodules.py, tests/tests_pytorch/core/test_lightning_module.py, tests/tests_pytorch/core/test_lightning_optimizer.py, tests/tests_pytorch/helpers/deterministic_model.py, tests/tests_pytorch/loggers/test_all.py, tests/tests_pytorch/loggers/test_logger.py, tests/tests_pytorch/loggers/test_neptune.py, tests/tests_pytorch/loggers/test_tensorboard.py, tests/tests_pytorch/loops/optimization/test_optimizer_loop.py, tests/tests_pytorch/loops/test_evaluation_loop.py, tests/tests_pytorch/loops/test_evaluation_loop_flow.py, tests/tests_pytorch/loops/test_flow_warnings.py, tests/tests_pytorch/loops/test_loops.py, tests/tests_pytorch/loops/test_training_loop.py, tests/tests_pytorch/loops/test_training_loop_flow_dict.py, tests/tests_pytorch/loops/test_training_loop_flow_scalar.py, tests/tests_pytorch/models/test_hooks.py, tests/tests_pytorch/plugins/test_double_plugin.py, tests/tests_pytorch/strategies/test_deepspeed_strategy.py, tests/tests_pytorch/strategies/test_dp.py, tests/tests_pytorch/trainer/connectors/test_data_connector.py, tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py, tests/tests_pytorch/trainer/flags/test_fast_dev_run.py, tests/tests_pytorch/trainer/flags/test_min_max_epochs.py, tests/tests_pytorch/trainer/logging_/test_distributed_logging.py, tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py, tests/tests_pytorch/trainer/logging_/test_logger_connector.py, tests/tests_pytorch/trainer/logging_/test_loop_logging.py, tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py, tests/tests_pytorch/trainer/optimization/test_manual_optimization.py, tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py, tests/tests_pytorch/trainer/optimization/test_optimizers.py, tests/tests_pytorch/trainer/test_config_validator.py, tests/tests_pytorch/trainer/test_dataloaders.py, tests/tests_pytorch/trainer/test_trainer.py, tests/tests_pytorch/tuner/test_scale_batch_size.py, tests/tests_pytorch/utilities/test_all_gather_grad.py, tests/tests_pytorch/utilities/test_auto_restart.py, tests/tests_pytorch/utilities/test_fetching.py.

🟢 pytorch_lightning: Azure IPU
Check ID Status
pytorch-lightning (IPUs) success

These checks are required after the changes to src/lightning/pytorch/callbacks/callback.py, src/lightning/pytorch/core/hooks.py, src/lightning/pytorch/core/module.py, src/lightning/pytorch/demos/boring_classes.py, src/lightning/pytorch/loops/dataloader/evaluation_loop.py, src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py, src/lightning/pytorch/loops/epoch/training_epoch_loop.py, src/lightning/pytorch/loops/fit_loop.py, src/lightning/pytorch/loops/optimization/manual.py, src/lightning/pytorch/trainer/configuration_validator.py, src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py, src/lightning/pytorch/trainer/trainer.py, src/lightning/pytorch/utilities/types.py, tests/tests_pytorch/accelerators/test_ipu.py, tests/tests_pytorch/accelerators/test_tpu.py, tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py, tests/tests_pytorch/callbacks/test_callback_hook_outputs.py, tests/tests_pytorch/callbacks/test_lr_monitor.py, tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py, tests/tests_pytorch/checkpointing/test_model_checkpoint.py, tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py, tests/tests_pytorch/core/test_datamodules.py, tests/tests_pytorch/core/test_lightning_module.py, tests/tests_pytorch/core/test_lightning_optimizer.py, tests/tests_pytorch/helpers/deterministic_model.py, tests/tests_pytorch/loggers/test_all.py, tests/tests_pytorch/loggers/test_logger.py, tests/tests_pytorch/loggers/test_neptune.py, tests/tests_pytorch/loggers/test_tensorboard.py, tests/tests_pytorch/loops/optimization/test_optimizer_loop.py, tests/tests_pytorch/loops/test_evaluation_loop.py, tests/tests_pytorch/loops/test_evaluation_loop_flow.py, tests/tests_pytorch/loops/test_flow_warnings.py, tests/tests_pytorch/loops/test_loops.py, tests/tests_pytorch/loops/test_training_loop.py, tests/tests_pytorch/loops/test_training_loop_flow_dict.py, tests/tests_pytorch/loops/test_training_loop_flow_scalar.py, tests/tests_pytorch/models/test_hooks.py, tests/tests_pytorch/plugins/test_double_plugin.py, tests/tests_pytorch/strategies/test_deepspeed_strategy.py, tests/tests_pytorch/strategies/test_dp.py, tests/tests_pytorch/trainer/connectors/test_data_connector.py, tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py, tests/tests_pytorch/trainer/flags/test_fast_dev_run.py, tests/tests_pytorch/trainer/flags/test_min_max_epochs.py, tests/tests_pytorch/trainer/logging_/test_distributed_logging.py, tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py, tests/tests_pytorch/trainer/logging_/test_logger_connector.py, tests/tests_pytorch/trainer/logging_/test_loop_logging.py, tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py, tests/tests_pytorch/trainer/optimization/test_manual_optimization.py, tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py, tests/tests_pytorch/trainer/optimization/test_optimizers.py, tests/tests_pytorch/trainer/test_config_validator.py, tests/tests_pytorch/trainer/test_dataloaders.py, tests/tests_pytorch/trainer/test_trainer.py, tests/tests_pytorch/tuner/test_scale_batch_size.py, tests/tests_pytorch/utilities/test_all_gather_grad.py, tests/tests_pytorch/utilities/test_auto_restart.py, tests/tests_pytorch/utilities/test_fetching.py.

🟢 pytorch_lightning: Docs
Check ID Status
make-doctest (pytorch) success
make-html (pytorch) success

These checks are required after the changes to src/lightning/pytorch/callbacks/callback.py, src/lightning/pytorch/core/hooks.py, src/lightning/pytorch/core/module.py, src/lightning/pytorch/demos/boring_classes.py, src/lightning/pytorch/loops/dataloader/evaluation_loop.py, src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py, src/lightning/pytorch/loops/epoch/training_epoch_loop.py, src/lightning/pytorch/loops/fit_loop.py, src/lightning/pytorch/loops/optimization/manual.py, src/lightning/pytorch/trainer/configuration_validator.py, src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py, src/lightning/pytorch/trainer/trainer.py, src/lightning/pytorch/utilities/types.py, docs/source-pytorch/accelerators/accelerator_prepare.rst, docs/source-pytorch/common/lightning_module.rst, docs/source-pytorch/extensions/logging.rst, docs/source-pytorch/model/manual_optimization.rst, docs/source-pytorch/starter/style_guide.rst, docs/source-pytorch/visualize/logging_advanced.rst.

🟢 lightning_app: Tests workflow
Check ID Status
app-pytest (macOS-11, lightning, 3.8, latest) success
app-pytest (macOS-11, lightning, 3.8, oldest) success
app-pytest (macOS-11, app, 3.9, latest) success
app-pytest (ubuntu-20.04, lightning, 3.8, latest) success
app-pytest (ubuntu-20.04, lightning, 3.8, oldest) success
app-pytest (ubuntu-20.04, app, 3.9, latest) success
app-pytest (windows-2022, lightning, 3.8, latest) success
app-pytest (windows-2022, lightning, 3.8, oldest) success
app-pytest (windows-2022, app, 3.8, latest) success

These checks are required after the changes to src/lightning/app/utilities/introspection.py.

🟢 lightning_app: Examples
Check ID Status
app-examples (macOS-11, lightning, 3.9, latest) success
app-examples (macOS-11, lightning, 3.9, oldest) success
app-examples (macOS-11, app, 3.9, latest) success
app-examples (ubuntu-20.04, lightning, 3.9, latest) success
app-examples (ubuntu-20.04, lightning, 3.9, oldest) success
app-examples (ubuntu-20.04, app, 3.9, latest) success
app-examples (windows-2022, lightning, 3.9, latest) success
app-examples (windows-2022, lightning, 3.9, oldest) success
app-examples (windows-2022, app, 3.9, latest) success

These checks are required after the changes to src/lightning/app/utilities/introspection.py.

🟢 lightning_app: Azure
Check ID Status
App.cloud-e2e success

These checks are required after the changes to src/lightning/app/utilities/introspection.py.

🟢 lightning_app: Docs
Check ID Status
make-doctest (app) success
make-html (app) success

These checks are required after the changes to src/lightning/app/utilities/introspection.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/app/utilities/introspection.py, src/lightning/pytorch/callbacks/callback.py, src/lightning/pytorch/core/hooks.py, src/lightning/pytorch/core/module.py, src/lightning/pytorch/demos/boring_classes.py, src/lightning/pytorch/loops/dataloader/evaluation_loop.py, src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py, src/lightning/pytorch/loops/epoch/training_epoch_loop.py, src/lightning/pytorch/loops/fit_loop.py, src/lightning/pytorch/loops/optimization/manual.py, src/lightning/pytorch/trainer/configuration_validator.py, src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py, src/lightning/pytorch/trainer/trainer.py, src/lightning/pytorch/utilities/types.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.10) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.10) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.10) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.10) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.10) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.10) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.10) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.10) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.10) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.10) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.10) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.10) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.10) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.10) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.10) success

These checks are required after the changes to src/lightning/app/utilities/introspection.py, src/lightning/pytorch/callbacks/callback.py, src/lightning/pytorch/core/hooks.py, src/lightning/pytorch/core/module.py, src/lightning/pytorch/demos/boring_classes.py, src/lightning/pytorch/loops/dataloader/evaluation_loop.py, src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py, src/lightning/pytorch/loops/epoch/training_epoch_loop.py, src/lightning/pytorch/loops/fit_loop.py, src/lightning/pytorch/loops/optimization/manual.py, src/lightning/pytorch/trainer/configuration_validator.py, src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py, src/lightning/pytorch/trainer/trainer.py, src/lightning/pytorch/utilities/types.py.

🟢 link-check
Check ID Status
markdown-link-check success

These checks are required after the changes to src/lightning/pytorch/CHANGELOG.md.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

ddelange added a commit to ddelange/autogluon that referenced this pull request May 4, 2023
yinweisu pushed a commit to autogluon/autogluon that referenced this pull request May 8, 2023
pallaviyn referenced this pull request in talhaanwarch/youtube-tutorials Jun 2, 2023
RuslanSergeev added a commit to RuslanSergeev/im2height that referenced this pull request May 27, 2024
@edmcman
Copy link

edmcman commented Oct 16, 2024

Is there any way to get the old behavior without adding boiler plate?

@42elenz
Copy link

42elenz commented Oct 23, 2024

I am wondering the following:
I Am using multiple GPU - Training so I assumed that I am using some sort of DP. This is the reason I followed the DP-example. This is my implementation:

class Contrastive_Training_Model(MultimodalBasis):
def init(self, hparams, fold=''):
super().init(hparams)

    self.save_hyperparameters(hparams)
    self.train_criterion_type = hparams.model.clip_pretrain_train_criterion
    self.label_type_fn = hparams.data.label_type_false_negative_class
    self.train_criterion, self.validation_criterion = choose_contrastive_loss_fct(self.train_criterion_type, hparams)
    self.debugging = hparams.logging.debug_level
    self.correct_val_ids_file = hparams.logging.correct_val_ids_file
    self.fold = fold
    self.training_step_outputs = []
    self.validation_step_outputs = []

    #sanity_check(self.train_criterion_type, self.label_type_fn)

#This is called automatically in trainer class
def training_step(self, batch, batch_idx):
    """
    Trains contrastive model
    """
    train_mri = batch['cor_mri'] #can be courrpted depends on the settings
    train_questionnaire = batch['cor_questionnaire'] #can be courrpted depends on the settings
    mri_data = batch['mri']
    questionnaire_data = batch['questionnaire']
    id = batch['id']
    fn_label_class = batch['fn_label_class']
    ds_label_class = batch['ds_label_class']

    if self.label_type_fn == "binary":
        fn_label_class = fn_label_class.bool()
    mri_embeddings_projected = self.forward_mri(train_mri) 
    questionnaire_embeddings_projected = self.forward_quest(train_questionnaire)
    loss, logits, labels = self.train_criterion(mri_embeddings_projected, questionnaire_embeddings_projected, fn_label_class)
    #self.training_step_outputs.append({'loss':loss, 'logits': logits, 'labels': labels, "ds_label_class": ds_label_class, 'ID': id, 'mri_embeddings': mri_embeddings_projected, 'questionnaire_embeddings': questionnaire_embeddings_projected, 'questionaire_data': questionnaire_data})
    # self.log(f"multimodal.train.loss", loss, on_epoch=True, on_step=False) implemente later
    #if len(im_views[0])==self.hparams.batch_size:
    #self.calc_and_log_train_embedding_acc(logits=logits, labels=labels, modality='multimodal')

    return {'loss':loss,
            'logits': logits, 
            "ds_label_class": ds_label_class,
            'fn_label_class': fn_label_class, 
            'ID': id,
            'mri_embeddings': mri_embeddings_projected,
            'questionnaire_embeddings': questionnaire_embeddings_projected,
            'questionaire_data': questionnaire_data,}

def training_step_end(self, training_step_output):
    training_step_output = self.trainer.strategy.reduce(training_step_output)
    self.training_step_outputs.append(training_step_output)
    return training_step_output


def validation_step(self, batch, batch_idx):
    """
    Validates contrastive model
    """
    val_mri = batch['cor_mri']
    val_questionnaire = batch['cor_questionnaire']
    mri_data = batch['mri']
    questionnaire_data = batch['questionnaire']
    id = batch['id']
    fn_label_class = batch['fn_label_class']
    ds_label_class = batch['ds_label_class']

    mri_embeddings_projected = self.forward_mri(val_mri)
    questionnaire_embeddings_projected = self.forward_quest(val_questionnaire)
    loss, logits, quest_logits, labels = self.validation_criterion(mri_embeddings_projected, questionnaire_embeddings_projected, fn_label_class)
    #self.validation_step_outputs.append({'loss':loss, 'mri_logits': logits, 'quest_logits': quest_logits,'logits': logits, "ds_label_class": ds_label_class, 'fn_label_class': fn_label_class, 'ID': id, 'mri_embeddings': mri_embeddings_projected, 'questionnaire_embeddings': questionnaire_embeddings_projected,'questionaire_data': questionnaire_data})
    return {'loss':loss,
            'mri_logits': logits,
            'quest_logits': quest_logits,
            'logits': logits,
            "ds_label_class": ds_label_class,
            'fn_label_class': fn_label_class, 
            'ID': id,
            'mri_embeddings': mri_embeddings_projected,
            'questionnaire_embeddings': questionnaire_embeddings_projected,
            'questionaire_data': questionnaire_data,}

def validation_step_end(self, validation_step_output):
    self.validation_step_outputs.append(validation_step_output)

#At the end of the epcoh. All outputs are in a list.
def on_train_epoch_end(self):
    train_outputs = self.training_step_outputs
    epoch_loss, epoch_accuracy = evaluation_of_contrastive_outputs(train_outputs,self.debugging, evaluation_type="train")
    fold = self.fold
    self.log(f"cont.train.loss{fold}", epoch_loss, on_epoch=True, on_step=False)
    self.log(f"cont.train.acc{fold}", epoch_accuracy, on_epoch=True, on_step=False)
    self.training_step_outputs.clear()

def on_validation_epoch_end(self):
    val_outputs = self.validation_step_outputs
    fold = self.fold
    epoch_loss, epoch_accuracy, mri_accuracy, questionnaire_accuracy, mri_quest_accuracy = evaluation_of_contrastive_outputs(val_outputs,self.debugging, evaluation_type="validation", correct_val_ids_file=self.correct_val_ids_file)
    self.log(f"cont.val.loss", epoch_loss, on_epoch=True, on_step=False)
    self.log(f"cont.val.accuracy{fold}", epoch_accuracy, on_epoch=True, on_step=False)
    self.log(f"cont.val.mri_accuracy{fold}", mri_accuracy, on_epoch=True, on_step=False)
    self.log(f"cont.val.questionnaire_accuracy{fold}", questionnaire_accuracy, on_epoch=True, on_step=False)
    self.log(f"cont.val.mri_questionnaire_accuracy{fold}", mri_quest_accuracy, on_epoch=True, on_step=False)
    self.validation_step_outputs.clear()
    
def configure_optimizers(self):
    optimizer = torch.optim.Adam(
        self.parameters(), 
        lr=self.hparams.training.lr,
        weight_decay=self.hparams.training.weight_decay)
    return optimizer`
```



**Unfourtnatly in the prerrunning, when the basic validation is calculated, the valdiation_step_end() is not called. So I get an error (Division by Zero in the evaluation_of_contrastive_outputs).

What else can I do than just check for this case?**

@bnestor
Copy link

bnestor commented Dec 20, 2024

@42elenz I had to add self.validation_step_outputs.append(validation_step_output) to my validation_step. Then everything worked fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
app (removed) Generic label for Lightning App package breaking change Includes a breaking change hooks Related to the hooks API lightningmodule pl.LightningModule pl Generic label for PyTorch Lightning package ready PRs ready to be merged
Projects
None yet