A template for simple deep learning projects using Lightning
English | 中文
PyTorch Lightning is to deep learning project development as MVC frameworks (such as Spring, Django, etc.) are to website development. While it is possible to implement everything from scratch and achieve maximum flexibility (especially since PyTorch and its ecosystem are already quite straightforward), using a framework can help you quickly implement prototypes with guidance from "best practices" (personal opinion) to save a lot of boilerplate code through re-usability, and focus on scientific innovation rather than engineering challenges. This template is built using the full Lightning suite, follows the principle of Occam's razor, and is friendly to researchers. It also includes a simple handwritten digit recognition task using the MNIST dataset. The repository also contains some Tips, for reference.
Using Pytorch Lightning as a deep learning framework:
Most of the deep learning code can be divided into the following three parts(Reference [Chinese]):
-
Research code: This part pertains to the model and generally deals with customizations of the model's structure and training. In
Lightning
, this code is abstracted as thepl.LightningModule
class. While dataset definition can also be included in this part, it is not recommended as it is not relevant to the experiment and should be included inpl.LightningDataModule
instead. -
Engineering code: This part of the code is essential for its high repeatability, such as setting early stopping, 16-bit precision, and GPU distributed training. In
Lightning
, this code is abstracted as thepl.Trainer
class. -
Non-essential code: This code is helpful in conducting experiments but is not directly related to the experiment itself, and can even be omitted. For example, gradient checking and outputting logs to
TensorBoard
. In Lightning, this code is abstracted as theCallbacks
class, which is registered topl.Trainer
.
The advantages of using Lightning
:
-
Custom training processes and learning rate adjustment strategies can be implemented through various hook functions in
pl.LightningModule
. -
The model and data no longer need to be explicitly designated for devices (
tensor.to
,tensor.cuda
, etc.).pl.Trainer
handles this automatically, thereby supporting various acceleration devices such as CPU, GPU, and TPU. -
pl.Trainer
implements various training strategies, such as automatic mixed precision training, multi-GPU training, and distributed training. -
pl.Trainer
implements multiple callbacks such as automatic model saving, automatic config saving, and automatic visualization result saving.
Using Pytorch Lightning CLI as a command-line tool:
-
Using
lightning_cli
as the program entry point, model, data, and training parameters can be set through configuration files or command-line parameters, thereby achieving quick switching between multiple experiments. -
pl.LightningModule.save_hyperparameters()
saves the model's hyperparameters and automatically generates a command-line parameter table, eliminating the need for tools such asargparse
orhydra
.
Using Torchmetrics as a metric computation tool:
-
Torchmetrics
provides multiple metric calculation methods such asAccuracy
,Precision
, andRecall
. -
It is integrated with
Lightning
and is compatible with parallel training strategies. Data is automatically aggregated to the main process for metric computation.
[Optional] Using WanDB to track experiments
graph TD;
A[LightningCLI]---B[LightningModule]
A---C[LightningDataModule]
B---D[models]
B---E[metrics]
B---F[...]
C---G[dataloaders]
G---H[datasets]
├── configs # Configuration files
│ ├── data # Dataset configuration
│ │ └── mnist.yaml # Example configuration for MNIST dataset
│ ├── model # Model configuration
│ │ └── simplenet.yaml # Example configuration for SimpleNet model
│ └── default.yaml # Default configuration
├── data # Dataset directory
├── logs # Log directory
├── notebooks # Jupyter Notebook directory
├── scripts # Script directory
│ └── clear_wandb_cache.py # Example script to clear wandb cache
├── src # Source code directory
│ ├── callbacks # Callbacks directory
│ │ └── __init__.py
│ ├── data_modules # Data module directory
│ │ ├── __init__.py
│ │ └── mnist.py # Example data module for MNIST dataset
│ ├── metrics # Metrics directory
│ │ └── __init__.py
│ ├── models # Model directory
│ │ ├── __init__.py
│ │ └── simplenet.py # Example SimpleNet model
│ ├── modules # Module directory
│ │ ├── __init__.py
│ │ └── mnist_module.py # Example MNIST module
│ ├── utils # Utility directory
│ │ ├── __init__.py
│ │ └── cli.py # CLI tool
│ ├── __init__.py
│ └── main.py # Main program entry point
├── .env.example # Example environment variable file
├── .gitignore # Ignore files for git
├── .project-root # Project root indicator file for pyrootutils
├── LICENSE # Open source license
├── pyproject.toml # Configuration file for Black and Ruff
├── README.md # Project documentation
├── README_PROJECT.md # Project documentation template
├── README_ZH.md # Project documentation in Chinese
└── requirements.txt # Dependency list
# Clone project
git clone https://github.com/DavidZhang73/pytorch-lightning-template <project_name>
cd <project_name>
# [Optional] Create a conda virtual environment
conda create -n <env_name> python=<3.8|3.9|3.10>
conda activate <env_name>
# [Optional] Use mamba instead of conda to speed up
conda install mamba -n base -c conda-forge
# [Optional] Install PyTorch according to the website to get GPU support
# https://pytorch.org/get-started/
# Install dependencies
pip install -r requirements.txt
- Define dataset by inheriting
pl.LightningDataModule
insrc/data_module
. - Define dataset configuration file in
configs/data
as parameters for the custompl.LightningDataModule
. - Define the model by inheriting
nn.Module
insrc/models
. - Define metrics by inheriting
torchmetrics.Metric
insrc/metrics
. - Define training module by inheriting
pl.LightningModule
insrc/modules
. - Define the configuration file for the training module in
configs/model
as parameters for the custompl.LightningModule
. - Configure
pl.trainer
, logs and other parameters inconfigs/default.yaml
.
Fit
python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1
Validate
python src/main.py validate -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1
Test
python src/main.py test -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1
Inference
python src/main.py predict -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.logger.name exp1
Debug
python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --trainer.fast_dev_run true
Resume
python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --ckpt_path <ckpt_path> --trainer.logger.id exp1_id
Using the print_config
functionality of jsonargparse
, you can obtain the parsed arguments and generate default yaml
files. However, it is necessary to first configure the yaml
files for data
and model
.
python src/main.py fit -c configs/data/mnist.yaml -c configs/model/simplenet.yaml --print_config
Prepare a config file for the CLI
This template implements a custom CLI
(CustomLightningCLI
) to achieve the following functions,
- When starting the program, the configuration file is automatically saved to the corresponding log directory, for
WandbLogger
only. - When starting the program, save configurations for optimizer and scheduler to loggers.
- When starting the program, the default configuration file is automatically loaded.
- After the test is completed, the
checkpoint_path
used for testing is printed. - Add some command line parameters:
--ignore_warnings
(default:False
): Ignore all warnings.--test_after_fit
(default:False
): Automatically test after each training.--git_commit_before_fit
(default:False
):git commit
before each training, the commit message is{logger.name}_{logger.version}
, forWandbLogger
only.
CONFIGURE HYPERPARAMETERS FROM THE CLI (EXPERT)
When running on a server, especially when the CPU has a lot of cores (>=24), you may encounter the problem of too many numpy
processes, which may cause the experiment to inexplicably hang. You can limit the number of numpy
processes by setting environment variables (in the .env
file).
OMP_NUM_THREADS=8
MKL_NUM_THREADS=8
GOTO_NUM_THREADS=8
NUMEXPR_NUM_THREADS=8
OPENBLAS_NUM_THREADS=8
MKL_DOMAIN_NUM_THREADS=8
VECLIB_MAXIMUM_THREADS=8
.env
file is automatically loaded to environment bypyrootutils
viapython-dotenv
.
Stack Overflow: Limit number of threads in numpy
When you delete an experiment from the wandb
web page, the cache of the experiment still exists in the local wandb
directory, you can use the scripts/clear_wandb_cache.py
script to clear the cache.
Inspired by,