diff --git a/docs/source/en/main_classes/trainer.mdx b/docs/source/en/main_classes/trainer.mdx index a0b914cd40af..67ab6aba42ef 100644 --- a/docs/source/en/main_classes/trainer.mdx +++ b/docs/source/en/main_classes/trainer.mdx @@ -564,32 +564,69 @@ as the model saving with FSDP activated is only available with recent fixes. - **Sharding Strategy**: - FULL_SHARD : Shards optimizer states + gradients + model parameters across data parallel workers/GPUs. - For this, add `--fsdp full_shard` to the command line arguments. + For this, add `--fsdp full_shard` to the command line arguments. - SHARD_GRAD_OP : Shards optimizer states + gradients across data parallel workers/GPUs. For this, add `--fsdp shard_grad_op` to the command line arguments. - NO_SHARD : No sharding. For this, add `--fsdp no_shard` to the command line arguments. - To offload the parameters and gradients to the CPU, -add `--fsdp "full_shard offload"` or `--fsdp "shard_grad_op offload"` to the command line arguments. -- To automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`, -add `--fsdp "full_shard auto_wrap"` or `--fsdp "shard_grad_op auto_wrap"` to the command line arguments. + add `--fsdp "full_shard offload"` or `--fsdp "shard_grad_op offload"` to the command line arguments. +- To automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`, + add `--fsdp "full_shard auto_wrap"` or `--fsdp "shard_grad_op auto_wrap"` to the command line arguments. - To enable both CPU offloading and auto wrapping, -add `--fsdp "full_shard offload auto_wrap"` or `--fsdp "shard_grad_op offload auto_wrap"` to the command line arguments. -- If auto wrapping is enabled, you can either use transformer based auto wrap policy or size based auto wrap policy. - - For transformer based auto wrap policy, please add `--fsdp_transformer_layer_cls_to_wrap ` to command line arguments. - This specifies the transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... - This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units. - Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers. - Remaining layers including the shared embeddings are conveniently wrapped in same outermost FSDP unit. - Therefore, use this for transformer based models. - - For size based auto wrap policy, please add `--fsdp_min_num_params ` to command line arguments. - It specifies FSDP's minimum number of parameters for auto wrapping. + add `--fsdp "full_shard offload auto_wrap"` or `--fsdp "shard_grad_op offload auto_wrap"` to the command line arguments. +- Remaining FSDP config is passed via `--fsdp_config `. It is either a location of + FSDP json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`. + - If auto wrapping is enabled, you can either use transformer based auto wrap policy or size based auto wrap policy. + - For transformer based auto wrap policy, please specify `fsdp_transformer_layer_cls_to_wrap` in the config file. + This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] .... + This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units. + Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers. + Remaining layers including the shared embeddings are conveniently wrapped in same outermost FSDP unit. + Therefore, use this for transformer based models. + - For size based auto wrap policy, please add `fsdp_min_num_params` in the config file. + It specifies FSDP's minimum number of parameters for auto wrapping. + - `fsdp_backward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters. + `backward_pre` and `backward_pos` are available options. + For more information refer `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch` + - `fsdp_forward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters. + If `"True"`, FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. + - `limit_all_gathers` can be specified in the config file. + If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. **Few caveats to be aware of** -- Mixed precision is currently not supported with FSDP as we wait for PyTorch to fix support for it. -More details in this [issues](https://github.com/pytorch/pytorch/issues/75676). -- FSDP currently doesn't support multiple parameter groups. -More details mentioned in this [issue](https://github.com/pytorch/pytorch/issues/76501) -(`The original model parameters' .grads are not set, meaning that they cannot be optimized separately (which is why we cannot support multiple parameter groups)`). +- it is incompatible with `generate`, thus is incompatible with `--predict_with_generate` + in all seq2seq/clm scripts (translation/summarization/clm etc.). + Please refer issue [#21667](https://github.com/huggingface/transformers/issues/21667) + +### PyTorch/XLA Fully Sharded Data parallel + +For all the TPU users, great news! PyTorch/XLA now supports FSDP. +All the latest Fully Sharded Data Parallel (FSDP) training are supported. +For more information refer to the [Scaling PyTorch models on Cloud TPUs with FSDP](https://pytorch.org/blog/scaling-pytorch-models-on-cloud-tpus-with-fsdp/) and [PyTorch/XLA implementation of FSDP](https://github.com/pytorch/xla/tree/master/torch_xla/distributed/fsdp) +All you need to do is enable it through the config. + +**Required PyTorch/XLA version for FSDP support**: >=2.0 + +**Usage**: + +Pass `--fsdp "full shard"` along with following changes to be made in `--fsdp_config `: +- `xla` should be set to `True` to enable PyTorch/XLA FSDP. +- `xla_fsdp_settings` The value is a dictionary which stores the XLA FSDP wrapping parameters. + For a complete list of options, please see [here]( + https://github.com/pytorch/xla/blob/master/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py). +- `xla_fsdp_grad_ckpt`. When `True`, uses gradient checkpointing over each nested XLA FSDP wrapped layer. + This setting can only be used when the xla flag is set to true, and an auto wrapping policy is specified through + `fsdp_min_num_params` or `fsdp_transformer_layer_cls_to_wrap`. +- You can either use transformer based auto wrap policy or size based auto wrap policy. + - For transformer based auto wrap policy, please specify `fsdp_transformer_layer_cls_to_wrap` in the config file. + This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] .... + This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units. + Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers. + Remaining layers including the shared embeddings are conveniently wrapped in same outermost FSDP unit. + Therefore, use this for transformer based models. + - For size based auto wrap policy, please add `fsdp_min_num_params` in the config file. + It specifies FSDP's minimum number of parameters for auto wrapping. + ### Using Trainer for accelerated PyTorch Training on Mac