Lightning 2.1: Train Bigger, Better, Faster
Lightning AI is excited to announce the release of Lightning 2.1 ⚡ It's the culmination of work from 79 contributors who have worked on features, bug-fixes, and documentation for a total of over 750+ commits since v2.0.
The theme of 2.1 is "bigger, better, faster": Bigger because training large multi-billion parameter models has gotten even more efficient thanks to FSDP, efficient initialization and sharded checkpointing improvements, better because it's easier than ever to scale models without making substantial code changes or installing third-party packages and faster because it leverages the latest hardware features to speed up training in low-bit precision thanks to new precision plugins like bitsandbytes and transformer engine.
And of course, as the name implies, this release fully leverages the latest features in PyTorch 2.1 🎉
Highlights
Improvements To Large-Scale Training With FSDP
The FSDP strategy for training large billion-parameter models gets substantial improvements and new features in Lightning 2.1, both in Trainer and Fabric (in case you didn't know, Fabric is the latest addition to the Lightning family of tools to scale models without the boilerplate code).
FSDP is now more user-friendly to configure, has memory management and speed improvements, and we have a brand new end-to-end user guide with best practices (Trainer, Fabric).
Efficient Saving and Loading of Large Checkpoints
When training large billion-parameter models with FSDP, saving and resuming training, or even just loading model parameters for finetuning can be challenging, as users are are often plagued by out-of-memory errors and speed bottlenecks.
In 2.1, we made several improvements. Starting with saving checkpoints, we added support for distributed/sharded checkpoints, enabled through the setting state_dict_type
in the strategy (#18364, #18358):
Trainer:
import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
# Default used by the strategy
strategy = FSDPStrategy(state_dict_type="full")
# Enable saving distributed checkpoints
strategy = FSDPStrategy(state_dict_type="sharded")
trainer = L.Trainer(strategy=strategy, ...)
Fabric:
import lightning as L
from lightning.fabric.strategies import FSDPStrategy
# Saving distributed checkpoints is the default
strategy = FSDPStrategy(state_dict_type="sharded")
# Save consolidated (single file) checkpoints
strategy = FSDPStrategy(state_dict_type="full")
fabric = L.Fabric(strategy=strategy, ...)
Distributed checkpoints are the fastest and most memory efficient way to save the state of very large models.
The distributed checkpoint format also makes it efficient to load these checkpoints back for resuming training in parallel, and it reduces the impact on CPU memory usage significantly. Furthermore, we've also introduced lazy-loading for non-distributed checkpoints (#18150, #18379), which greatly reduces the impact on CPU memory usage when loading a consolidated (single-file) checkpoint (e.g. for finetuning). Learn more about these features in our FSDP guides (Trainer, Fabric).
Fast and Memory-Optimized Initialization
A major challenge that users face when working with large models such as LLMs is dealing with the extreme memory requirements. Even something as simple as instantiating a model becomes non-trivial if the model is so large it won't fit in a single GPU or even a single machine. In Lightning 2.1, we are introducing empty-weights initialization through the Fabric.init_module()
(#17462, #17627) and Trainer.init_module()
/LightningModule.configure_model()
(#18004, #18004, #18385) methods:
Trainer:
import lightning as L
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
# Delay initialization of model to `configure_model()`
def configure_model(self):
# Model initialized in correct precision and weights on meta-device
self.model = ...
...
trainer = L.Trainer(strategy="fsdp", ...)
trainer.fit(model)
Fabric:
import lightning as L
fabric = L.Fabric(strategy="fsdp", ...)
# Model initialized in correct precision and weights on meta-device
with fabric.init_module(empty_init=True):
model = ...
# You can also initialize buffers and tensors directly on device and dtype
with fabric.init_tensor():
model.mask.create()
model.kv_cache.create()
x = torch.randn(4, 128)
# Materialization and sharding of model happens inside here
model = fabric.setup(model)
Read more about this new feature and its other benefits in our docs (Trainer, Fabric).
User-Friendly Configuration
We made it super easy to configure the sharding- and activation-checkpointing policy when you want to auto-wrap particular layers of your model for advanced control (#18045, #18084).
import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
- from torch.distributed.fsdp.wrap import ModuleWrapPolicy
- strategy = FSDPStrategy(auto_wrap_policy=ModuleWrapPolicy({MyTransformerBlock}))
+ strategy = FSDPStrategy(auto_wrap_policy={MyTransformerBlock})
trainer = L.Trainer(strategy=strategy, ...)
Furthermore, the sharding strategy can now be conveniently set with a string value (#18087):
import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
- from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
- strategy = FSDPStrategy(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP)
+ strategy = FSDPStrategy(sharding_strategy="SHARD_GRAD_OP")
trainer = L.Trainer(strategy=strategy, ...)
You no longer need to remember the long PyTorch imports! Fabric also supports all these improvements shown above.
True Half-Precision
Lightning now supports true half-precision for training and inference with all built-in strategies (#18193, #18217, #18213, #18219). With this setting, the memory required to store the model weights is only half of what is normally needed when running with float32. In addition, you get the same speed benefits as mixed precision training (precision="16-mixed"
) has:
import lightning as L
# default
trainer = L.Trainer(precision="32-true")
# train with model weights in `torch.float16`
trainer = L.Trainer(precision="16-true")
# train with model weights in `torch.bfloat16`
# (if hardware supports it)
trainer = L.Trainer(precision="bf16-true")
The same settings are also available in Fabric! We recommend to try bfloat16 training (precision="bf16-true"
) as it is often more numerically stable than regular 16-bit precision (precision="16-true"
).
Bitsandbytes Quantization
With the new Bitsandbytes precision plugin #18655, you can now quantize your model for significant memory savings during training, finetuning, or inference with a selection of several state-of-the-art quantization algorithms (int8, fp4, nf4 and more). For the first time, Trainer and Fabric make bitsandbytes easy to use for general models.
Trainer:
import lightning as L
from lightning.pytorch.plugins import BitsandbytesPrecisionPlugin
# this will pick out the compute dtype automatically, by default `bfloat16`
precision = BitsandbytesPrecisionPlugin("nf4-dq")
trainer = L.Trainer(plugins=precision)
Fabric:
import lightning as L
from lightning.fabric.plugins import BitsandbytesPrecision
# this will pick out the compute dtype automatically, by default `bfloat16`
precision = BitsandbytesPrecision("nf4-dq")
trainer = L.Fabric(plugins=precision)
Transformer Engine
The Transformer Engine by NVIDIA is a library for accelerating transformer layers on the new Hopper (H100) generation of GPUs. With the integration in Lightning Trainer and Fabric (#17597, #18459), you have easy access to the 8-bit mixed precision for significant speed ups:
Trainer:
import lightning as L
# Select 8-bit mixed precision via TransformerEngine, with model weights in float16
trainer = L.Trainer(precision="transformer-engine-float16")
Fabric:
import lightning as L
# Select 8-bit mixed precision via TransformerEngine, with model weights in float16
fabric = L.Fabric(precision="transformer-engine-float16")
More configuration options are available through the respective plugins in Trainer and Fabric.
Lightning on TPU Goes Brrr
Lightning 2.1 runs on the latest generation of TPU hardware on Google Cloud! TPU-v4 and TPU-v5 (#17227) are now fully supported both in Fabric and Trainer and run using the new PjRT runtime by default (#17352). PjRT is the runtime used by Jax and has shown an average improvement of 35% on benchmarks.
Trainer:
import lightning as L
trainer = L.Trainer(accelerator="tpu", devices=8)
model = MyModel()
trainer.fit(model) # uses PjRT if available
Fabric:
import lightning as L
def train(fabric):
...
fabric = L.Fabric(accelerator="tpu")
fabric.launch(train) # uses PjRT if available
And what's even more exciting, you can now scale massive multi-billion parameter models on TPUs using FSDP (#17421).
import lightning as L
from lightning.fabric.strategies import XLAFSDPStrategy
strategy = XLAFSDPStrategy(
# Most arguments from the PyTorch native FSDP strategy are also available here!
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
sequential_save=True,
)
fabric = L.Fabric(devices=8, strategy=strategy)
fabric.launch(finetune)
You can find a full end-to-end finetuning example script in our Lit-GPT repository. The new XLA-FSDP strategy is experimental and currently only available in Fabric. Support in the Trainer will follow in the future.
Granular Control Over Checkpoints in Fabric
Several improvements for checkpoint saving and loading have landed in Fabric, enabling more fine-grained control over what is saved/loaded while reducing boilerplate code:
-
There is a new
Fabric.load_raw()
method with which you can load model- or optimizer state-dicts saved externally by a non-Fabric application (e.g., raw PyTorch) (#18049)import lightning as L fabric = L.Fabric() model = MyModel() # A model weights file saved by your friend who doesn't use Fabric fabric.load_raw("path/to/model.pt", model) # Equivalent to this: # model.load_state_dict(torch.load("path/to/model.pt"))
-
A new parameter
Fabric.load(..., strict=True|False)
to disable strict loading (#17645)import lightning as L fabric = L.Fabric() model = MyModel() state = {"model": model} # strict loading is the default fabric.load("path/to/checkpoint.ckpt", state, strict=True) # disable strict loading fabric.load("path/to/checkpoint.ckpt", state, strict=False)
-
A new parameter
Fabric.save(..., filter=...)
that enables you to exclude certain parameters of your model without writing boilerplate code for it (#17845)import lightning as L fabric = L.Fabric() model, optimizer = ... state = {"model": model, "optimizer": optimizer, "foo": 123} # save only the weights that match a pattern filter = {"model": lambda k, v: "weight" in k} fabric.save("path/to/checkpoint.ckpt", state, filter=filter)
You can read more about the new options in our checkpoint guide.
Backward Incompatible Changes
The release of PyTorch Lightning 2.0 was a big step into a new chapter: It brought a more polished API and removed a lot of legacy code and outdated as well as experimental features, at the cost of a long list of breaking changes resulting in more work needed than usual to upgrade from 1.9 to 2.0. Moving forward, we promised to maintain full backward compatibility of our public core APIs to guarantee a smooth upgrade experience for everyone, and with 2.1 we are happy to deliver on this promise. A few exceptions were made in places where the change was justified if it significantly improves the user experience, improves performance, or fixes the correctness of a feature. These changes will likely not impact most users.
PyTorch Lightning
TPU/XLA Changes
When selecting device indices via devices=[i]
, the Trainer now selects the i-th TPU core (0-based, previously it was 1-based) (#17227)
Before:
# Selects the first TPU core (1-based index)
trainer = Trainer(accelerator="tpu", devices=[1])
Now:
# Selects the second TPU core (0-based index)
trainer = Trainer(accelerator="tpu", devices=[1])
Multi-GPU in Jupyter Notebooks
Due to lack of reliability, Trainer now only runs on one GPU instead of all GPUs in a Jupyter notebook if devices="auto"
(default) (#18291)
Before:
import lightning as L
# In Jupyter notebooks, this would select all available GPUs (DDP)
trainer = L.Trainer(accelerator="cuda", devices="auto")
Now:
# In Jupyter notebooks, this now selects only one GPU (the first)
trainer = L.Trainer(accelerator="cuda", devices="auto")
# You can still explicitly select multiple
trainer = L.Trainer(accelerator="cuda", devices=8)
Device Access in Setup Hook
- During
LightningModule.setup()
, theself.device
now returns the device the module will be placed on instead ofcpu
(#18021)
Before:
def setup(self, stage):
# CPU regardless of the accelerator used
print(self.device)
Now:
def setup(self, stage):
# CPU/CUDA/MPS/XLA depending on accelerator
print(self.device)
Miscellaneous Changes
self.log
ed tensors are now kept in the original device to reduce unnecessary host-to-device synchronizations (#17334)- The
FSDPStrategy
now loads checkpoints after theconfigure_model
/configure_sharded_model
hook (#18358) - The
FSDPStrategy.load_optimizer_state_dict
andFSDPStrategy.load_model_state_dict
are a no-op now (#18358) - Removed experimental support for
torchdistx
due to a lack of project maintenance (#17995) - Dropped support for PyTorch 1.11 (#18691)
Lightning Fabric
We thank the community for the amazing feedback we got for Fabric so far - keep it coming. The list of breaking changes is short and won't affect the vast majority of users.
Sharding Context Manager in Fabric.run()
We removed automatic sharding support with Fabric.run
or using fabric.launch(fn)
. This only impacts FSDP and DeepSpeed strategy users who use this way of launching. Please note that Fabric.run
is a legacy construct from the LightningLite
days, and is not recommended today. Please instantiate your large FSDP or DeepSpeed model under the newly added fabric.init_module
context manager (#17832).
Before:
import lightning as L
def train(fabric):
# FSDP's `enable_wrap` context or `deepspeed.zero.Init()`
# were applied automaticaly here
model = LargeModel()
...
fabric = L.Fabric()
fabric.launch(train)
Now:
def train(fabric):
# Use `init_module` explicitly to apply these context managers
with fabric.init_module():
model = LargeModel()
...
Multi-GPU in Jupyter Notebooks
Due to lack of reliability, Fabric now only runs on one GPU instead of all GPUs in a Jupyter notebook if devices="auto"
(default) (#18291)
Before:
import lightning as L
# In Jupyter notebooks, this would select all available GPUs (DDP)
fabric = L.Fabric(accelerator="cuda", devices="auto")
Now:
# In Jupyter notebooks, this now selects only one GPU (the first)
fabric = L.Fabric(accelerator="cuda", devices="auto")
# You can still explicitly select multiple
fabric = L.Fabric(accelerator="cuda", devices=8)
CHANGELOG
PyTorch Lightning
Added
- Added
metrics_format
attribute toRichProgressBarTheme
class (#18373) - Added
CHECKPOINT_EQUALS_CHAR
attribute toModelCheckpoint
class (#17999) - Added
**summarize_kwargs
toModelSummary
andRichModelSummary
callbacks (#16788) - Added support for the
max_size_cycle|max_size|min_size
iteration modes during evaluation (#17163) - Added support for the TPU-v4 architecture (#17227)
- Added support for XLA's new PJRT runtime (#17352)
- Check for invalid TPU device inputs (#17227)
- Added
XLAStrategy(sync_module_states=bool)
to control whether to broadcast the parameters to all devices (#17522) - Added support for multiple optimizer parameter groups when using the FSDP strategy (#17309)
- Enabled saving the full model state dict when using the
FSDPStrategy
(#16558) - Update
LightningDataModule.from_datasets
to support arbitrary iterables (#17402) - Run the DDP wrapper in a CUDA stream (#17334)
- Added
SaveConfigCallback.save_config
to ease use cases such as saving the config to a logger (#17475) - Enabled optional file versioning of model checkpoints (#17320)
- Added the process group timeout argument
FSDPStrategy(timeout=...)
for the FSDP strategy (#17274) - Added
FSDPStrategy(activation_checkpointing_policy=...)
to customize the layer policy for automatic activation checkpointing (requires torch>=2.1) (#18045) - Added CLI option
--map-to-cpu
to the checkpoint upgrade script to enable converting GPU checkpoints on a CPU-only machine (#17527) - Added non-layer param count to the model summary (#17005)
- Updated
LearningRateMonitor
to log monitored values totrainer.callback_metrics
(#17626) - Added
log_weight_decay
argument toLearningRateMonitor
callback (#18439) - Added
Trainer.print()
to print on local rank zero only (#17980) - Added
Trainer.init_module()
context manager to instantiate large models efficiently directly on device, dtype (#18004)- Creates the model parameters in the desired dtype (
torch.float32
,torch.float64
) depending on the 'true' precision choice inTrainer(precision='32-true'|'64-true')
- Creates the model parameters in the desired dtype (
- Added the
LightningModule.configure_model()
hook to instantiate large models efficiently directly on device, dtype, and with sharding support (#18004)- Handles initialization for FSDP models before wrapping and the Zero stage 3 initialization for DeepSpeed before sharding
- Added support for meta-device initialization with
Trainer.init_module(empty_init=True)
in FSDP (#18385) - Added
lightning.pytorch.plugins.PrecisionPlugin.module_init_context()
andlightning.pytorch.strategies.Strategy.tensor_init_context()
context managers to control model and tensor instantiation (#18004) - Automatically call
xla_model.mark_step()
before saving checkpoints with XLA (#17882) - Added a callback for spike-detection (#18014)
- Added the ability to set the
torch.distributed.fsdp.ShardingStrategy
via string inFSDPStrategy
(#18087) - Improved error messages when attempting to load a DeepSpeed checkpoint at an invalid path (#17795)
- Allowed accessing rank information in the main process before processes are launched when using the
XLAStrategy
(#18194) - Added support for true half-precision training via
Trainer(precision="16-true"|"bf16-true")
(#18193, #18217, #18213, #18219) - Added automatic process cleanup to avoid zombie child processes and stalls when exceptions are raised (#18218)
- Added validation of user input for
devices
andnum_nodes
when running withSLURM
orTorchElastic
(#18292) - Added support for saving checkpoints with either full state-dict or sharded state dict via
FSDPStrategy(state_dict_type="full"|"sharded")
(#18364) - Added support for loading sharded/distributed checkpoints in FSDP (#18358)
- Made the text delimiter in the rich progress bar configurable (#18372)
- Improved the error messaging and instructions when handling custom batch samplers in distributed settings (#18402)
- Added support for mixed 8-bit precision as
Trainer(precision="transformer-engine")
using Nvidia's Transformer Engine (#18459) - Added support for linear layer quantization with
Trainer(plugins=BitsandbytesPrecision())
using bitsandbytes (#18655) - Added support for passing the process group to the
FSDPStrategy
(#18583) - Enabled the default process group configuration for FSDP's hybrid sharding (#18583)
- Added
lightning.pytorch.utilities.suggested_max_num_workers
to assist with setting a good value in distributed settings (#18591) - Improved the
num_workers
warning to give a more accurate upper limit on thenum_workers
suggestion (#18591) - Added
lightning.pytorch.utilities.is_shared_filesystem
utility function to automatically check whether the filesystem is shared between machines (#18586) - Added support for returning an object of type
Mapping
fromLightningModule.training_step()
(#18657) - Added the hook
LightningModule.on_validation_model_zero_grad()
to allow overriding the behavior of zeroing the gradients before entering the validation loop (#18710)
Changed
- Changed default metric formatting from
round(..., 3)
to".3f"
format string inMetricsTextColumn
class (#18483) - Removed the limitation to call
self.trainer.model.parameters()
inLightningModule.configure_optimizers()
(#17309) Trainer(accelerator="tpu", devices=[i])"
now selects the i-th TPU core (0-based, previously it was 1-based) (#17227)- Allow using iterable-style datasets with TPUs (#17331)
- Increased the minimum XLA requirement to 1.13 (#17368)
self.log
ed tensors are now kept in the original device to reduce unnecessary host-to-device synchronizations (#17334)- Made the run initialization in
WandbLogger
lazy to avoid creating artifacts when the CLI is used (#17573) - Simplified redirection of
*_step
methods in strategies by removing the_LightningModuleWrapperBase
wrapper module (#17531) - Support kwargs input for LayerSummary (#17709)
- Dropped support for
wandb
versions older than 0.12.0 inWandbLogger
(#17876) - During
LightningModule.setup()
, theself.device
now returns the device the module will be placed on instead ofcpu
(#18021) - Increased the minimum supported
wandb
version forWandbLogger
from 0.12.0 to 0.12.10 (#18171) - The input tensors now get cast to the right precision type before transfer to the device (#18264)
- Improved the formatting of emitted warnings (#18288)
- Broadcast and reduction of tensors with XLA-based strategies now preserve the input's device (#18275)
- The
FSDPStrategy
now loads checkpoints after theconfigure_model
/configure_sharded_model
hook (#18358) - The
FSDPStrategy.load_optimizer_state_dict
andFSDPStrategy.load_model_state_dict
are a no-op now (#18358) - The
Trainer.num_val_batches
,Trainer.num_test_batches
andTrainer.num_sanity_val_batches
now return a list of sizes per dataloader instead of a single integer (#18441) - The
*_step(dataloader_iter)
flavor now no longer takes thebatch_idx
in the signature (#18390) - Calling
next(dataloader_iter)
now returns a triplet(batch, batch_idx, dataloader_idx)
(#18390) - Calling
next(combined_loader)
now returns a triplet(batch, batch_idx, dataloader_idx)
(#18390) - Due to lack of reliability, Trainer now only runs on one GPU instead of all GPUs in a Jupyter notebook if
devices="auto"
(default) (#18291) - Made the
batch_idx
argument optional invalidation_step
,test_step
andpredict_step
to maintain consistency withtraining_step
(#18512) - The
TQDMProgressBar
now consistently shows it/s for the speed even when the iteration time becomes larger than one second (#18593) - The
LightningDataModule.load_from_checkpoint
andLightningModule.load_from_checkpoint
methods now raise an error if they are called on an instance instead of the class (#18432) - Enabled launching via
torchrun
in a SLURM environment; theTorchElasticEnvironment
now gets chosen over theSLURMEnvironment
if both are detected (#18618) - If not set by the user, Lightning will set
OMP_NUM_THREADS
tonum_cpus / num_processes
when launching subprocesses (e.g. when DDP is used) to avoid system overload for CPU-intensive tasks (#18677) - The
ModelCheckpoint
no longer deletes files under the save-top-k mechanism when resuming from a folder that is not the same as the current checkpoint folder (#18750) - The
ModelCheckpoint
no longer deletes the file that was passed toTrainer.fit(ckpt_path=...)
(#18750) - Calling
trainer.fit()
twice now raises an error with strategies that spawn subprocesses throughmultiprocessing
(ddp_spawn, xla) (#18776) - The
ModelCheckpoint
now saves a symbolic link ifsave_last=True
andsave_top_k != 0
(#18748)
Deprecated
- Deprecated the
SingleTPUStrategy
(strategy="single_tpu"
) in favor ofSingleDeviceXLAStrategy
(strategy="single_xla"
) (#17383) - Deprecated the
TPUAccelerator
in favor ofXLAAccelerator
(#17383) - Deprecated the
TPUPrecisionPlugin
in favor ofXLAPrecisionPlugin
(#17383) - Deprecated the
TPUBf16PrecisionPlugin
in favor ofXLABf16PrecisionPlugin
(#17383) - Deprecated the
Strategy.post_training_step
method (#17531) - Deprecated the
LightningModule.configure_sharded_model
hook in favor ofLightningModule.configure_model
(#18004) - Deprecated the
LightningDoublePrecisionModule
wrapper in favor of callingTrainer.precision_plugin.convert_input()
(#18209)
Removed
Fixed
- Fixed an issue with reusing the same model across multiple trainer stages when using the
DeepSpeedStrategy
(#17531) - Fixed the saving and loading of FSDP optimizer states (#17819)
- Fixed FSDP re-applying activation checkpointing when the user had manually applied it already (#18006)
- Fixed issue where unexpected exceptions would leave the default torch dtype modified when using true precision settings (#18500)
- Fixed issue where not including the
batch_idx
argument in thetraining_step
would disable gradient accumulation (#18619) - Fixed the replacement of callbacks returned in
LightningModule.configure_callbacks
when the callback was a subclass of an existing Trainer callback (#18508) - Fixed
Trainer.log_dir
not returning the correct directory for theCSVLogger
(#18548) - Fixed redundant input-type casting in FSDP precision (#18630)
- Fixed numerical issues when reducing values in low precision with
self.log
(#18686) - Fixed an issue that would cause the gradients to be erased if validation happened in the middle of a gradient accumulation phase (#18710)
- Fixed redundant file writes in
CSVLogger
(#18567) - Fixed an issue that could lead to checkpoint files being deleted accidentally when resuming training (#18750)
Lightning Fabric
Added
- Added support for the TPU-v4 architecture (#17227)
- Added support for XLA's new PJRT runtime (#17352)
- Added support for Fully Sharded Data Parallel (FSDP) training with XLA (#18126, #18424, #18430)
- Check for invalid TPU device inputs (#17227)
- Added
XLAStrategy(sync_module_states=bool)
to control whether to broadcast the parameters to all devices (#17522) - Added support for joint setup of model and optimizer with FSDP (#17305)
- Added support for handling multiple parameter groups in optimizers set up with FSDP (#17305)
- Added support for saving and loading sharded model and optimizer state with
FSDPStrategy
(#17323) - Added a warning when calling methods on
_FabricModule
that bypass the strategy-specific wrappers (#17424) - Added
Fabric.init_tensor()
context manager to instantiate tensors efficiently directly on device and dtype (#17488) - Added
Fabric.init_module()
context manager to instantiate large models efficiently directly on device, dtype, and with sharding support (#17462)- Creates the model parameters in the desired dtype (
torch.float32
,torch.float64
,torch.float16
, ortorch.bfloat16
) depending on the 'true' precision choice inFabric(precision='32-true'|'64-true'|'16-true'|'bf16-true')
- Handles initialization for FSDP models before wrapping and the Zero stage 3 initialization for DeepSpeed before sharding
- Creates the model parameters in the desired dtype (
- Added support for empty weight initialization with
Fabric.init_module(empty_init=True)
for checkpoint loading (#17627) - Added support for meta-device initialization with
Fabric.init_module(empty_init=True)
in FSDP (#18122) - Added
lightning.fabric.plugins.Precision.module_init_context()
andlightning.fabric.strategies.Strategy.module_init_context()
context managers to control model and tensor instantiation (#17462) lightning.fabric.strategies.Strategy.tensor_init_context()
context manager to instantiate tensors efficiently directly on device and dtype (#17607)- Run the DDP wrapper in a CUDA stream (#17334)
- Added support for true half-precision as
Fabric(precision="16-true"|"bf16-true")
(#17287) - Added support for mixed 8-bit precision as
Fabric(precision="transformer-engine")
using Nvidia's Transformer Engine (#17597) - Added support for linear layer quantization with
Fabric(plugins=BitsandbytesPrecision())
using bitsandbytes (#18655) - Added error messaging for missed
.launch()
when it is required (#17570) - Added support for saving checkpoints with either full state-dict or sharded state dict via
FSDPStrategy(state_dict_type="full"|"sharded")
(#17526) - Added support for loading a full-state checkpoint file into a sharded model (#17623)
- Added support for calling hooks on a LightningModule via
Fabric.call
(#17874) - Added the parameter
Fabric.load(..., strict=True|False)
to enable non-strict loading of partial checkpoint state (#17645) - Added the parameter
Fabric.save(..., filter=...)
to enable saving a partial checkpoint state (#17845) - Added support for loading optimizer states from a full-state checkpoint file (#17747)
- Automatically call
xla_model.mark_step()
before saving checkpoints with XLA (#17882) - Automatically call
xla_model.mark_step()
afteroptimizer.step()
with XLA (#17883) - Added support for all half-precision modes in FSDP precision plugin (#17807)
- Added
FSDPStrategy(activation_checkpointing_policy=...)
to customize the layer policy for automatic activation checkpointing (requires torch>=2.1) (#18045) - Added a callback for spike-detection (#18014)
- Added the ability to set the
torch.distributed.fsdp.ShardingStrategy
via string inFSDPStrategy
(#18087) - Improved error messages when attempting to load a DeepSpeed checkpoint at an invalid path (#17795)
- Added
Fabric.load_raw()
for loading raw PyTorch state dict checkpoints for model or optimizer objects (#18049) - Allowed accessing rank information in the main process before processes are launched when using the
XLAStrategy
(#18194) - Added automatic process cleanup to avoid zombie child processes and stalls when exceptions are raised (#18218)
- Added validation of user input for
devices
andnum_nodes
when running withSLURM
orTorchElastic
(#18292) - Improved the error messaging and instructions when handling custom batch samplers in distributed settings (#18402)
- Added support for saving and loading stateful objects other than modules and optimizers (#18513)
- Enabled the default process group configuration for FSDP's hybrid sharding (#18583)
- Added
lightning.fabric.utilities.suggested_max_num_workers
to assist with setting a good value in distributed settings (#18591) - Added
lightning.fabric.utilities.is_shared_filesystem
utility function to automatically check whether the filesystem is shared between machines (#18586) - Removed support for PyTorch 1.11 (#18691)
- Added support for passing the argument
.load_state_dict(..., assign=True|False)
on Fabric-wrapped modules in PyTorch 2.1 or newer (#18690)
Changed
- Allow using iterable-style datasets with TPUs (#17331)
- Increased the minimum XLA requirement to 1.13 (#17368)
- Fabric argument validation now only raises an error if conflicting settings are set through the CLI (#17679)
- DataLoader re-instantiation is now only performed when a distributed sampler is required (#18191)
- Improved the formatting of emitted warnings (#18288)
- Broadcast and reduction of tensors with XLA-based strategies now preserve the input's device (#18275)
- Due to lack of reliability, Fabric now only runs on one GPU instead of all GPUs in a Jupyter notebook if
devices="auto"
(default) (#18291) - Enabled launching via
torchrun
in a SLURM environment; theTorchElasticEnvironment
now gets chosen over theSLURMEnvironment
if both are detected (#18618) - If not set by the user, Lightning will set
OMP_NUM_THREADS
tonum_cpus / num_processes
when launching subprocesses (e.g. when DDP is used) to avoid system overload for CPU-intensive tasks (#18677)
Deprecated
- Deprecated the
DDPStrategy.is_distributed
property. This strategy is distributed by definition (#17381) - Deprecated the
SingleTPUStrategy
(strategy="single_tpu"
) in favor ofSingleDeviceXLAStrategy
(strategy="single_xla"
) (#17383) - Deprecated the
TPUAccelerator
in favor ofXLAAccelerator
(#17383) - Deprecated the
TPUPrecision
in favor ofXLAPrecision
(#17383) - Deprecated the
TPUBf16Precision
in favor ofXLABf16Precision
(#17383)
Removed
- Removed automatic sharding support with
Fabric.run
or usingfabric.launch(fn)
. This only impacts FSDP and DeepSpeed strategy users. Please instantiate your module under the newly addedfabric.init_module
context manager (#17832) - Removed the unsupported
checkpoint_io
argument from theFSDPStrategy
(#18192)
Fixed
- Fixed issue where running on TPUs would select the wrong device index (#17227)
- Removed the need to call
.launch()
when using the DP-strategy (strategy="dp"
) (#17931) - Fixed FSDP re-applying activation checkpointing when the user had manually applied it already (#18006)
- Fixed FSDP re-wrapping the module root when the user had manually wrapped the model (#18054)
- Fixed issue where unexpected exceptions would leave the default torch dtype modified when using true precision settings (#18500)
- Fixed redundant input-type casting in FSDP precision (#18630)
- Fixed an issue with
find_usable_cuda_devices(0)
incorrectly returning a list of devices (#18722) - Fixed redundant file writes in
CSVLogger
(#18567)
Lightning App
Added
- Allow customizing
gradio
components with lightning colors (#17054)
Changed
- Changed
LocalSourceCodeDir
cache_location to not use home in some certain cases (#17491)
Removed
- Remove cluster commands from the CLI (#18151)
Full commit list: 2.0.0...2.1.0
Contributors
Veteran
@adamjstewart @akreuzer @ethanwharris @dmitsf @lantiga @nicolai86 @pl-ghost @carmocca @awaelchli @justusschock @edenlightning @belerico @lightningforever @nisheethlahoti @tchaton @yurijmikhalevich @mauvilsa @rlizzo @rusmux @yhl48 @Liyang90 @jerome-habana @JustinGoheen @Borda @speediedan @SkafteNicki @dcfidalgo
New
@saryazdi @parambharat @kshitij12345 @woqidaideshi @colehawkins @md-121 @gkroiz @idc9 @BoringDonut @OmerShubi @ishandutta0098 @ryan597 @leng-yue @alicanb @One-sixth @santurini @SpirinEgor @KogaiIrina @shanmugamr1992 @janeyx99 @asmith26 @dingusagar @AleksanderWWW @strawberrypie @solyaH @kaczmarj @voidful @water-vapor @bkiat1123 @rhiga2 @baskrahmer @felipewhitaker @mukhery @Quasar-Kim @robieta @one-matrix @jere357 @schmidt-ai @schuhschuh @anio @rjarun8 @callumhay @minhlong94 @klieret @giorgioskij @shihaoyin @JonathanRayner @NripeshN @marcimarc1 @bilelomrani1 @NikolasWolke @0x404 @quintenroets @Borodin @amorehead @SebastianGer @ioangatop @Tribhuvan0 @f0k @sameertantry @kwsp @nik777 @matsumotosan
Did you know?
When Chuck Norris trains a neural network, it not only learns, but it also gains the ability to defend itself from adversarial attacks by roundhouse kicking them into submission.