Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions docs/design/feature/sequence_parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ The `_sp_plan` mechanism allows SP **without modifying `forward()` logic**. The
- Tensor operations happen at `nn.Module` boundaries
- Predictable sharding/gathering patterns

This is the ideal approach for integrating sequence parallelism into new models, as it is easier to maintain and ensure compatibility with other types of acceleration.

**How it works:**
1. Declare `_sp_plan` dict in your transformer class
2. Framework automatically applies hooks when `sequence_parallel_size > 1`
Expand Down Expand Up @@ -201,6 +203,36 @@ class TransformerWithRoPE(nn.Module):
}
```

**Pattern 3: Shard RoPE for Dual Stream Attention**
In some cases, different streams in attention may need to handle sequence parallelism differently. For example, we may want to shard the image embeddings, while replicating the text embeddings to correctly configure joint attention.

```python
class DualStreamTransformer(nn.Module):
"""
Dual-stream model where we need to replicate the text components, but shard
the image components to correctly handle sequence parallelism.
"""
_sp_plan = {
# In this case, the rope_preparer returns a tuple of len 4, where the
# first 2 items correspond to the text, and the second 2 correspond to
# visual inputs, so we only shard the second.
"rope_preparer": {
# Outputs 0, 1 (text) - NOT sharded (replicated)
# Outputs 2, 3 (image) - sharded
2: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True), # img_cos
3: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True), # img_sin
},
# Shard transformer block INPUT
"transformer_blocks.0": {
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
},
# Gather at output
"proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
}
```

NOTE: be careful to test adequately when refactoring classes that take this style of plan, as changing the order of the return values will break sequence parallelism.

### API Reference

**SequenceParallelInput Parameters:**
Expand Down Expand Up @@ -240,7 +272,7 @@ class TransformerWithRoPE(nn.Module):

## Approach 2: Intrusive Modification (For Complex Cases)

For models with dynamic sharding logic that cannot be expressed via `_sp_plan`, manually insert shard/gather calls. Importantly, when taking this approach, be careful to ensure that you correctly manage the `_sp_shard_depth`; if the sequence parallel shard depth is 0, Ulysses will not be used.
For models with dynamic sharding logic that cannot be expressed via `_sp_plan`, manually insert shard/gather calls.


**When to use:**
Expand All @@ -253,18 +285,15 @@ from vllm_omni.diffusion.distributed.sp_sharding import sp_shard, sp_gather

def forward(self, hidden_states, ...):
if self.parallel_config.sequence_parallel_size > 1:
# <Increment the _sp_shard_depth on the fwd context>
hidden_states = sp_shard(hidden_states, dim=1)

# ... computation ...

if self.parallel_config.sequence_parallel_size > 1:
output = sp_gather(output, dim=1)
# <Decrement the _sp_shard_depth on the fwd context>

return output
```
Note that currently, `sp_shard` / `sp_gather` do *not* automatically manage the `_sp_shard_depth`; you need to be careful to manage it yourself.

---

Expand Down Expand Up @@ -439,6 +468,7 @@ Complete examples in the codebase:

| Model | Path | Pattern | Notes |
|-------|------|---------|-------|
| **LongCat** | `vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py` | Dual-stream | Text components replicated, image components sharded |
| **Qwen-Image** | `vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py` | Dual-stream + preprocessing | auto_pad, separate RoPE |
| **Wan2.2** | `vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py` | Dual-Transformer + RoPE | Video transformer |
| **Z-Image** | `vllm_omni/diffusion/models/z_image/z_image_transformer.py` | Unified sequence | Concatenated input |
Expand All @@ -453,7 +483,7 @@ Complete examples in the codebase:
Adding Sequence Parallel support to a transformer:

1. ✅ **Choose approach** - Use `_sp_plan` for standard cases, intrusive modification for complex cases
2. ✅ **Identify sharding boundaries** - Where should tensors be split/gathered?
2. ✅ **Identify sharding boundaries** - Where should tensors be split/gathered? And which module boundaries need to be moved to facilitate this?
3. ✅ **Extract inline operations** - Move `torch.cat`, `pad_sequence`, etc. to submodules
4. ✅ **Define `_sp_plan`** - Declare shard/gather points as class attribute
5. ✅ **Use `auto_pad` for variable lengths** - Support non-uniform sequences
Expand Down
Loading
Loading