Lightning v2.2
Lightning AI is excited to announce the release of Lightning 2.2 ⚡
Did you know? The Lightning philosophy extends beyond a boilerplate-free deep learning framework: We've been hard at work bringing you Lightning Studio. Code together, prototype, train, deploy, host AI web apps. All from your browser, with zero setup.
While our previous release was packed with many big new features, this time around we're rolling out mainly improvements based on feedback from the community. And of course, as the name implies, this release fully supports the latest PyTorch 2.2 🎉
Highlights
Monitoring Throughput
Lightning now has built-in utilities to measure throughput metrics such as batches/sec, samples/sec and Model FLOP Utilization (MFU) (#18848).
Trainer:
For the Trainer, this comes in form of a ThroughputMonitor
callback. In order to track samples/sec, you need to provide a function to tell the monitor how to extract the batch dimension from your input. Furthermore, if you want to track MFU, you can provide a sample forward pass and the ThroughputMonitor
will automatically estimate the utilization based on the hardware you are running on:
import lightning as L
from lightning.pytorch.callbacks import ThroughputMonitor
from lightning.fabric.utilities.throughput import measure_flops
class MyModel(LightningModule):
def setup(self, stage):
with torch.device("meta"):
model = MyModel()
def sample_forward():
batch = torch.randn(..., device="meta")
return model(batch)
self.flops_per_batch = measure_flops(model, sample_forward, loss_fn=torch.Tensor.sum)
throughput = ThroughputMonitor(
batch_size_fn=lambda batch: batch.size(0),
# optional, if your samples have a length (like number of tokens)
sample_fn=lambda batch: batch.size(1)
)
trainer = L.Trainer(log_every_n_steps=10, callbacks=throughput, logger=...)
model = MyModel()
trainer.fit(model)
The results get automatically sent to the logger if one is configured on the Trainer.
Fabric:
For Fabric, the ThroughputMonitor
is a simple utility object on which you call .update()
and compute_and_log()
during the training loop:
import lightning as L
from lightning.fabric.utilities import ThroughputMonitor
fabric = L.Fabric(logger=...)
throughput = ThroughputMonitor(fabric)
t0 = time()
for batch_idx, batch in enumerate(train_dataloader):
do_work()
torch.cuda.synchronize() # required or else time() won't be correct
throughput.update(
time=(time() - t0),
batches=batch_idx,
samples=(batch_idx * batch_size)
)
if batch_idx % 10 == 0:
throughput.compute_and_log(step=batch_idx)
Check out our TinyLlama LLM pretraining script for a full example using Fabric's ThroughputMonitor
.
The troughput utilities can report:
- batches per second (per process and across process)
- samples per second (per process and across process)
- items per second (e.g. tokens) (per process and across process)
- flops per second (per process and across process)
- model flops utilization (MFU) (per process)
- total time, total samples, total batches, and total items (per process)
Improved Handling of Evaluation Mode
When you train a model and have validation enabled, the Trainer automatically calls .eval()
when transitioning to the validation loop, and .train()
when validation ends. Until now, this had the unfortunate side effect that any submodules in your LightningModule that were in evaluation mode get reset to train mode. In Lightning 2.2, the Trainer now captures the mode of every submodule before switching to validation, and restores the mode the modules were in when validation ends (#18951, #18951, #18951). This improvement will help users avoid silent correctness bugs and removes boilerplate code for managing frozen layers.
import lightning as L
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.trainable_module = ...
# This will now stay in eval mode
self.frozen_module = ...
self.frozen_module.eval()
def training_step(self, batch):
# Previously, modules were all in train mode
# Now: Modules are in mode they were set up with
assert self.trainable_module.training
assert not self.frozen_module.training
...
def validation_step(self, batch):
# All modules are in eval mode
...
model = LitModel()
trainer = L.Trainer()
trainer.fit(model)
If you have overridden any of the LightningModule.on_{validation,test,predict}_model_{eval,train}
hooks, they will still get called and execute your custom logic, but they are no longer required if you added them to preserve the eval mode of frozen modules.
Important
In some libraries, for example HuggingFace, models are created in evaluation mode by default (e.g. HFModel.from_pretrained(...)
). Starting from 2.2, you will have to set .train()
on these models if you intend to train them.
Converting FSDP Checkpoints
In the previous release, we introduced distributed checkpointing with FSDP to speed up saving and loading checkpoints for big models. These checkpoints are in a special format saved in a folder with shards from each GPU in a separate file. While these checkpoints can be loaded back with Lightning Trainer or Fabric very easily, they aren't easy to load or process externally. In Lightning 2.2, we introduced a CLI utility that lets you consolidate the checkpoint folder to a single file that can be loaded in raw PyTorch with torch.load()
for example (#19213).
Given you saved a distributed checkpoint, you can then convert it like so:
# For Trainer checkpoints:
python -m lightning.pytorch.utilities.consolidate_checkpoint path/to/my/checkpoint
# For Fabric checkpoints:
python -m lightning.fabric.utilities.consolidate_checkpoint path/to/my/checkpoint
Read more about distributed checkpointing in our documentation: Trainer, Fabric.
Improvements to Compiling DDP/FSDP in Fabric
PyTorch 2.0+ introduced torch.compile
, a powerful tool to speed up your models without changing the code.
We now added a comprehensive guide how to use torch.compile
correctly with tips and tricks to help you troubleshoot common issues. On top of that, Fabric.setup()
will now reapply torch.compile
on top of DDP/FSDP if you are enabling these strategies (#19280).
import lightning as L
# Select a distributed strategy (DDP, FSDP, ...)
fabric = L.Fabric(strategy="ddp", devices=8)
# Compile your model before `.setup()`
model = torch.compile(model)
# Now automatically handles compiling also over DDP/FSDP
model = fabric.setup(model)
# You can opt-out if it is causing trouble
model = fabric.setup(model, _reapply_compile=False)
You might see fewer graph breaks, but there won't be any significant speed-ups with this. We introduced this mainly to make Fabric ready for future improvements from PyTorch to optimizing distributed operations.
Saving and Loading DataLoader State
If you use a dataloader/iterable that implements the .state_dict()
and .load_state_dict()
interface, the Trainer will now automatically save and load their state in the checkpoint (#19361).
import lightning as L
class MyDataLoader:
"""A dataloader that implements the 'stateful' interface."""
def state_dict(self):
# Return a dictionary with state
return {"batches_fetched": ...}
def load_state_dict(self, state_dict):
# Load the state from the checkpoint
self.batches_fetched = state_dict["batches_fetched"]
model = ...
dataloader = MyDataLoader()
trainer = L.Trainer()
# Saves checkpoints that include the dataloader state
trainer.fit(model, dataloader)
# When you resume training, the dataloader can now load its state
trainer.fit(model, dataloader, ckpt_path="path/to/my/checkpoint")
Note that the standard PyTorch DataLoader does not support this stateful interface. This feature only works on loaders that implement these two methods. A dataloader that supports full fault-tolerance will be included in our upcoming release of Lightning Data - a library to optimize data preprocessing and streaming in the cloud. Stay tuned!
Non-strict Checkpoint Loading in Trainer
A feature that has been requested for a long time by the community is non-strict checkpoint loading. By default, a checkpoint in PyTorch is loaded with strict=True
to ensure all keys in the saved checkpoint match what's in the model's state dict.
However, in some use cases it might make sense to exclude certain weights from being included in the checkpoint. When resuming training, the user would then be required to set strict=False
, which wasn't configurable until now.
You can now set the attribute strict_loading=False
on your LightningModule if you want to allow loading partial checkpoints (#19404).
import lightning as L
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
# This model only trains the decoder, we don't save the encoder
self.encoder = from_pretrained(...).requires_grad_(False)
self.decoder = Decoder()
# Set to False because we only care about the decoder
self.strict_loading = False
def state_dict(self):
# Don't save the encoder, it is not being trained
return {k: v for k, v in super().state_dict().items() if "encoder" not in k}
...
trainer = L.Trainer()
model = LitModel()
# Will load weights with `.load_state_dict(strict=model.strict_loading)`
trainer.fit(model, ckpt_path="path/to/checkpoint")
Full documentation here.
Notable Changes
The 2.0 series of Lightning releases guarantees core API stability: No name changes, argument renaming, hook removals etc. on core interfaces (Trainer, LightningModule, etc.) unless a feature is specifically marked experimental. Here we list a few behavioral changes made in places where the change was justified if it significantly improves the user experience, improves performance, or fixes the correctness of a feature. These changes will likely not impact most users.
ModelCheckpoint's save-last Feature
In Lightning 2.1, we made the ModelCheckpoint(..., save_last=True)
feature save a symbolic link to the last saved checkpoint instead of rewriting the checkpoint (#18748). This time saver is especially useful for large models who take a while to save. However, many users were confused by the new behavior and wanted it turned off, saving a copy instead of a symbolic link like before. In Lightning 2.2, we are reverting this decision and make the linking opt-in (#19191):
from lightning.pytorch.callbacks import ModelCheckpoint
# In 2.1 saves a symbolic link "last.ckpt" to the last checkpoint saved
# In 2.2 saves "last.ckpt" as a copy of the last checkpoint saved
checkpoint = ModelCheckpoint("./my_checkpoints", save_last=True)
# You can opt-in to save a symlink (if possible)
checkpoint = ModelCheckpoint("./my_checkpoints", save_last="link")
Removed Problematic Default Seeding
The seed_everything(x)
utility function is useful to set the seed for several libraries like PyTorch, NumPy and Python in a single line of code. However, until now you were allowed to omit passing a seeding value, in which case the function picked a seed value randomly. In certain cases, for example when processes are launched externally (e.g., SLURM, torchelastic etc.), this default behavior is dangerous because each process will independently choose a random seed. This can affect sampling, randomized validation splits, and other behaviors that rely on each process having the same seed. In 2.2, we removed this default behavior and default to a seed value 0 (#18846):
from lightning.pytorch.utilities import seed_everything
# Set the random seed for PyTorch, NumPy, Python etc.
seed_everything(42)
# Not setting a value now defaults to 0
seed_everything()
In the unlikely event that you relied on the previous behavior, you now have to choose the seed randomly yourself:
seed_everything(random.randint(0, 1000000))
Miscellaneous Changes
- Dropped support for PyTorch 1.12 (#19300)
- The columns in the
metrics.csv
file produced byCSVLogger
are now sorted alphabetically (#19159) - Added support for meta-device initialization and materialization of 4-bit Bitsandbytes layers (#19150)
- Added
TransformerEnginePrecision(fallback_compute_dtype=)
to control the dtype of operations that don't support fp8 (#19082) - We renamed the
TransformerEnginePrecision(dtype=)
argument toweights_dtype
and made it required (#19082) - The
LightningModule.load_from_checkpoint()
function now calls.configure_model()
on the model if it is overridden, to ensure all layers can be loaded from the checkpoint (#19036)
CHANGELOG
PyTorch Lightning
Added
- Added
lightning.pytorch.callbacks.ThroughputMonitor
to track throughput and log it (#18848) - The Trainer now restores the training mode set through
.train()
or.eval()
on a submodule-level when switching from validation to training (#18951) - Added support for meta-device initialization and materialization of 4-bit Bitsandbytes layers (#19150)
- Added
TransformerEnginePrecision(fallback_compute_dtype=)
to control the dtype of operations that don't support fp8 (#19082) - Added the option
ModelCheckpoint(save_last='link')
to create a symbolic link for the 'last.ckpt' file (#19191) - Added a utility function and CLI to consolidate FSDP sharded checkpoints into a single file (#19213)
- The TQDM progress bar now respects the env variable
TQDM_MINITERS
for setting the refresh rate (#19381) - Added support for saving and loading stateful training DataLoaders (#19361)
- Added shortcut name
strategy='deepspeed_stage_1_offload'
to the strategy registry (#19075) - Added support for non-strict state-dict loading in Trainer via the new
LightningModule.strict_loading = True | False
attribute (#19404)
Changed
seed_everything()
without passing in a seed no longer randomly selects a seed, and now defaults to0
(#18846)- The
LightningModule.on_{validation,test,predict}_model_{eval,train}
now only get called if they are overridden by the user (#18951) - The
Trainer.fit()
loop no longer callsLightningModule.train()
at the start; it now preserves the user's configuration of frozen layers (#18951) - The
LightningModule.load_from_checkpoint()
function now calls.configure_model()
on the model if it is overridden, to ensure all layers can be loaded from the checkpoint (#19036) - Restored usage of
step
parameter when logging metrics withNeptuneLogger
(#19126) - Changed the
TransformerEnginePrecision(dtype=)
argument toweights_dtype
and made it required (#19082) - The columns in the
metrics.csv
file produced byCSVLogger
are now sorted alphabetically (#19159) - Reverted back to creating a checkpoint copy when
ModelCheckpoint(save_last=True)
instead of creating a symbolic link (#19191)
Deprecated
- Deprecated all precision plugin classes under
lightning.pytorch.plugins
with the suffixPlugin
in the name (#18840)
Removed
- Removed support for PyTorch 1.12 (#19300)
Fixed
- Fixed issue where the
precision="transformer-engine"
argument would not replace layers by default (#19082) - Fixed issue where layers created in
LightningModule.setup
orLightningModule.configure_model
wouldn't get converted when using the Bitsandbytes or TransformerEngine plugins (#19061) - Fixed the input validation logic in
FSDPStrategy
to accept adevice_mesh
(#19392)
Lightning Fabric
Added
- Added
lightning.fabric.utilities.ThroughputMonitor
andlightning.fabric.utilities.Throughput
to track throughput and log it (#18848) - Added
lightning.fabric.utilities.AttributeDict
for convenient dict-attribute access to represent state in script (#18943) - Added support for meta-device initialization and materialization of 4-bit Bitsandbytes layers (#19150)
- Added
TransformerEnginePrecision(fallback_compute_dtype=)
to control the dtype of operations that don't support fp8 (#19082) - Added support for clipping gradients by value with FSDP (#19236)
- Added a utility function and CLI to consolidate FSDP sharded checkpoints into a single file (#19213)
- Added support for re-compiling the model inside
Fabric.setup()
over the FSDP/DDP wrappers (#19280)
Changed
seed_everything()
without passing in a seed no longer randomly selects a seed, and now defaults to0
(#18846)- Changed the
TransformerEnginePrecision(dtype=)
argument toweights_dtype
and made it required (#19082) - The columns in the
metrics.csv
file produced byCSVLogger
are now sorted alphabetically (#19159)
Removed
- Removed support for PyTorch 1.12 (#19300)
Fixed
Full commit list: 2.1.0 -> 2.2.0
Contributors
Everyone who contributed between 2.1 and 2.2, in no particular order:
Veteran
@nik777 @Raalsky @wouterzwerink @AleksanderWWW @awaelchli @nohalon @ioangatop @Borda @ethanwharris @BoringDonut @mauvilsa @parambharat @tchaton @ryan597 @adamjstewart @rasbt @carmocca
New
@hiaoxui @VictorPrins @jaswon @AMHermansen @JalinWang @MF-FOOM @unacanal @Jamim @harishb00 @asingh9530 @dipta007 @daturkel @jerrymannil @mjbommar @shenmishajing @paganpasta @lauritsf @andyland @mathematicalmichael
Did you know?
Chuck Norris is a big fan and daily user of PyTorch Lightning.