Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/internal/trainer_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ Most of those are only useful if you are studying the code of the Trainer in the

[[autodoc]] torch_distributed_zero_first

[[autodoc]] load_pretrained_model_only_on_rank0

## Callbacks internals

[[autodoc]] trainer_callback.CallbackHandler
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3059,7 +3059,10 @@
_import_structure["sagemaker"] = []
_import_structure["time_series_utils"] = []
_import_structure["trainer"] = ["Trainer"]
_import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
_import_structure["trainer_pt_utils"] = [
"load_pretrained_model_only_on_rank0",
"torch_distributed_zero_first",
]
_import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"]

# TensorFlow-backed objects
Expand Down Expand Up @@ -6598,7 +6601,7 @@

# Trainer
from .trainer import Trainer
from .trainer_pt_utils import torch_distributed_zero_first
from .trainer_pt_utils import load_pretrained_model_only_on_rank0, torch_distributed_zero_first
from .trainer_seq2seq import Seq2SeqTrainer

# TensorFlow
Expand Down
58 changes: 37 additions & 21 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@
_init_weights = True


def is_fsdp_enabled():
return os.environ["ACCELERATE_USE_FSDP"]


def is_fsdp_enabled_and_dist_rank_0():
return is_fsdp_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() == 0


if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel import __version__ as SMP_VERSION
Expand Down Expand Up @@ -457,7 +465,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
)
return safe_load_file(checkpoint_file)
try:
if is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0:
if (
(is_deepspeed_zero3_enabled() or is_fsdp_enabled())
and torch.distributed.is_initialized()
and torch.distributed.get_rank() > 0
):
map_location = "meta"
else:
map_location = "cpu"
Expand Down Expand Up @@ -541,7 +553,7 @@ def load(module: nn.Module, state_dict, prefix=""):
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
elif not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
module._load_from_state_dict(*args)

for name, child in module._modules.items():
Expand Down Expand Up @@ -1481,7 +1493,7 @@ def _get_resized_embeddings(
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=0):
if torch.distributed.get_rank() == 0:
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
else:
elif not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]

return new_embeddings
Expand Down Expand Up @@ -1565,7 +1577,7 @@ def _get_resized_lm_head(
# Copy bias weights to new lm head
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
else:
elif not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
Expand Down Expand Up @@ -2193,6 +2205,9 @@ def from_pretrained(
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)

if is_fsdp_enabled():
low_cpu_mem_usage = True

if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
Expand Down Expand Up @@ -3265,23 +3280,24 @@ def _find_mismatched_keys(
)

if low_cpu_mem_usage:
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
is_quantized=is_quantized,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
)
error_msgs += new_error_msgs
if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
is_quantized=is_quantized,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
)
error_msgs += new_error_msgs
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)

Expand Down
11 changes: 3 additions & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,6 @@ def __init__(
):
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST

self.forward_prefetch = False
if self.args.fsdp_config.get("forward_prefect", False):
self.forward_prefetch = True

self.limit_all_gathers = False
if self.args.fsdp_config.get("limit_all_gathers", False):
self.limit_all_gathers = True
Expand Down Expand Up @@ -1379,12 +1375,12 @@ def _wrap_model(self, model, training=True, dataloader=None):
auto_wrapper_callable = None
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get(
"fsdp_transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
"transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
)

if self.args.fsdp_config["fsdp_min_num_params"] > 0:
if self.args.fsdp_config["min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"]
)
elif fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = set()
Expand Down Expand Up @@ -3825,7 +3821,6 @@ def create_accelerator_and_postprocess(self):
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
"limit_all_gathers", fsdp_plugin.limit_all_gathers
)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", fsdp_plugin.use_orig_params)

if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,3 +1125,16 @@ def smp_nested_concat(tensor):
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
# which is also the name of the decorator so Python is confused.
return tensor.concat().detach().cpu()


def load_pretrained_model_only_on_rank0(model_cls, config_cls, model_name_or_path):
from accelerate.state import PartialState

state = PartialState()
if state.is_main_process:
model = model_cls.from_pretrained(model_name_or_path, return_dict=True)
else:
with torch.device("meta"):
config = config_cls.from_pretrained(model_name_or_path)
model = model_cls.from_config(config)
return model
59 changes: 38 additions & 21 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,17 @@ class TrainingArguments:
FSDP's forward prefetch mode (useful only when `fsdp` field is passed).
If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the
forward pass.
- limit_all_gathers (`bool`, *optional*, defaults to `False`)
- fsdp_limit_all_gathers (`bool`, *optional*, defaults to `False`)
FSDP's limit_all_gathers (useful only when `fsdp` field is passed).
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight
all-gathers.
- fsdp_use_orig_params (`bool`, *optional*, defaults to `False`)
If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres.
Useful in cases such as parameter-efficient fine-tuning.
Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019
- fsdp_sync_module_states (`bool`, *optional*, defaults to `False`)
If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0
to ensure they are the same across all ranks after initialization
- xla (`bool`, *optional*, defaults to `False`):
Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature
and its API may evolve in the future.
Expand Down Expand Up @@ -1511,40 +1518,44 @@ def __post_init__(self):
self.fsdp_config = {}

if isinstance(self.fsdp_config, str):
if len(self.fsdp) == 0:
warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.")
with io.open(self.fsdp_config, "r", encoding="utf-8") as f:
self.fsdp_config = json.load(f)
for k, v in self.fsdp_config.items():
if k.startswith("fsdp_"):
self.fsdp_config[k.replace("fsdp", "")] = v
del self.fsdp_config[k]

if self.fsdp_min_num_params > 0:
warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)

self.fsdp_config["fsdp_min_num_params"] = max(
self.fsdp_config.get("fsdp_min_num_params", 0), self.fsdp_min_num_params
self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params
)

# if fsdp_config["fsdp_transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
if isinstance(self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None), str):
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = [
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
]
# if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str):
self.fsdp_config["transformer_layer_cls_to_wrap"] = [self.fsdp_config["transformer_layer_cls_to_wrap"]]

if self.fsdp_transformer_layer_cls_to_wrap is not None:
warnings.warn(
"using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning
)
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
"fsdp_transformer_layer_cls_to_wrap", []
self.fsdp_config["transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
"transformer_layer_cls_to_wrap", []
) + [self.fsdp_transformer_layer_cls_to_wrap]

if len(self.fsdp) == 0 and self.fsdp_config["fsdp_min_num_params"] > 0:
if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0:
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")

if len(self.fsdp) == 0 and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")

if (
len(self.fsdp) > 0
and self.fsdp_config["fsdp_min_num_params"] > 0
and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None
and self.fsdp_config["min_num_params"] > 0
and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None
):
raise ValueError(
"`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive."
Expand Down Expand Up @@ -1574,23 +1585,29 @@ def __post_init__(self):
FSDP_SHARDING_STRATEGY,
)

prefix = "FSDP_"
for fsdp_option in self.fsdp:
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
# set environment variable for FSDP sharding strategy
os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
os.environ[f"{prefix}SHARDING_STRATEGY"] = str(
FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1
)
elif fsdp_option == FSDPOption.OFFLOAD:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true"
elif fsdp_option == FSDPOption.AUTO_WRAP:
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
if self.fsdp_config["fsdp_min_num_params"] > 0:
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
os.environ[f"{prefix}MIN_NUM_PARAMS"] = str(self.fsdp_config["min_num_params"])
os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["transformer_layer_cls_to_wrap"]
)
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefect", "false")
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "false")
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "false")

if self.tpu_metrics_debug:
warnings.warn(
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -8493,6 +8493,10 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


def load_pretrained_model_only_on_rank0(*args, **kwargs):
requires_backends(load_pretrained_model_only_on_rank0, ["torch"])


def torch_distributed_zero_first(*args, **kwargs):
requires_backends(torch_distributed_zero_first, ["torch"])

Expand Down