Skip to content

Commit c4c0cef

Browse files
authored
add util for ram efficient loading of model when using fsdp (#25107)
* add util for ram efficient loading of model when using fsdp * make fix-copies * fixes 😅 * docs * making it further easier to use * rename the function * refactor to handle fsdp ram efficiency in `from_pretrained` * fixes * fixes * fixes * update * fixes * revert `load_pretrained_model_only_on_rank0` * resolve `load_from_checkpoint`
1 parent 4e1dee0 commit c4c0cef

File tree

3 files changed

+105
-62
lines changed

3 files changed

+105
-62
lines changed

src/transformers/modeling_utils.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
is_torch_tpu_available,
7474
logging,
7575
replace_return_docstrings,
76+
strtobool,
7677
)
7778
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
7879
from .utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled, is_torch_fx_proxy
@@ -106,6 +107,14 @@
106107
_init_weights = True
107108

108109

110+
def is_fsdp_enabled():
111+
return strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
112+
113+
114+
def is_fsdp_enabled_and_dist_rank_0():
115+
return is_fsdp_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() == 0
116+
117+
109118
if is_sagemaker_mp_enabled():
110119
import smdistributed.modelparallel.torch as smp
111120
from smdistributed.modelparallel import __version__ as SMP_VERSION
@@ -458,7 +467,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
458467
)
459468
return safe_load_file(checkpoint_file)
460469
try:
461-
if is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0:
470+
if (
471+
(is_deepspeed_zero3_enabled() or is_fsdp_enabled)
472+
and torch.distributed.is_initialized()
473+
and torch.distributed.get_rank() > 0
474+
):
462475
map_location = "meta"
463476
else:
464477
map_location = "cpu"
@@ -2283,6 +2296,9 @@ def from_pretrained(
22832296
commit_hash = kwargs.pop("_commit_hash", None)
22842297
variant = kwargs.pop("variant", None)
22852298

2299+
if is_fsdp_enabled():
2300+
low_cpu_mem_usage = True
2301+
22862302
if use_auth_token is not None:
22872303
warnings.warn(
22882304
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
@@ -3238,7 +3254,8 @@ def _fix_key(key):
32383254
model_buffers = {".".join([prefix, key]) for key in model_buffers}
32393255
unexpected_keys = list(unexpected_keys - model_buffers)
32403256

3241-
if device_map is None:
3257+
model.tie_weights()
3258+
if device_map is None and not is_fsdp_enabled():
32423259
ptrs = collections.defaultdict(list)
32433260
for name, tensor in model.state_dict().items():
32443261
id_tensor = id_tensor_storage(tensor)
@@ -3443,23 +3460,35 @@ def _find_mismatched_keys(
34433460
)
34443461

34453462
if low_cpu_mem_usage:
3446-
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
3447-
model_to_load,
3448-
state_dict,
3449-
loaded_keys,
3450-
start_prefix,
3451-
expected_keys,
3452-
device_map=device_map,
3453-
offload_folder=offload_folder,
3454-
offload_index=offload_index,
3455-
state_dict_folder=state_dict_folder,
3456-
state_dict_index=state_dict_index,
3457-
dtype=dtype,
3458-
is_quantized=is_quantized,
3459-
is_safetensors=is_safetensors,
3460-
keep_in_fp32_modules=keep_in_fp32_modules,
3461-
)
3462-
error_msgs += new_error_msgs
3463+
if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
3464+
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
3465+
model_to_load,
3466+
state_dict,
3467+
loaded_keys,
3468+
start_prefix,
3469+
expected_keys,
3470+
device_map=device_map,
3471+
offload_folder=offload_folder,
3472+
offload_index=offload_index,
3473+
state_dict_folder=state_dict_folder,
3474+
state_dict_index=state_dict_index,
3475+
dtype=dtype,
3476+
is_quantized=is_quantized,
3477+
is_safetensors=is_safetensors,
3478+
keep_in_fp32_modules=keep_in_fp32_modules,
3479+
)
3480+
error_msgs += new_error_msgs
3481+
else:
3482+
for key, param in model_to_load.state_dict().items():
3483+
if param.device == torch.device("meta"):
3484+
if not (is_quantized):
3485+
set_module_tensor_to_device(
3486+
model, key, "cpu", torch.empty(*param.size(), dtype=dtype)
3487+
)
3488+
else:
3489+
set_module_quantized_tensor_to_device(
3490+
model, key, "cpu", torch.empty(*param.size(), dtype=dtype)
3491+
)
34633492
else:
34643493
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
34653494

src/transformers/trainer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -465,10 +465,6 @@ def __init__(
465465
):
466466
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
467467

468-
self.forward_prefetch = False
469-
if self.args.fsdp_config.get("forward_prefetch", False):
470-
self.forward_prefetch = True
471-
472468
self.limit_all_gathers = False
473469
if self.args.fsdp_config.get("limit_all_gathers", False):
474470
self.limit_all_gathers = True
@@ -1379,12 +1375,12 @@ def _wrap_model(self, model, training=True, dataloader=None):
13791375
auto_wrapper_callable = None
13801376
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
13811377
fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get(
1382-
"fsdp_transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
1378+
"transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
13831379
)
13841380

1385-
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
1381+
if self.args.fsdp_config["min_num_params"] > 0:
13861382
auto_wrap_policy = functools.partial(
1387-
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
1383+
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"]
13881384
)
13891385
elif fsdp_transformer_layer_cls_to_wrap is not None:
13901386
transformer_cls_to_wrap = set()
@@ -1517,7 +1513,12 @@ def train(
15171513
if resume_from_checkpoint is None:
15181514
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
15191515

1520-
if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled:
1516+
if (
1517+
resume_from_checkpoint is not None
1518+
and not is_sagemaker_mp_enabled()
1519+
and not self.is_deepspeed_enabled
1520+
and not self.is_fsdp_enabled
1521+
):
15211522
self._load_from_checkpoint(resume_from_checkpoint)
15221523

15231524
# If model was re-initialized, put it on the right device and update self.model_wrapped
@@ -1651,7 +1652,7 @@ def _inner_training_loop(
16511652

16521653
model = self._wrap_model(self.model_wrapped)
16531654

1654-
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
1655+
if (is_sagemaker_mp_enabled() or self.is_fsdp_enabled) and resume_from_checkpoint is not None:
16551656
self._load_from_checkpoint(resume_from_checkpoint, model)
16561657

16571658
# as the model is wrapped, don't use `accelerator.prepare`
@@ -3886,7 +3887,6 @@ def create_accelerator_and_postprocess(self):
38863887
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
38873888
"limit_all_gathers", fsdp_plugin.limit_all_gathers
38883889
)
3889-
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", fsdp_plugin.use_orig_params)
38903890

38913891
if self.is_deepspeed_enabled:
38923892
if getattr(self.args, "hf_deepspeed_config", None) is None:

src/transformers/training_args.py

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -436,13 +436,13 @@ class TrainingArguments:
436436
deepspeed json config file (e.g., `ds_config.json`) or an already loaded json file as `dict`.
437437
438438
A List of config and its options:
439-
- fsdp_min_num_params (`int`, *optional*, defaults to `0`):
439+
- min_num_params (`int`, *optional*, defaults to `0`):
440440
FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is
441441
passed).
442-
- fsdp_transformer_layer_cls_to_wrap (`List[str]`, *optional*):
442+
- transformer_layer_cls_to_wrap (`List[str]`, *optional*):
443443
List of transformer layer class names (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`,
444444
`T5Block` .... (useful only when `fsdp` flag is passed).
445-
- fsdp_backward_prefetch (`str`, *optional*)
445+
- backward_prefetch (`str`, *optional*)
446446
FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when
447447
`fsdp` field is passed).
448448
@@ -454,14 +454,22 @@ class TrainingArguments:
454454
- `"backward_post"` : This prefetches the next set of parameters after the current set of
455455
parameter’s
456456
gradient computation.
457-
- fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`)
457+
- forward_prefetch (`bool`, *optional*, defaults to `False`)
458458
FSDP's forward prefetch mode (useful only when `fsdp` field is passed).
459459
If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the
460460
forward pass.
461461
- limit_all_gathers (`bool`, *optional*, defaults to `False`)
462462
FSDP's limit_all_gathers (useful only when `fsdp` field is passed).
463463
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight
464464
all-gathers.
465+
- use_orig_params (`bool`, *optional*, defaults to `False`)
466+
If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed
467+
frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. Please
468+
refer this
469+
[blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019
470+
- sync_module_states (`bool`, *optional*, defaults to `True`)
471+
If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to
472+
ensure they are the same across all ranks after initialization
465473
- xla (`bool`, *optional*, defaults to `False`):
466474
Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature
467475
and its API may evolve in the future.
@@ -1520,44 +1528,44 @@ def __post_init__(self):
15201528
self.fsdp_config = {}
15211529

15221530
if isinstance(self.fsdp_config, str):
1531+
if len(self.fsdp) == 0:
1532+
warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.")
15231533
with io.open(self.fsdp_config, "r", encoding="utf-8") as f:
15241534
self.fsdp_config = json.load(f)
1535+
for k, v in self.fsdp_config.items():
1536+
if k.startswith("fsdp_"):
1537+
self.fsdp_config[k.replace("fsdp_", "")] = v
1538+
del self.fsdp_config[k]
15251539

15261540
if self.fsdp_min_num_params > 0:
15271541
warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)
15281542

1529-
self.fsdp_config["fsdp_min_num_params"] = max(
1530-
self.fsdp_config.get("fsdp_min_num_params", 0), self.fsdp_min_num_params
1531-
)
1543+
self.fsdp_config["min_num_params"] = max(self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params)
15321544

1533-
# if fsdp_config["fsdp_transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
1534-
if isinstance(self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None), str):
1535-
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = [
1536-
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
1537-
]
1545+
# if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
1546+
if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str):
1547+
self.fsdp_config["transformer_layer_cls_to_wrap"] = [self.fsdp_config["transformer_layer_cls_to_wrap"]]
15381548

15391549
if self.fsdp_transformer_layer_cls_to_wrap is not None:
15401550
warnings.warn(
15411551
"using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning
15421552
)
1543-
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
1544-
"fsdp_transformer_layer_cls_to_wrap", []
1553+
self.fsdp_config["transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
1554+
"transformer_layer_cls_to_wrap", []
15451555
) + [self.fsdp_transformer_layer_cls_to_wrap]
15461556

1547-
if len(self.fsdp) == 0 and self.fsdp_config["fsdp_min_num_params"] > 0:
1548-
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")
1557+
if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0:
1558+
warnings.warn("`min_num_params` is useful only when `--fsdp` is specified.")
15491559

1550-
if len(self.fsdp) == 0 and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
1551-
warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
1560+
if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
1561+
warnings.warn("`transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
15521562

15531563
if (
15541564
len(self.fsdp) > 0
1555-
and self.fsdp_config["fsdp_min_num_params"] > 0
1556-
and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None
1565+
and self.fsdp_config["min_num_params"] > 0
1566+
and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None
15571567
):
1558-
raise ValueError(
1559-
"`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive."
1560-
)
1568+
raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.")
15611569
self.fsdp_config["xla"] = self.fsdp_config.get("xla", False)
15621570
self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False)
15631571
if self.fsdp_config["xla"]:
@@ -1583,23 +1591,29 @@ def __post_init__(self):
15831591
FSDP_SHARDING_STRATEGY,
15841592
)
15851593

1594+
prefix = "FSDP_"
15861595
for fsdp_option in self.fsdp:
15871596
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
15881597
# set environment variable for FSDP sharding strategy
1589-
os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
1598+
os.environ[f"{prefix}SHARDING_STRATEGY"] = str(
1599+
FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1
1600+
)
15901601
elif fsdp_option == FSDPOption.OFFLOAD:
1591-
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
1602+
os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true"
15921603
elif fsdp_option == FSDPOption.AUTO_WRAP:
1593-
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
1594-
if self.fsdp_config["fsdp_min_num_params"] > 0:
1595-
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
1596-
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
1597-
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
1598-
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
1599-
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
1604+
os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
1605+
if self.fsdp_config["min_num_params"] > 0:
1606+
os.environ[f"{prefix}MIN_NUM_PARAMS"] = str(self.fsdp_config["min_num_params"])
1607+
os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
1608+
elif self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
1609+
os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"] = ",".join(
1610+
self.fsdp_config["transformer_layer_cls_to_wrap"]
16001611
)
16011612
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
1602-
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
1613+
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
1614+
os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefect", "false")
1615+
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")
1616+
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "false")
16031617

16041618
if self.tpu_metrics_debug:
16051619
warnings.warn(

0 commit comments

Comments
 (0)