-
Notifications
You must be signed in to change notification settings - Fork 550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable LoRA + FSDP2 #855
enable LoRA + FSDP2 #855
Changes from 4 commits
e5826a1
64fc870
0cd21c6
589191e
c801f26
19a2d70
441da10
750b9e5
3d632d5
cb3abb3
e68804a
d6af9a2
b616394
a400497
e9de63c
05d3895
7a5bb80
64bf49c
cb1bba4
ac516e9
bfde704
102db31
0b66651
672aabb
6af2723
42ad99c
74f6175
f1b8a5e
36e6829
08cd1fd
559bc4d
2333134
49a0364
dc2ce02
0a604aa
fa83140
4b5a895
a2e34ec
6142031
7607e14
1899beb
c1cfabb
d7382ae
d1ff53b
1eb9e87
695e959
e10f638
b1e3d30
944a723
ac5f7aa
d769626
f90c3cc
42ef49a
170de94
f8a7018
a3b2f3e
1a692b3
8fbbc4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,20 +17,15 @@ | |
|
||
from torch import nn | ||
from torch.distributed import destroy_process_group, init_process_group | ||
from torch.distributed.fsdp import ( | ||
FullOptimStateDictConfig, | ||
FullStateDictConfig, | ||
FullyShardedDataParallel as FSDP, | ||
StateDictType, | ||
) | ||
from torch.distributed._composable.fsdp import fully_shard | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader, DistributedSampler | ||
from torchtune import config, modules, utils | ||
from torchtune.modules.peft import LoRALinear | ||
from torchtune.modules.peft.peft_utils import ( | ||
get_adapter_params, | ||
get_merged_lora_ckpt, | ||
set_trainable_params, | ||
validate_state_dict_for_lora, | ||
) | ||
from torchtune.recipe_interfaces import FTRecipeInterface | ||
|
||
|
@@ -277,86 +272,62 @@ def _setup_model( | |
the correct device. | ||
""" | ||
|
||
if self._device.type != "cuda": | ||
raise ValueError( | ||
f'FSDP needs device="cuda" but found device={self._device.type}' | ||
) | ||
|
||
if self._is_rank_zero: | ||
log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") | ||
log.info("FSDP is enabled. Model init and checkpoint loading on Rank 0 ...") | ||
init_start = time.perf_counter() | ||
|
||
with utils.set_default_dtype(self._dtype): | ||
model = config.instantiate(cfg_model) | ||
with utils.set_default_dtype(self._dtype), torch.device("meta"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry not able to comment above, but the docstring of this function should be updated since we're no longer initializing on CPU? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the docstring used to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh. Just got you point. Updated docstring for |
||
model = config.instantiate(cfg_model) | ||
|
||
log.info( | ||
f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" | ||
) | ||
# Note: this needs to be set before wrapping with FSDP | ||
self.adapter_params = get_adapter_params(model) | ||
set_trainable_params(model, self.adapter_params) | ||
|
||
# The model contains LoRA params which won't have any matching keys in | ||
# the state dict. As a result, we need to load with strict=False. | ||
# Before loading the state dict, ensure the state dict keys for the base | ||
# model and adapters (if available) match the keys in the full LoRA model | ||
# This is a good sanity check to prevent silent errors | ||
validate_state_dict_for_lora( | ||
lora_attn_modules=cfg_model.lora_attn_modules, | ||
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp, | ||
apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False), | ||
full_model_state_dict_keys=model.state_dict().keys(), | ||
lora_state_dict_keys=( | ||
lora_weights_state_dict.keys() | ||
if lora_weights_state_dict is not None | ||
else None | ||
), | ||
base_model_state_dict_keys=base_model_state_dict.keys(), | ||
if enable_activation_checkpointing: | ||
utils.set_activation_checkpointing( | ||
model, auto_wrap_policy={modules.TransformerDecoderLayer} | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
# Load both the base model weights and (if available) the adapter weights. Both | ||
# of this should happen only on Rank 0 | ||
model.load_state_dict(base_model_state_dict, strict=False) | ||
if lora_weights_state_dict: | ||
model.load_state_dict(lora_weights_state_dict, strict=False) | ||
for m in model.modules(): | ||
if isinstance(m, modules.TransformerDecoderLayer): | ||
fully_shard(m) | ||
fully_shard(model) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the noob question, but can you help me understand what's going on here? Why do I need to An unrelated question: if I have enough GPU memory, should I be thinking about using something similar to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In FSDP1, we wrap each In FSDP2, we un-blackboxed it to this for-loop. It you perfer, this can be factored into a util function in torchtune so user call Personally I have bias towards un-blackboxed approach since people can modify the for-loop to achieve different wrapping There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the equivalence SHARD_GRAD_OP in FSDP2 is
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for the explanation! I love the un-blackboxed approach here - just needs more documentation and explanation :) After reading the FSDP2 RFC, this became a lot clearer. |
||
|
||
else: | ||
# For non-zero ranks, load the model on meta device | ||
with utils.set_default_dtype(self._dtype), torch.device("meta"): | ||
model = config.instantiate(cfg_model) | ||
utils.load_from_full_state_dict( | ||
model, base_model_state_dict, self._device, self._is_rank_zero | ||
) | ||
if lora_weights_state_dict: | ||
utils.load_from_full_state_dict( | ||
model, lora_weights_state_dict, self._device, self._is_rank_zero | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pros and cons of meta init. pros is 4.5x speed up during model init and thus shorter TTFB. cons is user need to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this because these params are not being loaded from checkpoint? Or do I misunderstand? If this is indeed the reason, how do we handle this code block when the LoRA params are being loaded from checkpoint (eg: when resuming training)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you are right. when finetuning from a original HF checkpoint, for resuming training, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got you, thanks so much for the explanation! I think something that would be super helpful would be document here in the form of comments the relationship between:
Also I think there was a technical reason with FSDP1 to call the function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point! will add comment to explain FSDP1 calls It's just FSDP1 have a contract to call overrided |
||
if self._dtype == torch.bfloat16: | ||
model = model.to(torch.bfloat16) | ||
with utils.set_default_dtype(self._dtype), self._device: | ||
for m in model.modules(): | ||
if isinstance(m, LoRALinear) and not lora_weights_state_dict: | ||
# to_empty is needed since kaiming_uniform_ is inplace | ||
m.to_empty(device=self._device) | ||
awgu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
m.initialize_parameters() | ||
if isinstance(m, modules.RotaryPositionalEmbeddings): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to clarify, we special handle RoPE because the buffer is not being loaded from a state dict, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's correct There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar comment here, let's document what's happening so that users can easily understand why we initialize these modules separately. |
||
m.reset_parameters() | ||
|
||
model = model.to(self._dtype) | ||
|
||
# LoRA hyper-params needed for merging weights while saving checkpoints | ||
self._lora_rank = cfg_model.lora_rank | ||
self._lora_alpha = cfg_model.lora_alpha | ||
|
||
# Note: this needs to be set before wrapping with FSDP | ||
self.adapter_params = get_adapter_params(model) | ||
set_trainable_params(model, self.adapter_params) | ||
|
||
model = FSDP( | ||
module=model, | ||
auto_wrap_policy=utils.lora_fsdp_wrap_policy( | ||
modules_to_wrap={modules.TransformerDecoderLayer} | ||
), | ||
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, | ||
device_id=self._device, | ||
# this recipe does not currently support mixed precision training | ||
mixed_precision=None, | ||
# Ensure we broadcast params and buffers from rank 0 | ||
sync_module_states=True, | ||
# Initialize empty modules on all non-zero ranks | ||
param_init_fn=( | ||
lambda module: module.to_empty( | ||
device=torch.device("cuda"), recurse=False | ||
) | ||
if not self._is_rank_zero | ||
else None | ||
), | ||
) | ||
|
||
# Ensure no params and buffers are on meta device | ||
utils.validate_no_params_on_meta_device(model) | ||
|
||
if enable_activation_checkpointing: | ||
utils.set_activation_checkpointing( | ||
model, auto_wrap_policy={modules.TransformerDecoderLayer} | ||
) | ||
if self._is_rank_zero: | ||
log.info( | ||
f"Model init and checkpoint loading took {time.perf_counter() - init_start:.2f} secs" | ||
) | ||
memory_stats = utils.get_memory_stats(device=self._device) | ||
utils.log_memory_stats(memory_stats) | ||
|
||
|
@@ -451,22 +422,22 @@ def save_checkpoint( | |
intermediate_checkpoint = epoch + 1 < self.total_epochs | ||
# To prevent GPU memory from spiking during checkpoint save, | ||
# we consolidate the full model and optim state dicts on CPU for rank 0 | ||
with FSDP.state_dict_type( | ||
self._model, | ||
StateDictType.FULL_STATE_DICT, | ||
FullStateDictConfig(offload_to_cpu=True, rank0_only=True), | ||
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), | ||
): | ||
cpu_state_dict = self._model.state_dict() | ||
if intermediate_checkpoint: | ||
opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) | ||
weifengpy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
sharded_sd = self._model.state_dict() | ||
cpu_state_dict = {} | ||
for param_name, sharded_param in sharded_sd.items(): | ||
full_param = sharded_param.full_tensor() | ||
if self._is_rank_zero: | ||
cpu_state_dict[param_name] = full_param.cpu() | ||
else: | ||
opt_state_dict = None | ||
del full_param | ||
|
||
# TODO: implement optim state dict | ||
opt_state_dict = None | ||
|
||
# Now that we have the model and opt state dict, create the actual checkpoint dict | ||
# to be sent to the checkpointer and ultimately written to file | ||
if self._is_rank_zero: | ||
|
||
# Filter out the adapter keys and weights from the model state dict. These will | ||
# be saved separately | ||
adapter_key_filter = lambda x: x in self.adapter_params | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,11 +8,13 @@ | |
import logging | ||
import os | ||
from itertools import chain | ||
from typing import Callable, Dict, Optional, Set, Tuple, Type, Union | ||
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.distributed._composable.fsdp | ||
from torch import nn | ||
from torch.distributed._tensor import distribute_tensor | ||
from torch.distributed.fsdp import ( | ||
FullyShardedDataParallel as FSDP, | ||
MixedPrecision, | ||
|
@@ -297,3 +299,31 @@ def lora_wrap_fsdp(module: nn.Module, recurse: bool, **kwargs): | |
return isinstance(module, tuple(modules_to_wrap)) | ||
|
||
return lora_wrap_fsdp | ||
|
||
|
||
weifengpy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def load_from_full_state_dict( | ||
model: torch.distributed._composable.fsdp.FSDP, | ||
full_sd: Dict[str, Any], | ||
device: torch.device, | ||
is_rank_zero: bool, | ||
): | ||
meta_sharded_sd = model.state_dict() | ||
sharded_sd = {} | ||
for param_name, full_tensor in full_sd.items(): | ||
sharded_meta_param = meta_sharded_sd.get(param_name) | ||
mesh = sharded_meta_param.device_mesh | ||
if is_rank_zero: | ||
full_tensor = full_tensor.detach().to(device) | ||
torch.distributed.broadcast(full_tensor, src=0, group=mesh.get_group(0)) | ||
awgu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
full_tensor = torch.empty( | ||
sharded_meta_param.size(), | ||
device=device, | ||
dtype=sharded_meta_param.dtype, | ||
) | ||
torch.distributed.broadcast(full_tensor, src=0, group=mesh.get_group(0)) | ||
sharded_tensor = distribute_tensor( | ||
full_tensor, mesh, sharded_meta_param.placements | ||
) | ||
sharded_sd[param_name] = nn.Parameter(sharded_tensor) | ||
model.load_state_dict(sharded_sd, strict=False, assign=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we catch missing and unexpected keys from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for FSDP2, it's clean FQNs without FSDP prefix. For example, FSDP2 is clean because 1) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is great! I think this means we can actually do validation of LoRA state dict load more cleanly (note that we actually have two separate utilities for this for the single-device vs distributed case because of this FSDP prefix issue). Not a concern for this PR but this will allow us to clean up our code a bit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.