Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/source/en/main_classes/trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,42 @@ Known caveats:
`FullyShardedDataParallelism` of fairscale. It should be used with the option `auto_wrap` if you are not
doing this yourself: `--sharded_ddp "zero_dp_3 auto_wrap"`.

### PyTorch Fully Sharded Data parallel

To accelerate training huge models on larger batch sizes, we can use a fully sharded data parallel model.
This type of data parallel paradigm enables fitting more data and larger models by sharding the optimizer states, gradients and parameters.
To read more about it and the benefits, check out the [Fully Sharded Data Parallel blog](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/).
We have integrated the latest PyTorch's Fully Sharded Data Parallel (FSDP) training feature.
All you need to do is enable it through the config.

**Required PyTorch version for FSDP support**: PyTorch Nightly (or 1.12.0 if you read this after it has been released)
as the model saving with FSDP activated is only available with recent fixes.

**Usage**:

- Make sure you have added the distributed launcher
`-m torch.distributed.launch --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE` if you haven't been using it already.

- **Sharding Strategy**:
- FULL_SHARD : Shards optimizer states + gradients + model parameters across data parallel workers/GPUs.
For this, add `--fsdp full_shard` to the command line arguments.
- SHARD_GRAD_OP : Shards optimizer states + gradients across data parallel workers/GPUs.
For this, add `--fsdp shard_grad_op` to the command line arguments.
- To offload the parameters and gradients to the CPU,
add `--fsdp "full_shard offload"` or `--fsdp "shard_grad_op offload"` to the command line arguments.
- To automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`,
add `--fsdp "full_shard auto_wrap"` or `--fsdp "shard_grad_op auto_wrap"` to the command line arguments.
- To enable both CPU offloading and auto wrapping,
add `--fsdp "full_shard offload auto_wrap"` or `--fsdp "shard_grad_op offload auto_wrap"` to the command line arguments.
- If auto wrapping is enabled, please add `--fsdp_min_num_params <number>` to command line arguments.
It specifies FSDP's minimum number of parameters for Default Auto Wrapping.

**Few caveats to be aware of**
- Mixed precision is currently not supported with FSDP as we wait for PyTorch to fix support for it.
More details in this [issues](https://github.com/pytorch/pytorch/issues/75676).
- FSDP currently doesn't support multiple parameter groups.
More details mentioned in this [issue](https://github.com/pytorch/pytorch/issues/76501)
(`The original model parameters' .grads are not set, meaning that they cannot be optimized separately (which is why we cannot support multiple parameter groups)`).

Sections that were moved:

Expand Down
77 changes: 73 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import contextlib
import functools
import inspect
import math
import os
Expand Down Expand Up @@ -103,6 +104,7 @@
BestRun,
EvalLoopOutput,
EvalPrediction,
FSDPOption,
HPSearchBackend,
HubStrategy,
IntervalStrategy,
Expand Down Expand Up @@ -340,6 +342,10 @@ def __init__(
raise ValueError(
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
)
if len(args.fsdp) > 0:
raise ValueError(
"Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
)

if args.local_rank == -1:
raise ValueError("Using sharded DDP only works in distributed training.")
Expand All @@ -357,19 +363,45 @@ def __init__(
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

self.fsdp = None
if len(args.fsdp) > 0:
if args.deepspeed:
raise ValueError(
"Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
)
if args.local_rank == -1:
raise ValueError("Using fsdp only works in distributed training.")

# dep_version_check("torch>=1.12.0.dev20220418+cu113")
# Would have to update setup.py with torch>=1.12.0.dev20220418+cu113
# which isn't ideally given that it's a dev version
# and it will force people not using FSDP to also use torch>=1.12.0.dev20220418+cu113
# below is the current alternative.
if version.parse(torch.__version__) < version.parse("1.12.0.dev20220418+cu113"):
raise ValueError("FSDP requires PyTorch >= 1.12.0.dev20220418+cu113")

from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy

if FSDPOption.FULL_SHARD in args.fsdp:
self.fsdp = ShardingStrategy.FULL_SHARD
elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
self.fsdp = ShardingStrategy.SHARD_GRAD_OP

# one place to sort out whether to place the model on device or not
# postpone switching model to cuda when:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
# and we only use deepspeed for training at the moment
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
# 4. Sharded DDP - same as MP
# 5. FSDP - same as MP
self.place_model_on_device = args.place_model_on_device
if (
self.is_model_parallel
or args.deepspeed
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
or (self.fsdp is not None)
):
self.place_model_on_device = False

Expand Down Expand Up @@ -398,11 +430,11 @@ def __init__(
"Passing a `model_init` is incompatible with providing the `optimizers` argument. "
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
if (self.sharded_ddp is not None or args.deepspeed) and (
if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and (
self.optimizer is not None or self.lr_scheduler is not None
):
raise RuntimeError(
"Passing `optimizers` is not allowed if Fairscale or Deepspeed is enabled."
"Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
Expand Down Expand Up @@ -450,6 +482,11 @@ def __init__(
self.use_amp = False

if args.fp16 or args.bf16:
if self.fsdp is not None:
raise ValueError(
"Mixed precision is currently not supported for FSDP."
"Please do not set arguments related to `mixed_precision`"
)
if args.half_precision_backend == "auto":
if _is_native_amp_available:
args.half_precision_backend = "amp"
Expand Down Expand Up @@ -1102,6 +1139,33 @@ def _wrap_model(self, model, training=True):
cpu_offload=cpu_offload,
).to(self.args.device)

# Distributed training using PyTorch FSDP
if self.fsdp is not None:
# PyTorch FSDP!
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import default_auto_wrap_policy

if FSDPOption.OFFLOAD in self.args.fsdp:
cpu_offload = CPUOffload(offload_params=True)
else:
cpu_offload = CPUOffload(offload_params=False)

auto_wrap_policy = None
if FSDPOption.AUTO_WRAP in self.args.fsdp:
if self.args.fsdp_min_num_params > 0:
auto_wrap_policy = functools.partial(
default_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params
)

if type(model) != FSDP:
# XXX: Breaking the self.model convention but I see no way around it for now.
self.model = model = FSDP(
model, sharding_strategy=self.fsdp, cpu_offload=cpu_offload, auto_wrap_policy=auto_wrap_policy
)
if FSDPOption.OFFLOAD not in self.args.fsdp:
model.to(self.args.device)

elif is_sagemaker_dp_enabled():
model = nn.parallel.DistributedDataParallel(
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
Expand Down Expand Up @@ -1253,7 +1317,10 @@ def train(
debug_overflow = DebugUnderflowOverflow(self.model) # noqa

delay_optimizer_creation = (
self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE or is_sagemaker_mp_enabled()
self.sharded_ddp is not None
and self.sharded_ddp != ShardedDDPOption.SIMPLE
or is_sagemaker_mp_enabled()
or self.fsdp is not None
)
if args.deepspeed:
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
Expand Down Expand Up @@ -2138,7 +2205,9 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif (
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
or self.fsdp is not None
):
state_dict = self.model.state_dict()

Expand Down
7 changes: 7 additions & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,3 +582,10 @@ class ShardedDDPOption(ExplicitEnum):
ZERO_DP_3 = "zero_dp_3"
OFFLOAD = "offload"
AUTO_WRAP = "auto_wrap"


class FSDPOption(ExplicitEnum):
FULL_SHARD = "full_shard"
SHARD_GRAD_OP = "shard_grad_op"
OFFLOAD = "offload"
AUTO_WRAP = "auto_wrap"
53 changes: 51 additions & 2 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@
from typing import Any, Dict, List, Optional

from .debug_utils import DebugOption
from .trainer_utils import EvaluationStrategy, HubStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption
from .trainer_utils import (
EvaluationStrategy,
FSDPOption,
HubStrategy,
IntervalStrategy,
SchedulerType,
ShardedDDPOption,
)
from .utils import (
ExplicitEnum,
cached_property,
Expand Down Expand Up @@ -331,6 +338,18 @@ class TrainingArguments:

If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
list for `False` and `["simple"]` for `True`.
fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `False`):
Use PyTorch Distributed Parallel Training (in distributed training only).

A list of options along the following:

- `"full_shard"`: Shard parameters, gradients and optimizer states.
- `"shard_grad_op"`: Shard optimizer states and gradients.
- `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and
`"shard_grad_op"`).
- `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`.
fsdp_min_num_params (`int`, *optional*, defaults to `0`):
FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is passed).
deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
Expand Down Expand Up @@ -674,10 +693,25 @@ class TrainingArguments:
metadata={
"help": "Whether or not to use sharded DDP training (in distributed training only). The base option "
"should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` "
"like this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or "
"like this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or `zero_dp_3` "
"with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`.",
},
)
fsdp: str = field(
default="",
metadata={
"help": "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training only). The base option "
"should be `full_shard` or `shard_grad_op` and you can add CPU-offload to `full_shard` or `shard_grad_op` "
"like this: full_shard offload` or `shard_grad_op offload`. You can add auto-wrap to `full_shard` or `shard_grad_op` "
"with the same syntax: full_shard auto_wrap` or `shard_grad_op auto_wrap`.",
},
)
fsdp_min_num_params: int = field(
default=0,
metadata={
"help": "FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is passed)."
},
)
deepspeed: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -931,6 +965,21 @@ def __post_init__(self):
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")

if isinstance(self.fsdp, bool):
self.fsdp = "full_shard" if self.fsdp else ""
if isinstance(self.fsdp, str):
self.fsdp = [FSDPOption(s) for s in self.fsdp.split()]
if self.fsdp == [FSDPOption.OFFLOAD]:
raise ValueError(
"`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or "
'`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.'
)
elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.sharded_ddp:
raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")

if len(self.fsdp) == 0 and self.fsdp_min_num_params > 0:
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")

if self.tpu_metrics_debug:
warnings.warn(
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--debug tpu_metrics_debug` instead",
Expand Down