From 58b8f580b897795aba485f03e0f5ff4b452acf10 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 5 May 2022 18:35:52 +0530 Subject: [PATCH 1/4] PyTorch FSDP integration in Trainer --- docs/source/en/main_classes/trainer.mdx | 36 +++++++++++++ src/transformers/trainer.py | 68 +++++++++++++++++++++++-- src/transformers/trainer_utils.py | 7 +++ src/transformers/training_args.py | 52 ++++++++++++++++++- 4 files changed, 157 insertions(+), 6 deletions(-) diff --git a/docs/source/en/main_classes/trainer.mdx b/docs/source/en/main_classes/trainer.mdx index a3c57b51f570..42677f7eb908 100644 --- a/docs/source/en/main_classes/trainer.mdx +++ b/docs/source/en/main_classes/trainer.mdx @@ -540,6 +540,42 @@ Known caveats: `FullyShardedDataParallelism` of fairscale. It should be used with the option `auto_wrap` if you are not doing this yourself: `--sharded_ddp "zero_dp_3 auto_wrap"`. +### PyTorch Fully Sharded Data parallel + +To accelerate training huge models on larger batch sizes, we can use a fully sharded data parallel model. +This type of data parallel paradigm enables fitting more data and larger models by sharding the optimizer states, gradients and parameters. +To read more about it and the benefits, check out the [Fully Sharded Data Parallel blog](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/). +We have integrated the latest PyTorch's Fully Sharded Data Parallel (FSDP) training feature. +All you need to do is enable it through the config. + +**Required PyTorch version for FSDP support**: PyTorch Nightly (or 1.12.0 if you read this after it has been released) +as the model saving with FSDP activated is only available with recent fixes. + +**Usage**: + +- Make sure you have added the distributed launcher +`-m torch.distributed.launch --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE` if you haven't been using it already. + +- **Sharding Strategy**: + - FULL_SHARD : Shards optimizer states + gradients + model parameters across data parallel workers/GPUs. + For this, add `--fsdp full_shard` to the command line arguments. + - SHARD_GRAD_OP : Shards optimizer states + gradients across data parallel workers/GPUs. + For this, add `--fsdp shard_grad_op` to the command line arguments. +- To offload the parameters and gradients to the CPU, +add `--fsdp "full_shard offload"` or `--fsdp "shard_grad_op offload"` to the command line arguments. +- To automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`, +add `--fsdp "full_shard auto_wrap"` or `--fsdp "shard_grad_op auto_wrap"` to the command line arguments. +- To enable both CPU offloading and auto wrapping, +add `--fsdp "full_shard offload auto_wrap"` or `--fsdp "shard_grad_op offload auto_wrap"` to the command line arguments. +- If auto wrapping is enabled, please add `--fsdp_min_num_params ` to command line arguments. +It specifies FSDP's minimum number of parameters for Default Auto Wrapping. + +**Few caveats to be aware of** +- Mixed precision is currently not supported with FSDP as we wait for PyTorch to fix support for it. +More details in this [issues](https://github.com/pytorch/pytorch/issues/75676). +- FSDP currently doesn't support multiple parameter groups. +More details mentioned in this [issue](https://github.com/pytorch/pytorch/issues/76501) +(`The original model parameters' .grads are not set, meaning that they cannot be optimized separately (which is why we cannot support multiple parameter groups)`). Sections that were moved: diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6e339e8652cd..cd5760adfae1 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -17,6 +17,7 @@ """ import contextlib +import functools import inspect import math import os @@ -103,6 +104,7 @@ BestRun, EvalLoopOutput, EvalPrediction, + FSDPOption, HPSearchBackend, HubStrategy, IntervalStrategy, @@ -340,6 +342,10 @@ def __init__( raise ValueError( "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." ) + if len(args.fsdp) > 0: + raise ValueError( + "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags." + ) if args.local_rank == -1: raise ValueError("Using sharded DDP only works in distributed training.") @@ -357,6 +363,22 @@ def __init__( elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 + self.fsdp = None + if len(args.fsdp) > 0: + if args.deepspeed: + raise ValueError( + "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." + ) + if args.local_rank == -1: + raise ValueError("Using fsdp only works in distributed training.") + dep_version_check("torch>=1.12.0.dev20220418+cu113") + from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy + + if FSDPOption.FULL_SHARD in args.fsdp: + self.fsdp = ShardingStrategy.FULL_SHARD + elif FSDPOption.SHARD_GRAD_OP in args.fsdp: + self.fsdp = ShardingStrategy.SHARD_GRAD_OP + # one place to sort out whether to place the model on device or not # postpone switching model to cuda when: # 1. MP - since we are trying to fit a much bigger than 1 gpu model @@ -364,12 +386,14 @@ def __init__( # and we only use deepspeed for training at the moment # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first # 4. Sharded DDP - same as MP + # 5. FSDP - same as MP self.place_model_on_device = args.place_model_on_device if ( self.is_model_parallel or args.deepspeed or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) + or (self.fsdp is not None) ): self.place_model_on_device = False @@ -398,11 +422,11 @@ def __init__( "Passing a `model_init` is incompatible with providing the `optimizers` argument. " "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) - if (self.sharded_ddp is not None or args.deepspeed) and ( + if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and ( self.optimizer is not None or self.lr_scheduler is not None ): raise RuntimeError( - "Passing `optimizers` is not allowed if Fairscale or Deepspeed is enabled." + "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) @@ -450,6 +474,11 @@ def __init__( self.use_amp = False if args.fp16 or args.bf16: + if self.fsdp is not None: + raise ValueError( + "Mixed precision is currently not supported for FSDP." + "Please do not set arguments related to `mixed_precision`" + ) if args.half_precision_backend == "auto": if _is_native_amp_available: args.half_precision_backend = "amp" @@ -1102,6 +1131,32 @@ def _wrap_model(self, model, training=True): cpu_offload=cpu_offload, ).to(self.args.device) + # Distributed training using PyTorch FSDP + if self.fsdp is not None: + # PyTorch FSDP! + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.wrap import default_auto_wrap_policy + + if FSDPOption.OFFLOAD in self.args.fsdp: + cpu_offload = CPUOffload(offload_params=True) + else: + cpu_offload = CPUOffload(offload_params=False) + + auto_wrap_policy = None + if FSDPOption.AUTO_WRAP in self.args.fsdp: + if self.args.fsdp_min_num_params > 0: + auto_wrap_policy = functools.partial( + default_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params + ) + + if type(model) != FSDP: + # XXX: Breaking the self.model convention but I see no way around it for now. + self.model = model = FSDP( + model, sharding_strategy=self.fsdp, cpu_offload=cpu_offload, auto_wrap_policy=auto_wrap_policy + ) + if not FSDPOption.OFFLOAD in self.args.fsdp: + model.to(self.args.device) + elif is_sagemaker_dp_enabled(): model = nn.parallel.DistributedDataParallel( model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] @@ -1253,7 +1308,10 @@ def train( debug_overflow = DebugUnderflowOverflow(self.model) # noqa delay_optimizer_creation = ( - self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE or is_sagemaker_mp_enabled() + self.sharded_ddp is not None + and self.sharded_ddp != ShardedDDPOption.SIMPLE + or is_sagemaker_mp_enabled() + or self.fsdp is not None ) if args.deepspeed: deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( @@ -2138,7 +2196,9 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa if self.args.should_save: self._save(output_dir, state_dict=state_dict) elif ( - ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp + ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp + or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp + or self.fsdp is not None ): state_dict = self.model.state_dict() diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 4450bfde646e..5369b2e78023 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -582,3 +582,10 @@ class ShardedDDPOption(ExplicitEnum): ZERO_DP_3 = "zero_dp_3" OFFLOAD = "offload" AUTO_WRAP = "auto_wrap" + + +class FSDPOption(ExplicitEnum): + FULL_SHARD = "full_shard" + SHARD_GRAD_OP = "shard_grad_op" + OFFLOAD = "offload" + AUTO_WRAP = "auto_wrap" diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index cc0a5ec83570..f1bf72ec9b2a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -23,7 +23,14 @@ from typing import Any, Dict, List, Optional from .debug_utils import DebugOption -from .trainer_utils import EvaluationStrategy, HubStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption +from .trainer_utils import ( + EvaluationStrategy, + FSDPOption, + HubStrategy, + IntervalStrategy, + SchedulerType, + ShardedDDPOption, +) from .utils import ( ExplicitEnum, cached_property, @@ -331,6 +338,17 @@ class TrainingArguments: If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty list for `False` and `["simple"]` for `True`. + fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `False`): + Use PyTorch Distributed Parallel Training (in distributed training only). + + A list of options along the following: + + - `"full_shard"`: Shard parameters, gradients and optimizer states. + - `"shard_grad_op"`: Shard optimizer states and gradients. + - `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and `"shard_grad_op"`). + - `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`. + fsdp_min_num_params (`int`, *optional*, defaults to `0`): + FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is passed). deepspeed (`str` or `dict`, *optional*): Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may evolve in the future. The value is either the location of DeepSpeed json config file (e.g., @@ -674,10 +692,25 @@ class TrainingArguments: metadata={ "help": "Whether or not to use sharded DDP training (in distributed training only). The base option " "should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` " - "like this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or " + "like this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or `zero_dp_3` " "with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`.", }, ) + fsdp: str = field( + default="", + metadata={ + "help": "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training only). The base option " + "should be `full_shard` or `shard_grad_op` and you can add CPU-offload to `full_shard` or `shard_grad_op` " + "like this: full_shard offload` or `shard_grad_op offload`. You can add auto-wrap to `full_shard` or `shard_grad_op` " + "with the same syntax: full_shard auto_wrap` or `shard_grad_op auto_wrap`.", + }, + ) + fsdp_min_num_params: int = field( + default=0, + metadata={ + "help": "FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is passed)." + }, + ) deepspeed: Optional[str] = field( default=None, metadata={ @@ -931,6 +964,21 @@ def __post_init__(self): elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp: raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.") + if isinstance(self.fsdp, bool): + self.fsdp = "full_shard" if self.fsdp else "" + if isinstance(self.fsdp, str): + self.fsdp = [FSDPOption(s) for s in self.fsdp.split()] + if self.fsdp == [FSDPOption.OFFLOAD]: + raise ValueError( + "`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or " + '`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.' + ) + elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.sharded_ddp: + raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") + + if len(self.fsdp) == 0 and self.fsdp_min_num_params > 0: + warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.") + if self.tpu_metrics_debug: warnings.warn( "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--debug tpu_metrics_debug` instead", From 9e33d57442cec90ac075ceda4835d80ff1712675 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 5 May 2022 19:30:00 +0530 Subject: [PATCH 2/4] reformatting make style and make quality are now compliant. --- src/transformers/trainer.py | 5 +++-- src/transformers/training_args.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index cd5760adfae1..cb904a7a3232 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1134,7 +1134,8 @@ def _wrap_model(self, model, training=True): # Distributed training using PyTorch FSDP if self.fsdp is not None: # PyTorch FSDP! - from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import default_auto_wrap_policy if FSDPOption.OFFLOAD in self.args.fsdp: @@ -1154,7 +1155,7 @@ def _wrap_model(self, model, training=True): self.model = model = FSDP( model, sharding_strategy=self.fsdp, cpu_offload=cpu_offload, auto_wrap_policy=auto_wrap_policy ) - if not FSDPOption.OFFLOAD in self.args.fsdp: + if FSDPOption.OFFLOAD not in self.args.fsdp: model.to(self.args.device) elif is_sagemaker_dp_enabled(): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index f1bf72ec9b2a..b1c3f8b2558b 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -345,7 +345,8 @@ class TrainingArguments: - `"full_shard"`: Shard parameters, gradients and optimizer states. - `"shard_grad_op"`: Shard optimizer states and gradients. - - `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and `"shard_grad_op"`). + - `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and + `"shard_grad_op"`). - `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`. fsdp_min_num_params (`int`, *optional*, defaults to `0`): FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is passed). From d6564a68ebc9f1f36ee6e324d8c9e54a1b48884c Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Fri, 6 May 2022 18:08:02 +0530 Subject: [PATCH 3/4] Updating dependency check --- src/transformers/trainer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index cb904a7a3232..e3e6c0297bef 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -371,7 +371,15 @@ def __init__( ) if args.local_rank == -1: raise ValueError("Using fsdp only works in distributed training.") - dep_version_check("torch>=1.12.0.dev20220418+cu113") + + # dep_version_check("torch>=1.12.0.dev20220418+cu113") + # Would have to update setup.py with torch>=1.12.0.dev20220418+cu113 + # which isn't ideally given that it's a dev version + # and it will force people not using FSDP to also use torch>=1.12.0.dev20220418+cu113 + # below is the current alternative. + if version.parse(torch.__version__) < version.parse("1.12.0.dev20220418+cu113"): + raise ValueError("FSDP requires PyTorch >= 1.12.0.dev20220418+cu113") + from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy if FSDPOption.FULL_SHARD in args.fsdp: From 3b73fae1c98398a33accdd7367a5d1fcfe083147 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Mon, 9 May 2022 10:51:56 -0400 Subject: [PATCH 4/4] Trigger CI