diff --git a/examples/language-modeling/fsdp_config.json b/examples/language-modeling/fsdp_config.json new file mode 100644 index 0000000000..4aae21af2f --- /dev/null +++ b/examples/language-modeling/fsdp_config.json @@ -0,0 +1,12 @@ +{ + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_backward_prefetch": "BACKWARD_PRE", + "fsdp_forward_prefetch": false, + "fsdp_offload_params": false, + "fsdp_sharding_strategy": 1, + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_sync_module_states": true, + "fsdp_use_orig_params": true, + "transformer_layer_cls_to_wrap": "GaudiLlamaDecoderLayer", + "fsdp_activation_checkpointing": false +} diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index b1bf69edcd..b480990752 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -30,11 +30,8 @@ import torch import transformers from datasets import load_dataset -from peft import ( - LoraConfig, - TaskType, - get_peft_model, -) +from peft import LoraConfig, TaskType, get_peft_model, tuners +from peft.utils.other import fsdp_auto_wrap_policy from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -45,6 +42,7 @@ from transformers.trainer_utils import is_main_process from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments +from optimum.habana.peft.layer import GaudiLoraLayerLinearForward from optimum.habana.utils import set_seed @@ -674,6 +672,7 @@ def compute_metrics(eval_preds): ) if training_args.gradient_checkpointing: model.enable_input_require_grads() + tuners.lora.layer.Linear.forward = GaudiLoraLayerLinearForward lora_model = get_peft_model(model, peft_config) if training_args.bf16: lora_model = lora_model.to(torch.bfloat16) @@ -695,6 +694,10 @@ def compute_metrics(eval_preds): preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None, ) + # Solution for https://github.com/huggingface/peft/blob/v0.6.2/README.md#caveats (1) + if training_args.fsdp and training_args.fsdp_config["auto_wrap_policy"] == "TRANSFORMER_BASED_WRAP": + trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(lora_model) + if training_args.do_train: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) trainer.save_model() diff --git a/examples/question-answering/fsdp_config.json b/examples/question-answering/fsdp_config.json new file mode 100644 index 0000000000..27e9aeaf77 --- /dev/null +++ b/examples/question-answering/fsdp_config.json @@ -0,0 +1,12 @@ +{ + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_backward_prefetch": "BACKWARD_PRE", + "fsdp_forward_prefetch": false, + "fsdp_offload_params": false, + "fsdp_sharding_strategy": 1, + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_sync_module_states": true, + "fsdp_use_orig_params": true, + "transformer_layer_cls_to_wrap": "BertLayer", + "fsdp_activation_checkpointing": false +} diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py index b47b75a3c4..56e9a63e23 100644 --- a/optimum/habana/accelerate/accelerator.py +++ b/optimum/habana/accelerate/accelerator.py @@ -16,6 +16,7 @@ from __future__ import annotations import contextlib +import functools import math import os import sys @@ -37,7 +38,6 @@ DistributedDataParallelKwargs, DistributedType, FP8RecipeKwargs, - FullyShardedDataParallelPlugin, GradientAccumulationPlugin, GradScalerKwargs, InitProcessGroupKwargs, @@ -50,8 +50,10 @@ check_os_kernel, convert_outputs_to_fp32, is_deepspeed_available, + is_torch_version, parse_choice_from_env, ) +from accelerate.utils.constants import FSDP_PYTORCH_VERSION from accelerate.utils.operations import _gpu_gather from accelerate.utils.other import is_compiled_module from torch.optim.lr_scheduler import LRScheduler @@ -68,7 +70,12 @@ from .data_loader import gaudi_prepare_data_loader from .state import GaudiAcceleratorState, GaudiPartialState -from .utils import GaudiDistributedType, GaudiDynamoBackend, GaudiTorchDynamoPlugin +from .utils import ( + GaudiDistributedType, + GaudiDynamoBackend, + GaudiFullyShardedDataParallelPlugin, + GaudiTorchDynamoPlugin, +) logger = get_logger(__name__) @@ -87,7 +94,7 @@ def __init__( gradient_accumulation_steps: int = 1, cpu: bool = False, deepspeed_plugin: DeepSpeedPlugin | None = None, - fsdp_plugin: FullyShardedDataParallelPlugin | None = None, + fsdp_plugin: GaudiFullyShardedDataParallelPlugin | None = None, megatron_lm_plugin: MegatronLMPlugin | None = None, rng_types: list[str | RNGType] | None = None, log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None, @@ -142,6 +149,27 @@ def __init__( deepspeed_plugin.set_mixed_precision(mixed_precision) deepspeed_plugin.set_deepspeed_weakref() + if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or isinstance( + fsdp_plugin, GaudiFullyShardedDataParallelPlugin + ): + import importlib.metadata + + torch_version = importlib.metadata.version("torch") + torch_version = torch_version[5:] + if is_torch_version("<", FSDP_PYTORCH_VERSION + torch_version): + raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}") + + if fsdp_plugin is None: # init from env variables + fsdp_plugin = ( + GaudiFullyShardedDataParallelPlugin() + if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" + else None + ) + else: + if not isinstance(fsdp_plugin, GaudiFullyShardedDataParallelPlugin): + raise TypeError("`fsdp_plugin` must be a GaudiFullyShardedDataParallelPlugin object.") + os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided + # Kwargs handlers self.ddp_handler = None self.scaler_handler = None @@ -370,6 +398,54 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e if any(p.requires_grad for p in model.parameters()): kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {} model = torch.nn.parallel.DistributedDataParallel(model, **kwargs) + elif self.distributed_type == GaudiDistributedType.FSDP: + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + + # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, + # don't wrap it again + # In case the model is already compiled using PyTorch 2.0 and the wrapped model in it + # is a FSDP model, don't wrap it again + is_type_fsdp = isinstance(model, FSDP) or ( + is_compiled_module(model) and isinstance(model._orig_mod, FSDP) + ) + + if not is_type_fsdp: + self.state.fsdp_plugin.set_auto_wrap_policy(model) + fsdp_plugin = self.state.fsdp_plugin + kwargs = { + "sharding_strategy": fsdp_plugin.sharding_strategy, + "cpu_offload": fsdp_plugin.cpu_offload, + "auto_wrap_policy": fsdp_plugin.auto_wrap_policy, + "mixed_precision": fsdp_plugin.mixed_precision_policy, + "sync_module_states": fsdp_plugin.sync_module_states, + "backward_prefetch": fsdp_plugin.backward_prefetch, + "forward_prefetch": fsdp_plugin.forward_prefetch, + "use_orig_params": fsdp_plugin.use_orig_params, + "param_init_fn": fsdp_plugin.param_init_fn, + "ignored_modules": fsdp_plugin.ignored_modules, + "limit_all_gathers": fsdp_plugin.limit_all_gathers, + "device_id": torch.device("hpu"), + } + model = FSDP(model, **kwargs) + if fsdp_plugin.activation_checkpointing: + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, + ) + + apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=functools.partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ), + auto_wrap_policy=fsdp_plugin.auto_wrap_policy, + ) + # if the previous and current models are same, delete the previous one + if len(self._models) > 1 and (self._models[-2] is self._models[-1]): + del self._models[-2] + self._models[-1] = model # torch.compile should be called last and only if the model isn't already compiled. if self.state.dynamo_plugin.backend != GaudiDynamoBackend.NO and not is_compiled_module(model): model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs()) @@ -672,7 +748,11 @@ def gather(self, tensor): tensor([0, 1, 2, 3]) ``` """ - if GaudiPartialState().distributed_type in [GaudiDistributedType.MULTI_HPU, GaudiDistributedType.DEEPSPEED]: + if GaudiPartialState().distributed_type in [ + GaudiDistributedType.MULTI_HPU, + GaudiDistributedType.DEEPSPEED, + GaudiDistributedType.FSDP, + ]: return _gpu_gather(tensor) else: return tensor @@ -719,6 +799,14 @@ def get_state_dict(self, model, unwrap=True): from deepspeed.checkpoint.utils import clone_tensors_for_torch_save state_dict = clone_tensors_for_torch_save(self.unwrap_model(model).state_dict()) + # copied from https://github.com/huggingface/accelerate/blob/6f05bbd41a179cc9a86238c7c6f3f4eded70fbd8/src/accelerate/accelerator.py#L3057 + elif self.distributed_type == DistributedType.FSDP: + from torch.distributed.fsdp import FullStateDictConfig, StateDictType + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + + full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config): + state_dict = model.state_dict() else: if unwrap: model = self.unwrap_model(model) diff --git a/optimum/habana/accelerate/state.py b/optimum/habana/accelerate/state.py index 7ccbdbf593..5c81e06a52 100644 --- a/optimum/habana/accelerate/state.py +++ b/optimum/habana/accelerate/state.py @@ -67,6 +67,11 @@ def __init__(self, cpu: bool = False, **kwargs): deepspeed.init_distributed(dist_backend=self.backend, **kwargs) logger.info("DeepSpeed is enabled.") self._mixed_precision = "no" # deepspeed handles mixed_precision using deepspeed_config + elif os.environ.get("ACCELERATE_USE_FSDP", "false") == "true": + self.distributed_type = GaudiDistributedType.FSDP + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend=self.backend, rank=rank, world_size=world_size) + logger.info("Enabled distributed run.") else: self.distributed_type = GaudiDistributedType.MULTI_HPU if not torch.distributed.is_initialized(): @@ -115,6 +120,7 @@ def wait_for_everyone(self): GaudiDistributedType.MULTI_CPU, GaudiDistributedType.DEEPSPEED, GaudiDistributedType.MULTI_HPU, + GaudiDistributedType.FSDP, ): torch.distributed.barrier() @@ -171,6 +177,10 @@ def __init__( ) if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu: self.deepspeed_plugin = deepspeed_plugin + if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" and not cpu: + if self._mixed_precision != "no": + fsdp_plugin.set_mixed_precision(self._mixed_precision) + self.fsdp_plugin = fsdp_plugin GaudiPartialState._shared_state["distributed_type"] = self.distributed_type self.use_ipex = False diff --git a/optimum/habana/accelerate/utils/__init__.py b/optimum/habana/accelerate/utils/__init__.py index 5e5de8a194..6dd629291d 100644 --- a/optimum/habana/accelerate/utils/__init__.py +++ b/optimum/habana/accelerate/utils/__init__.py @@ -1 +1,6 @@ -from .dataclasses import GaudiDistributedType, GaudiDynamoBackend, GaudiTorchDynamoPlugin +from .dataclasses import ( + GaudiDistributedType, + GaudiDynamoBackend, + GaudiFullyShardedDataParallelPlugin, + GaudiTorchDynamoPlugin, +) diff --git a/optimum/habana/accelerate/utils/dataclasses.py b/optimum/habana/accelerate/utils/dataclasses.py index 18346477c1..07e256372f 100644 --- a/optimum/habana/accelerate/utils/dataclasses.py +++ b/optimum/habana/accelerate/utils/dataclasses.py @@ -17,6 +17,9 @@ from dataclasses import dataclass from enum import Enum +import torch +from accelerate.utils import FullyShardedDataParallelPlugin +from accelerate.utils.constants import FSDP_BACKWARD_PREFETCH from accelerate.utils.dataclasses import BaseEnum, TorchDynamoPlugin from accelerate.utils.environment import str_to_bool @@ -31,12 +34,14 @@ class GaudiDistributedType(str, Enum): - **NO** -- Not a distributed environment, just a single process. - **MULTI_HPU** -- Distributed on multiple HPUs. - **DEEPSPEED** -- Using DeepSpeed. + - **FSDP** -- Using FSDP. """ # Subclassing str as well as Enum allows the `GaudiDistributedType` to be JSON-serializable out of the box. NO = "NO" MULTI_HPU = "MULTI_HPU" DEEPSPEED = "DEEPSPEED" + FSDP = "FSDP" class GaudiDynamoBackend(str, BaseEnum): @@ -106,3 +111,36 @@ def __post_init__(self): self.fullgraph = str_to_bool(os.environ.get(prefix + "USE_FULLGRAPH", "False")) == 1 if self.dynamic is None: self.dynamic = str_to_bool(os.environ.get(prefix + "USE_DYNAMIC", "False")) == 1 + + +@dataclass +class GaudiFullyShardedDataParallelPlugin(FullyShardedDataParallelPlugin): + def __post_init__(self): + from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, CPUOffload, ShardingStrategy + + prefix = "FSDP_" + if self.sharding_strategy is None: + self.sharding_strategy = ShardingStrategy(int(os.environ.get(prefix + "SHARDING_STRATEGY", 1))) + + if self.cpu_offload is None: + if str_to_bool(os.environ.get(prefix + "OFFLOAD_PARAMS", "False")) == 1: + self.cpu_offload = CPUOffload(offload_params=True) + else: + self.cpu_offload = CPUOffload(offload_params=False) + + if self.backward_prefetch is None: + prefetch_policy = os.environ.get(prefix + "BACKWARD_PREFETCH", "NO_PREFETCH") + if prefetch_policy != FSDP_BACKWARD_PREFETCH[-1]: + self.backward_prefetch = BackwardPrefetch(FSDP_BACKWARD_PREFETCH.index(prefetch_policy) + 1) + + if self.state_dict_type is None: + state_dict_type_policy = os.environ.get(prefix + "STATE_DICT_TYPE", "FULL_STATE_DICT") + self.set_state_dict_type(state_dict_type_policy) + self.use_orig_params = str_to_bool(os.environ.get(prefix + "USE_ORIG_PARAMS", "False")) == 1 + self.sync_module_states = str_to_bool(os.environ.get(prefix + "SYNC_MODULE_STATES", "True")) == 1 + self.forward_prefetch = str_to_bool(os.environ.get(prefix + "FORWARD_PREFETCH", "False")) == 1 + self.activation_checkpointing = str_to_bool(os.environ.get(prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1 + + if self.sync_module_states: + device = torch.device("hpu") + self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False) diff --git a/optimum/habana/peft/__init__.py b/optimum/habana/peft/__init__.py new file mode 100644 index 0000000000..2ba5892ad3 --- /dev/null +++ b/optimum/habana/peft/__init__.py @@ -0,0 +1 @@ +from .layer import GaudiLoraLayerLinearForward diff --git a/optimum/habana/peft/layer.py b/optimum/habana/peft/layer.py new file mode 100644 index 0000000000..f2e7561792 --- /dev/null +++ b/optimum/habana/peft/layer.py @@ -0,0 +1,31 @@ +from typing import Any + +import torch + + +def GaudiLoraLayerLinearForward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + # https://github.com/huggingface/peft/blob/4b02148af252c17e36b0a4b995f9e8519806fbb5/src/peft/tuners/lora/layer.py#L354C1-L376C22 + # only differences are avoiding inplace update of "result" to prevent error from torch Dynamo in torch.compile mode of execution + # and replacing self.base_layer by self._linear + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self._linear(x, *args, **kwargs) + elif self.merged: + result = self._linear(x, *args, **kwargs) + else: + result = self._linear(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) + result = result.clone() + lora_B(lora_A(dropout(x))) * scaling + + result = result.to(previous_dtype) + return result diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 8071985a6f..8ca4dfe58e 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -60,8 +60,14 @@ def gaudi_llama_rmsnorm_forward(self, hidden_states): - override RMSNorm with Habana fused RMSNorm """ if hidden_states.device.type == "hpu" and FusedRMSNorm: - hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) - return hidden_states + # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype + if hidden_states.dtype != self.weight.dtype: + orig_dtype = hidden_states.dtype + hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon) + return hidden_states.to(orig_dtype) + else: + hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) + return hidden_states else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index c393733f22..c04836a815 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -33,6 +33,7 @@ from accelerate import skip_first_batches from accelerate.utils import DistributedDataParallelKwargs, GradientAccumulationPlugin from huggingface_hub import upload_folder +from packaging import version from torch.utils.data import DataLoader, Dataset, RandomSampler from transformers import Trainer from transformers.data.data_collator import DataCollator @@ -79,6 +80,7 @@ WEIGHTS_INDEX_NAME, WEIGHTS_NAME, PushInProgress, + is_accelerate_available, is_datasets_available, is_peft_available, is_safetensors_available, @@ -115,6 +117,12 @@ if is_deepspeed_available(): from accelerate.utils import DeepSpeedSchedulerWrapper +if is_accelerate_available(): + from accelerate import __version__ as accelerate_version + from accelerate.utils import ( + load_fsdp_optimizer, + save_fsdp_optimizer, + ) if TYPE_CHECKING: import optuna @@ -287,8 +295,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: else: num_samples = len(self.train_dataset) if ( - self.args.use_lazy_mode - and not self.args.dataloader_drop_last + not self.args.dataloader_drop_last and len(self.train_dataset) % self.args.per_device_train_batch_size != 0 and self.args.parallel_mode != ParallelMode.DISTRIBUTED ): @@ -467,7 +474,7 @@ def train( if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") - if resume_from_checkpoint is not None and not self.is_deepspeed_enabled: + if resume_from_checkpoint is not None and not self.is_deepspeed_enabled and not self.is_fsdp_enabled: self._load_from_checkpoint(resume_from_checkpoint) # If model was re-initialized, put it on the right device and update self.model_wrapped @@ -563,6 +570,8 @@ def _inner_training_loop( if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: debug_overflow = DebugUnderflowOverflow(self.model) # noqa + delay_optimizer_creation = self.is_fsdp_enabled + # We need to reset the scheduler, as its parameters may be different on subsequent calls if self._created_lr_scheduler: self.lr_scheduler = None @@ -571,7 +580,8 @@ def _inner_training_loop( if self.is_deepspeed_enabled: self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) - self.create_optimizer_and_scheduler(num_training_steps=max_steps) + if not delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState() self.state.is_hyper_param_search = trial is not None @@ -635,6 +645,11 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX use_accelerator_prepare = True if model is self.model else False + if delay_optimizer_creation: + if use_accelerator_prepare: + self.model = self.accelerator.prepare(self.model) + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + # prepare using `accelerator` prepare if use_accelerator_prepare: self.model.train() @@ -646,6 +661,9 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): self.model, self.optimizer, self.lr_scheduler ) + if self.is_fsdp_enabled: + self.model = self.model_wrapped = model + # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model @@ -658,6 +676,10 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): if resume_from_checkpoint is not None and self.is_deepspeed_enabled: deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) + # fsdp ckpt loading + if resume_from_checkpoint is not None and self.is_fsdp_enabled: + self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) + # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) @@ -677,6 +699,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): # important: at this point: # self.model is the Transformers Model # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. + # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. # Train! logger.info("***** Running training *****") @@ -1019,6 +1042,10 @@ def _load_best_model(self): # TODO: check if the code below works # if self.is_deepspeed_enabled: # deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) + # elif self.is_fsdp_enabled: + # load_result = load_fsdp_model( + # self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint + # ) if ( os.path.exists(best_model_path) or os.path.exists(best_safe_model_path) @@ -1179,8 +1206,13 @@ def _save_checkpoint(self, model, trial, metrics=None): else: self.model_wrapped.save_checkpoint(output_dir) + if self.is_fsdp_enabled: + save_fsdp_optimizer( + self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir + ) + # Save optimizer and scheduler - if self.args.should_save and not self.is_deepspeed_enabled: + if self.args.should_save and not self.is_deepspeed_enabled and not self.is_fsdp_enabled: # deepspeed.save_checkpoint above saves model/optim/sched # This block is exectuted by the main process only optim_dict = self.optimizer.state_dict() @@ -1285,10 +1317,18 @@ def _load_optimizer_and_scheduler(self, checkpoint): # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more # likely to get OOM on CPU (since we load num_gpu times the optimizer state map_location = "cpu" if self.args.use_habana else self.args.device - - self.optimizer.load_state_dict( - torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) - ) + if self.is_fsdp_enabled: + load_fsdp_optimizer( + self.accelerator.state.fsdp_plugin, + self.accelerator, + self.optimizer, + self.model, + checkpoint, + ) + else: + self.optimizer.load_state_dict( + torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) + ) with warnings.catch_warnings(record=True) as caught_warnings: self.lr_scheduler.load_state_dict( @@ -1395,8 +1435,17 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa """ if output_dir is None: output_dir = self.args.output_dir - - if self.is_deepspeed_enabled: + # copy from https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/trainer.py#L2825 + # Note we picked this code from transformers 0.36.2 (when rest of code is from older version) because without this checkpoint with LoRA + # was not coming out correct. + if self.is_fsdp_enabled: + if ("FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)) and ( + version.parse(accelerate_version) > version.parse("0.24.1") + ): + state_dict = self.accelerator.get_state_dict(self.model) + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + elif self.is_deepspeed_enabled: # this takes care of everything as long as we aren't under zero3 try: state_dict = self.accelerator.get_state_dict(self.deepspeed) @@ -1494,6 +1543,9 @@ def evaluation_loop( else self.accelerator.prepare_model(model, evaluation_mode=True) ) + if self.is_fsdp_enabled: + self.model = model + # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model @@ -1501,6 +1553,7 @@ def evaluation_loop( # backward compatibility if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped + model.eval() # Do not use HPU graphs if the training is ongoing because it detaches gradients @@ -2037,13 +2090,13 @@ def create_accelerator_and_postprocess(self): dispatch_batches=self.args.dispatch_batches, deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin, - even_batches=self.args.use_lazy_mode and not self.args.dataloader_drop_last, + even_batches=not self.args.dataloader_drop_last, distribution_strategy=self.args.distribution_strategy, ) # deepspeed and accelerate flags covering both trainer args and accelerate launcher self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None - self.is_fsdp_enabled = False + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None # post accelerator creation setup if self.is_deepspeed_enabled: @@ -2055,6 +2108,24 @@ def create_accelerator_and_postprocess(self): ds_plugin.hf_ds_config = GaudiTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config + # copy of https://github.com/huggingface/transformers/blob/b71f20a7c9f3716d30f6738501559acf863e2c5c/src/transformers/trainer.py#L3991 + # post accelerator creation setup + if self.is_fsdp_enabled: + fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get( + "limit_all_gathers", fsdp_plugin.limit_all_gathers + ) + if is_accelerate_available("0.23.0"): + fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get( + "activation_checkpointing", fsdp_plugin.activation_checkpointing + ) + if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing: + raise ValueError( + "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " + "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " + "when using FSDP." + ) + def _zero_model_grad(self, model): if hasattr(model, "_zero_grad_kwargs"): model.zero_grad(**model._zero_grad_kwargs) diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index b09e3be4a4..3d6bedc4c7 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import io +import json import os import warnings from dataclasses import asdict, dataclass, field @@ -23,7 +25,7 @@ from packaging import version from transformers.debug_utils import DebugOption from transformers.file_utils import cached_property, is_torch_available, requires_backends -from transformers.trainer_utils import EvaluationStrategy, HubStrategy, IntervalStrategy, SchedulerType +from transformers.trainer_utils import EvaluationStrategy, FSDPOption, HubStrategy, IntervalStrategy, SchedulerType from transformers.training_args import ( OptimizerNames, ParallelMode, @@ -58,7 +60,6 @@ "fp16_backend", "fp16_full_eval", "fp16_opt_level", - "fsdp", "mp_parameters", "tf32", "tpu_metrics_debug", @@ -318,8 +319,6 @@ def __post_init__(self): "--fp16, --fp16_backend, --fp16_full_eval and --fp16_opt_level are not" " supported by optimum-habana. Mixed-precision can be enabled in your Gaudi configuration." ) - if self.fsdp: - raise ValueError("--fsdp is not supported by optimum-habana.") if self.tpu_num_cores or self.tpu_metrics_debug: raise ValueError("TPUs are not supported by optimum-habana.") if self.mp_parameters: @@ -506,6 +505,100 @@ def __post_init__(self): " during training" ) + # Copy of https://github.com/huggingface/transformers/blob/b71f20a7c9f3716d30f6738501559acf863e2c5c/src/transformers/training_args.py#L1563 + # except following changes, (1) Remove XLA specific code & (2) change fsdp_backward_prefetch to backward_prefetch + if isinstance(self.fsdp, bool): + self.fsdp = "full_shard" if self.fsdp else "" + if isinstance(self.fsdp, str): + self.fsdp = [FSDPOption(s) for s in self.fsdp.split()] + if self.fsdp == [FSDPOption.OFFLOAD]: + raise ValueError( + "`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or " + '`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.' + ) + elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp: + raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") + + if self.fsdp_config is None: + self.fsdp_config = {} + + if isinstance(self.fsdp_config, str): + if len(self.fsdp) == 0: + warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.") + with io.open(self.fsdp_config, "r", encoding="utf-8") as f: + self.fsdp_config = json.load(f) + for k in list(self.fsdp_config.keys()): + if k.startswith("fsdp_"): + v = self.fsdp_config.pop(k) + self.fsdp_config[k[5:]] = v + + if self.fsdp_min_num_params > 0: + warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning) + + self.fsdp_config["min_num_params"] = max(self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params) + + # if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object + if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str): + self.fsdp_config["transformer_layer_cls_to_wrap"] = [self.fsdp_config["transformer_layer_cls_to_wrap"]] + + if self.fsdp_transformer_layer_cls_to_wrap is not None: + warnings.warn( + "using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning + ) + self.fsdp_config["transformer_layer_cls_to_wrap"] = self.fsdp_config.get( + "transformer_layer_cls_to_wrap", [] + ) + [self.fsdp_transformer_layer_cls_to_wrap] + + if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0: + warnings.warn("`min_num_params` is useful only when `--fsdp` is specified.") + + if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None: + warnings.warn("`transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.") + + if ( + len(self.fsdp) > 0 + and self.fsdp_config["min_num_params"] > 0 + and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None + ): + raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.") + self.fsdp_config["xla"] = self.fsdp_config.get("xla", False) + self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False) + + # accelerate integration for FSDP + if len(self.fsdp) > 0 and not self.fsdp_config["xla"]: + os.environ["ACCELERATE_USE_FSDP"] = "true" + from accelerate.utils.constants import ( + FSDP_AUTO_WRAP_POLICY, + FSDP_SHARDING_STRATEGY, + ) + + prefix = "FSDP_" + for fsdp_option in self.fsdp: + if fsdp_option.upper() in FSDP_SHARDING_STRATEGY: + # set environment variable for FSDP sharding strategy + os.environ[f"{prefix}SHARDING_STRATEGY"] = str( + FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1 + ) + elif fsdp_option == FSDPOption.OFFLOAD: + os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true" + elif fsdp_option == FSDPOption.AUTO_WRAP: + os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0] + if self.fsdp_config["min_num_params"] > 0: + os.environ[f"{prefix}MIN_NUM_PARAMS"] = str(self.fsdp_config["min_num_params"]) + os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1] + elif self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None: + os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"] = ",".join( + self.fsdp_config["transformer_layer_cls_to_wrap"] + ) + prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") + os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper() + os.environ[f"{prefix}FORWARD_PREFETCH"] = str(self.fsdp_config.get("forward_prefect", "false")) + os.environ[f"{prefix}SYNC_MODULE_STATES"] = str(self.fsdp_config.get("sync_module_states", "true")) + os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "false")) + os.environ[f"{prefix}ACTIVATION_CHECKPOINTING"] = str( + self.fsdp_config.get("activation_checkpointing", "false") + ) + if isinstance(self.debug, str): self.debug = [DebugOption(s) for s in self.debug.split()] elif self.debug is None: diff --git a/tests/test_fsdp_examples.py b/tests/test_fsdp_examples.py new file mode 100644 index 0000000000..af82965063 --- /dev/null +++ b/tests/test_fsdp_examples.py @@ -0,0 +1,127 @@ +import json +import os +import re +import subprocess +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest + +from .test_examples import ACCURACY_PERF_FACTOR, TIME_PERF_FACTOR + + +# Gaudi2 CI baselines +MODELS_TO_TEST = { + "bf16": [ + ( + "bert-base-uncased", + "Habana/bert-base-uncased", + 2807, + 85.4688, + "question-answering", + 24, + 8, + "run_qa.py", + "full_shard", + ), + ], +} + + +def _test_fsdp( + model_name: str, + gaudi_config: str, + baseline: float, + baseline_acc: float, + task: str, + batch_size_train: int, + batch_size_eval: int, + script: str, + policy: str, + world_size: int = 8, +): + os.environ["PT_HPU_LAZY_MODE"] = "0" + os.environ["PT_HPU_EAGER_4_STAGE_PIPELINE_ENABLE"] = "0" # To be removed later + os.environ["PT_HPU_EAGER_PIPELINE_ENABLE"] = "0" # To be removed later + path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" + + # Install question-answering example requirements + cmd_line = f"pip install -r {path_to_example_dir / task / 'requirements.txt'}".split() + p = subprocess.Popen(cmd_line) + return_code = p.wait() + assert return_code == 0 + + command = ["python3"] + + command += [ + f"{path_to_example_dir / 'gaudi_spawn.py'}", + "--use_mpi", + f"--world_size {world_size}", + ] + + command += [ + f"{path_to_example_dir / task / script}", + f"--model_name_or_path {model_name}", + "--do_train", + "--dataset_name squad", + "--max_seq_length 384", + f"--per_device_eval_batch_size {batch_size_eval}", + f"--per_device_train_batch_size {batch_size_train}", + "--learning_rate 3e-05", + "--num_train_epochs 2.0", + "--logging_steps 20", + "--save_steps 5000", + "--seed 42", + "--doc_stride 128", + "--use_habana", + "--overwrite_output_dir", + f"--gaudi_config_name {gaudi_config}", + "--throughput_warmup_steps 100", + f"--fsdp_config {path_to_example_dir / task / 'fsdp_config.json'}", + f"--fsdp '{policy}'", + "--do_eval", + "--torch_compile_backend aot_hpu_training_backend", + "--torch_compile", + ] + + with TemporaryDirectory() as tmp_dir: + command.append(f"--output_dir {tmp_dir}") + print(f"\n\nCommand to test: {' '.join(command)}\n") + + pattern = re.compile(r"([\"\'].+?[\"\'])|\s") + command = [x for y in command for x in re.split(pattern, y) if x] + + proc = subprocess.run(command) + + # Ensure the run finished without any issue + # Use try-except to avoid logging the token if used + try: + assert proc.returncode == 0 + except AssertionError as e: + if "'--token', 'hf_" in e.args[0]: + e.args = (f"The following command failed:\n{' '.join(command[:-2])}",) + raise + + with open(Path(tmp_dir) / "all_results.json") as fp: + results = json.load(fp) + + # Ensure performance requirements (throughput) are met + assert results["train_samples_per_second"] >= (2 - TIME_PERF_FACTOR) * baseline + assert results["eval_f1"] >= ACCURACY_PERF_FACTOR * baseline_acc + + +@pytest.mark.parametrize( + "model_name, gaudi_config, baseline, baseline_acc, task, bs_train, bs_eval, script, policy", MODELS_TO_TEST["bf16"] +) +def test_fsdp_bf16( + model_name: str, + gaudi_config: str, + baseline: float, + baseline_acc: float, + task: str, + bs_train: int, + bs_eval: int, + script: str, + policy: str, +): + _test_fsdp(model_name, gaudi_config, baseline, baseline_acc, task, bs_train, bs_eval, script, policy)