A clean and flexible Pytorch Lightning template to kickstart and structure your deep learning project, ensuring efficient workflow, reproducibility, and easy extensibility for rapid experiments.
Pytorch Lightning is a deep learning framework designed for professional AI researchers and engineers, freeing users from boilerplate code (e.g., multiple GPUs/TPUs/HPUs training, early stopping, and checkpointing) to focus on going from idea to paper/production.
This Lightning template leverages Lightning CLI to separate configuration from source code, guaranteeing reproducibility of experiments, and incorporates many other best practices.
- Compared to Lightning-Hydra-Template: Our template provides similar functionality through a simple and straightforward encapsulation of Lightning's built-in CLI, making it suitable for users who prefer minimal setup without an additional Hybra layer.
Note: This is an unofficial project that lacks comprehensive test and continuous integration.
git clone https://github.com/YourGithubName/your-repository-name
cd your-repository-name
# [SUGGESTED] use conda environment
conda env create -n env-name -f environment.yaml
conda activate env-name
# [ALTERNATIVE] install requirements directly
pip install -r requirements.txt
# Run the sample script, i.e., ./run fit --config configs/mnist.yaml
bash -x scripts/run.sh
Before using this template, please read the basic Pytorch Lightning documentation: Lightning in 15 minutes.
- Define a Lightning Module (Examples: mnist_model.py and glue_transformer.py)
- Define a Lightning DataModule (Examples: mnist_datamodule.py and glue_datamodule.py)
- Prepare your experiment configs (Examples: mnist.yaml and mrpc.yaml)
- Run experiments (cf., Configure hyperparameters from the CLI)
- To see the available commands type:
./run --help
- Train a model from the config:
./run fit --config configs/mnist.yaml
- Override config options:
./run fit --config configs/mnist.yaml --trainer.precision 16 --model.learning_rate 0.1 --data.batch_size 64
- Separate model and datamodule configs:
./run fit --config configs/data.yaml --config configs/model.yaml
The directory structure of a project looks like this:
lightning-template
├── configs ← Directory of Configs
│ ├── mnist.yaml
│ ├── mrpc.yaml
│ ├── presets ← Preset configs for Lightning features
│ └── sweep_mnist.yaml
├── data ← Directory of Data
├── environment.yaml
├── models ← Directory of Models
├── notebooks ← Directory of Notebooks
├── pyproject.toml
├── README.md
├── requirements.txt
├── results ← Directory of Results
├── run ← Script to Run Lightning CLI
├── scripts ← Directory of Scripts
│ ├── print_results
│ ├── run.sh
│ ├── sweep ← Script to sweep Experiments
│ └── sweep_mnist.sh
└── src ← Directory of Source Code
├── callbacks
├── datamodules
├── models
├── utils
└── vendor ← Directory of Third-Party Code
- Use conda to manage environments.
- Leverages Lightning awesome features (cf., How-to Guides & Glossary)
- Use pre-commit and ruff to check and format code with configuration in pyproject.toml and .pre-commit-config.yaml.
pre-commit install
- Use dotenv to automatically change environments and set variables (cf., .envrc).
λ cd lightning-template direnv: loading ~/lightning-template/.envrc direnv: export +CONDA_DEFAULT_ENV +CONDA_EXE +CONDA_PREFIX +CONDA_PROMPT_MODIFIER +CONDA_PYTHON_EXE +CONDA_SHLVL +_CE_CONDA +_CE_M ~PATH ~PYTHONPATH
- Add the project root to
PATH
to userun
script directly.
export PATH=$PWD:$PWD/scripts:$PATH run fit --config configs/mnist.yaml
- Add the project root to
PYTHONPATH
to avoid modifyingsys.path
in scripts.
export PYTHONPATH=$PWD${PYTHONPATH:+":$PYTHONPATH"}
- Save privacy variable to
.env
.
- Add the project root to
- Use shtab to generate shell completion file.
- Use ray tune to sweep parameters or hyperparameter search (cf., sweep_cli.py).
bash ./scripts/sweep --config configs/sweep_mnist.yaml
- Use third-party logger (e.g., w&b and aim) to track experiments.
What it does
First, install dependencies
# clone project
git clone https://github.com/YourGithubName/your-repository-name
cd your-repository-name
# [SUGGESTED] use conda environment
conda env create -f environment.yaml
conda activate lit-template
# [ALTERNATIVE] install requirements directly
pip install -r requirements.txt
Next, to obtain the main results of the paper:
# commands to get the main results
You can also run experiments with the run
script.
# fit with the demo config
./run fit --config configs/demo.yaml
# or specific command line arguments
./run fit --model MNISTModel --data MNISTDataModule --data.batch_size 32 --trainer.gpus 0
# evaluate with the checkpoint
./run test --config configs/demo.yaml --ckpt_path ckpt_path
# get the script help
./run --help
./run fit --help
@article{YourName,
title={Your Title},
author={Your team},
journal={Location},
year={Year}
}