Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions docs/source/en/deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,107 @@ The example ZeRO-3 and ZeRO-Infinity config below sets most of the parameter val
}
```

### Sequence Parallelism

DeepSpeed's ALST/Ulysses sequence parallelism (also called context parallelism) enables training with very long sequences by splitting the sequence across multiple GPUs. This is particularly useful for training large language models with extended context windows.

ALST (Arctic Long Sequence Training) uses attention head parallelism to shard inputs along the sequence dimension. With this approach, you can train models with sequence lengths up to 500K tokens on a single GPU, 3.7M on a single node, or 15M tokens on just four nodes with Llama-8B. The implementation described here enables one component of the full ALST system. For additional optimizations like TiledMLP and activation checkpoint offloading, refer to the [DeepSpeed ALST tutorial](https://www.deepspeed.ai/tutorials/ulysses-alst-sequence-parallelism/).

> [!TIP]
> For more detailed information about sequence parallelism, see the Accelerate [Context Parallelism](https://huggingface.co/docs/accelerate/concept_guides/context_parallelism) guide.

To enable ALST/Ulysses sequence parallelism with [`Trainer`], configure `parallelism_config` in [`TrainingArguments`]. Sequence parallelism is configured via Accelerate's `ParallelismConfig` and requires Accelerate version higher than 1.11.0.

```py
from accelerate.utils import ParallelismConfig, DeepSpeedContextParallelConfig

parallelism_config = ParallelismConfig(
cp_backend="deepspeed",
cp_size=4, # Number of GPUs to split sequence across
cp_handler=DeepSpeedContextParallelConfig(
cp_seq_length_is_variable=True,
cp_attn_implementation="sdpa",
),
)

training_args = TrainingArguments(
...,
deepspeed="path/to/deepspeed_config.json",
parallelism_config=parallelism_config,
)
```

You can also configure sequence parallelism using an Accelerate config file.

```yaml
distributed_type: DEEPSPEED
deepspeed_config:
deepspeed_config_file: path/to/ds_config.json
machine_rank: 0
num_machines: 1
num_processes: 4 # Total number of processes
parallelism_config:
parallelism_config_dp_replicate_size: 1
parallelism_config_dp_shard_size: 1
parallelism_config_tp_size: 1
parallelism_config_cp_size: 4 # Sequence parallel size
parallelism_config_cp_backend: deepspeed
parallelism_config_cp_seq_length_is_variable: true
parallelism_config_cp_attn_implementation: sdpa
```

Important configuration parameters include the following.

* `cp_backend` must be set to `"deepspeed"` to use ALST/Ulysses sequence parallelism.
* `cp_size` is the degree of sequence parallelism. For example, `cp_size=4` means 4 GPUs will process a single sequence in parallel. You need at least 2 GPUs to enable sequence parallelism.
* `cp_seq_length_is_variable` determines how sequence lengths are handled. When set to `True` (recommended), the implementation adapts to varying sequence lengths between batches. When `False`, all sequences must be padded to a fixed length specified by `cp_seq_length`.
* `cp_attn_implementation` specifies the attention implementation to use. Supported values are `"sdpa"`, `"flash_attention_2"`, or `"flash_attention_3"`. Flash Attention is recommended for best performance, especially with multiple samples in a batch, because SDPA may incorrectly attend across sample boundaries.

> [!WARNING]
> Sequence parallelism requires your model to use one of the supported attention implementations (`sdpa`, `flash_attention_2`, or `flash_attention_3`). The `eager` attention implementation is not supported because it doesn't properly handle `position_ids`.

When using sequence parallelism, ensure your sequences are properly padded. Use `pad_to_multiple_of` in your data collator to ensure sequences are divisible by `cp_size`. For example, with `cp_size=4`, set `pad_to_multiple_of=4` or higher.

```py
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
pad_to_multiple_of=4, # Ensure sequences are divisible by cp_size
)
```

Sequence parallelism can be combined with data parallelism for even better scaling. The example below uses 8 GPUs with 4-way sequence parallelism and 2-way data parallelism.

```py
parallelism_config = ParallelismConfig(
cp_backend="deepspeed",
cp_size=4,
dp_shard_size=2,
cp_handler=DeepSpeedContextParallelConfig(
cp_seq_length_is_variable=True,
cp_attn_implementation="flash_attention_2",
),
)
```

[`Trainer`] automatically handles the special requirements for sequence parallelism including:

* Adapting the data loader via DeepSpeed's [`UlyssesSPDataLoaderAdapter`](https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/runtime/sequence_parallel/ulysses_sp.py) to shard sequences across GPUs
* Generating `position_ids` when not provided
* Creating `shift_labels` for causal language modeling
* Aggregating loss across sequence parallel ranks with proper masking for `-100` labels

You can launch training with sequence parallelism using the `accelerate launch` command.

```bash
accelerate launch --config_file alst_config.yaml your_training_script.py \
--output_dir output_dir \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1
```

## Training features

DeepSpeed supports many training features that can be configured in the config file. This section describes some of the most important features.
Expand Down