Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Trainer.init_module and LightningModule.configure_model #18004

Merged
merged 14 commits into from
Jul 14, 2023

Conversation

carmocca
Copy link
Contributor

@carmocca carmocca commented Jul 6, 2023

What does this PR do?

Port of #17462 for the Trainer

Key implementation differences:

  • No init_context support for DeepSpeedPrecisionPlugin because it doesn't support true precision (unlike Fabric)
  • Trainer only exposes init_module which functionally matches fabric.init_tensor because otherwise it would require that processes have been launched when it is called, but that won't happen until trainer.fit().
  • [NEW HOOK] Instead of Trainer.init_module, users are suggested LightningModule.configure_model with sharded strategies
  • LightningModule.configure_sharded_model is deprecated in favor of LightningModule.configure_model

Generated docs: https://pytorch-lightning--18004.org.readthedocs.build/

Enables Lightning-AI/litgpt#228

cc @Borda @awaelchli @carmocca @justusschock

@carmocca carmocca added feature Is an improvement or enhancement strategy: deepspeed trainer strategy strategy: fsdp Fully Sharded Data Parallel pl Generic label for PyTorch Lightning package labels Jul 6, 2023
@carmocca carmocca added this to the 2.1 milestone Jul 6, 2023
@carmocca carmocca self-assigned this Jul 6, 2023
@github-actions github-actions bot added fabric lightning.fabric.Fabric app (removed) Generic label for Lightning App package labels Jul 6, 2023
@carmocca carmocca force-pushed the carmocca/trainer-init-module branch from ceb470f to 7ffddf2 Compare July 6, 2023 18:27
@github-actions github-actions bot removed the app (removed) Generic label for Lightning App package label Jul 6, 2023
@carmocca carmocca force-pushed the carmocca/trainer-init-module branch from 2535351 to f0939ea Compare July 6, 2023 18:57
@carmocca carmocca force-pushed the carmocca/trainer-init-module branch from dbf2f54 to eb942ea Compare July 6, 2023 19:39
@carmocca carmocca changed the title Add trainer.init_module and trainer.init_tensor Add trainer.init_module Jul 10, 2023
@carmocca carmocca force-pushed the carmocca/trainer-init-module branch 7 times, most recently from 6fed2f8 to 4613d7c Compare July 11, 2023 00:53
@carmocca carmocca force-pushed the carmocca/trainer-init-module branch from b7562a3 to 49629af Compare July 11, 2023 01:18
@github-actions
Copy link
Contributor

github-actions bot commented Jul 13, 2023

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-11, lightning, 3.8, 1.11) success
pl-cpu (macOS-11, lightning, 3.9, 1.12) success
pl-cpu (macOS-11, lightning, 3.10, 1.13) success
pl-cpu (macOS-11, lightning, 3.10, 2.0) success
pl-cpu (macOS-11, lightning, 3.8, 1.11, oldest) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.11) success
pl-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.11, oldest) success
pl-cpu (windows-2022, lightning, 3.8, 1.11) success
pl-cpu (windows-2022, lightning, 3.9, 1.12) success
pl-cpu (windows-2022, lightning, 3.10, 1.13) success
pl-cpu (windows-2022, lightning, 3.10, 2.0) success
pl-cpu (windows-2022, lightning, 3.8, 1.11, oldest) success
pl-cpu (macOS-11, pytorch, 3.8, 1.13) success
pl-cpu (ubuntu-20.04, pytorch, 3.8, 1.13) success
pl-cpu (windows-2022, pytorch, 3.8, 1.13) success

These checks are required after the changes to src/lightning/pytorch/core/hooks.py, src/lightning/pytorch/plugins/precision/double.py, src/lightning/pytorch/plugins/precision/fsdp.py, src/lightning/pytorch/strategies/deepspeed.py, src/lightning/pytorch/strategies/fsdp.py, src/lightning/pytorch/strategies/strategy.py, src/lightning/pytorch/trainer/call.py, src/lightning/pytorch/trainer/configuration_validator.py, src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py, src/lightning/pytorch/trainer/trainer.py, tests/tests_pytorch/deprecated_api/test_no_removal_version.py, tests/tests_pytorch/models/test_hooks.py, tests/tests_pytorch/plugins/test_double_plugin.py, tests/tests_pytorch/strategies/test_common.py, tests/tests_pytorch/strategies/test_ddp.py, tests/tests_pytorch/strategies/test_deepspeed_strategy.py, tests/tests_pytorch/strategies/test_fsdp.py, tests/tests_pytorch/trainer/logging_/test_logger_connector.py, tests/tests_pytorch/trainer/test_trainer.py.

🟢 pytorch_lightning: Azure GPU
Check ID Status
pytorch-lightning (GPUs) success

These checks are required after the changes to src/lightning/pytorch/core/hooks.py, src/lightning/pytorch/plugins/precision/double.py, src/lightning/pytorch/plugins/precision/fsdp.py, src/lightning/pytorch/strategies/deepspeed.py, src/lightning/pytorch/strategies/fsdp.py, src/lightning/pytorch/strategies/strategy.py, src/lightning/pytorch/trainer/call.py, src/lightning/pytorch/trainer/configuration_validator.py, src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py, src/lightning/pytorch/trainer/trainer.py, tests/tests_pytorch/deprecated_api/test_no_removal_version.py, tests/tests_pytorch/models/test_hooks.py, tests/tests_pytorch/plugins/test_double_plugin.py, tests/tests_pytorch/strategies/test_common.py, tests/tests_pytorch/strategies/test_ddp.py, tests/tests_pytorch/strategies/test_deepspeed_strategy.py, tests/tests_pytorch/strategies/test_fsdp.py, tests/tests_pytorch/trainer/logging_/test_logger_connector.py, tests/tests_pytorch/trainer/test_trainer.py.

🟢 pytorch_lightning: Benchmarks
Check ID Status
lightning.Benchmarks success

These checks are required after the changes to src/lightning/pytorch/core/hooks.py, src/lightning/pytorch/plugins/precision/double.py, src/lightning/pytorch/plugins/precision/fsdp.py, src/lightning/pytorch/strategies/deepspeed.py, src/lightning/pytorch/strategies/fsdp.py, src/lightning/pytorch/strategies/strategy.py, src/lightning/pytorch/trainer/call.py, src/lightning/pytorch/trainer/configuration_validator.py, src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py, src/lightning/pytorch/trainer/trainer.py.

🟢 pytorch_lightning: Docs
Check ID Status
make-doctest (pytorch) success
make-html (pytorch) success

These checks are required after the changes to src/lightning/pytorch/core/hooks.py, src/lightning/pytorch/plugins/precision/double.py, src/lightning/pytorch/plugins/precision/fsdp.py, src/lightning/pytorch/strategies/deepspeed.py, src/lightning/pytorch/strategies/fsdp.py, src/lightning/pytorch/strategies/strategy.py, src/lightning/pytorch/trainer/call.py, src/lightning/pytorch/trainer/configuration_validator.py, src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py, src/lightning/pytorch/trainer/trainer.py, docs/source-pytorch/advanced/model_parallel.rst, docs/source-pytorch/common/lightning_module.rst, docs/source-pytorch/integrations/strategies/colossalai.rst.

🟢 lightning_fabric: CPU workflow
Check ID Status
fabric-cpu (macOS-11, lightning, 3.8, 1.11) success
fabric-cpu (macOS-11, lightning, 3.9, 1.12) success
fabric-cpu (macOS-11, lightning, 3.10, 1.13) success
fabric-cpu (macOS-11, lightning, 3.10, 2.0) success
fabric-cpu (macOS-11, lightning, 3.8, 1.11, oldest) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 1.11) success
fabric-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 1.11, oldest) success
fabric-cpu (windows-2022, lightning, 3.8, 1.11) success
fabric-cpu (windows-2022, lightning, 3.9, 1.12) success
fabric-cpu (windows-2022, lightning, 3.10, 1.13) success
fabric-cpu (windows-2022, lightning, 3.10, 2.0) success
fabric-cpu (windows-2022, lightning, 3.8, 1.11, oldest) success
fabric-cpu (macOS-11, fabric, 3.8, 1.13) success
fabric-cpu (ubuntu-20.04, fabric, 3.8, 1.13) success
fabric-cpu (windows-2022, fabric, 3.8, 1.13) success

These checks are required after the changes to tests/tests_fabric/strategies/test_fsdp_integration.py.

🟢 lightning_fabric: Azure GPU
Check ID Status
lightning-fabric (GPUs) success

These checks are required after the changes to tests/tests_fabric/strategies/test_fsdp_integration.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/pytorch/core/hooks.py, src/lightning/pytorch/plugins/precision/double.py, src/lightning/pytorch/plugins/precision/fsdp.py, src/lightning/pytorch/strategies/deepspeed.py, src/lightning/pytorch/strategies/fsdp.py, src/lightning/pytorch/strategies/strategy.py, src/lightning/pytorch/trainer/call.py, src/lightning/pytorch/trainer/configuration_validator.py, src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py, src/lightning/pytorch/trainer/trainer.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.10) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.10) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.10) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.10) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.10) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.10) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.10) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.10) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.10) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.10) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.10) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.10) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.10) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.10) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.10) success

These checks are required after the changes to src/lightning/pytorch/core/hooks.py, src/lightning/pytorch/plugins/precision/double.py, src/lightning/pytorch/plugins/precision/fsdp.py, src/lightning/pytorch/strategies/deepspeed.py, src/lightning/pytorch/strategies/fsdp.py, src/lightning/pytorch/strategies/strategy.py, src/lightning/pytorch/trainer/call.py, src/lightning/pytorch/trainer/configuration_validator.py, src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py, src/lightning/pytorch/trainer/trainer.py.

🟢 link-check
Check ID Status
check-md-links / markdown-link-check success

These checks are required after the changes to src/lightning/fabric/CHANGELOG.md, src/lightning/pytorch/CHANGELOG.md.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

@carmocca carmocca changed the title Add trainer.init_module Add Trainer.init_module and LightningModule.configure_model Jul 13, 2023
@carmocca carmocca added hooks Related to the hooks API and removed fabric lightning.fabric.Fabric strategy: deepspeed strategy: fsdp Fully Sharded Data Parallel labels Jul 13, 2023
Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. Minor questions and comments only

src/lightning/fabric/CHANGELOG.md Show resolved Hide resolved
src/lightning/pytorch/trainer/call.py Show resolved Hide resolved
src/lightning/pytorch/trainer/trainer.py Show resolved Hide resolved
tests/tests_pytorch/strategies/test_fsdp.py Outdated Show resolved Hide resolved
@github-actions github-actions bot added the fabric lightning.fabric.Fabric label Jul 14, 2023
@mergify mergify bot added the ready PRs ready to be merged label Jul 14, 2023
@carmocca carmocca merged commit 340eecd into master Jul 14, 2023
@carmocca carmocca deleted the carmocca/trainer-init-module branch July 14, 2023 17:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fabric lightning.fabric.Fabric feature Is an improvement or enhancement hooks Related to the hooks API pl Generic label for PyTorch Lightning package ready PRs ready to be merged strategy trainer
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants