Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions docs/design/feature/sequence_parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,64 @@ _sp_plan = {

**Problem:** `SequenceParallelInput(auto_pad=False)` - auto_pad should be True to enable automatic sequence padding.

**Solution:** In `SequenceParallelInput`, set `auto_pad=True`:
**Solution:** In `SequenceParallelInput`, set `auto_pad=True` and add attention mask support.

> **Experimental Feature:** `auto_pad=True` is an experimental feature and may be changed in the future. We plan to improve this solution to involve minimal changes to model files. More details are [here](https://github.com/vllm-project/vllm-omni/issues/1324).

**Constraints of auto_pad:**

| Constraint | Description |
|------------|-------------|
| **Attention Backend Compatibility** | The attention backends must support `attention_mask`. Currently only `TORCH_SDPA` and `FLASH_ATTN` (default for diffusion models) are supported. |
| **Ring Attention Limitation** | Ring attention does not support `attention_mask`. Therefore, when using `auto_pad=True`, the combination of Ulysses + Ring attention is not feasible. |

1. Enable `auto_pad=True` for all sequence-dimension inputs in `_sp_plan`:
```python
"blocks.0": {
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, auto_pad=True)
_sp_plan = {
"rope": {
0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True, auto_pad=True),
1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True, auto_pad=True),
},
"blocks.0": {
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, auto_pad=True)
},
...
}
```

2. Create attention mask dynamically when padding is applied:
```python
from vllm_omni.diffusion.forward_context import get_forward_context
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata

# In model forward(), before transformer blocks:
hidden_states_mask = None
ctx = get_forward_context()
if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0:
batch_size = hidden_states.shape[0]
padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size
hidden_states_mask = torch.ones(batch_size, padded_seq_len, dtype=torch.bool, device=hidden_states.device)
hidden_states_mask[:, ctx.sp_original_seq_len:] = False

# Pass mask to attention layers
attn_metadata = AttentionMetadata(attn_mask=hidden_states_mask) if hidden_states_mask is not None else None
output = self.attn(query, key, value, attn_metadata)
```

**Important Quality Considerations:**

While `auto_pad` enables generation for irregular resolutions, be aware of potential quality impacts:

| Aspect | Impact |
|--------|--------|
| **Training Distribution** | Models perform best on aspect ratios seen during training (e.g., 1:1, 16:9, 4:3). Unusual ratios like 700x400 (1.75:1) may produce lower quality results. |
| **Padding Overhead** | Padded positions consume compute even when masked. For best efficiency, prefer resolutions divisible by `sp_size`. |

Comment thread
gcanlin marked this conversation as resolved.
**Recommendations for users:**
- Use standard aspect ratios when possible (e.g., 768x432 for 16:9 instead of 700x400)
- Ensure post-patch dimensions are divisible by `sp_size` for optimal quality
- Test generation quality when using unusual resolutions

### Issue: Inline operations not sharded

**Symptoms:** Some tensors remain full-sized, not sharded.
Expand Down
57 changes: 48 additions & 9 deletions vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
from vllm.model_executor.layers.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.distributed.sp_plan import (
SequenceParallelInput,
SequenceParallelOutput,
)
from vllm_omni.diffusion.forward_context import get_forward_context

logger = init_logger(__name__)

Expand Down Expand Up @@ -395,6 +397,7 @@ def forward(
self,
hidden_states: torch.Tensor,
rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
# Fused QKV projection
qkv, _ = self.to_qkv(hidden_states)
Expand All @@ -418,8 +421,13 @@ def forward(
query = apply_rotary_emb_wan(query, freqs_cos, freqs_sin)
key = apply_rotary_emb_wan(key, freqs_cos, freqs_sin)

# Create attention metadata if mask is provided
attn_metadata = None
if attn_mask is not None:
attn_metadata = AttentionMetadata(attn_mask=attn_mask)

# Compute attention using unified attention layer
hidden_states = self.attn(query, key, value)
hidden_states = self.attn(query, key, value, attn_metadata)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)

Expand Down Expand Up @@ -637,6 +645,7 @@ def forward(
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: tuple[torch.Tensor, torch.Tensor],
hidden_states_mask: torch.Tensor | None = None,
) -> torch.Tensor:
if temb.ndim == 4:
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
Expand All @@ -657,7 +666,7 @@ def forward(

# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.attn1(norm_hidden_states, rotary_emb)
attn_output = self.attn1(norm_hidden_states, rotary_emb, hidden_states_mask)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)

# 2. Cross-attention
Expand Down Expand Up @@ -727,25 +736,32 @@ class WanTransformer3DModel(nn.Module):
# Note: _sp_plan corresponds to diffusers' _cp_plan (Context Parallelism)
_sp_plan = {
# Shard RoPE embeddings after rope module computes them
# auto_pad=True enables variable sequence length support
"rope": {
0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), # freqs_cos [1, seq, 1, dim]
1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), # freqs_sin [1, seq, 1, dim]
0: SequenceParallelInput(
split_dim=1, expected_dims=4, split_output=True, auto_pad=True
), # freqs_cos [1, seq, 1, dim]
1: SequenceParallelInput(
split_dim=1, expected_dims=4, split_output=True, auto_pad=True
), # freqs_sin [1, seq, 1, dim]
},
# Shard timestep_proj for TI2V models (4D tensor: [batch, seq_len, 6, inner_dim])
# This is only active when ts_seq_len is not None (TI2V mode)
# Output is a single tensor, shard along dim=1 (sequence dimension)
"timestep_proj_prepare": {
0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), # [B, seq, 6, dim]
0: SequenceParallelInput(
split_dim=1, expected_dims=4, split_output=True, auto_pad=True
), # [B, seq, 6, dim]
},
# Shard hidden_states at first transformer block input
# (after patch_embedding + flatten + transpose)
"blocks.0": {
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), # [B, seq, dim]
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, auto_pad=True), # [B, seq, dim]
Comment thread
gcanlin marked this conversation as resolved.
},
# Shard output scale/shift for TI2V (3D); T2V outputs 2D and skips sharding
"output_scale_shift_prepare": {
0: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True),
1: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True),
0: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True, auto_pad=True),
1: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True, auto_pad=True),
},
# Gather at proj_out (final linear projection before unpatchify)
"proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
Expand Down Expand Up @@ -878,9 +894,32 @@ def forward(
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)

# Check for SP auto_pad: create attention mask dynamically if padding was applied
hidden_states_mask = None # default
config = get_forward_context().omni_diffusion_config
parallel_config = config.parallel_config
if parallel_config is not None and parallel_config.sequence_parallel_size > 1:
ctx = get_forward_context()
if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0:
# Create mask for the full (padded) sequence
# valid positions = True, padding positions = False
batch_size = hidden_states.shape[0]
padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size
hidden_states_mask = torch.ones(
batch_size,
padded_seq_len,
dtype=torch.bool,
device=hidden_states.device,
)
hidden_states_mask[:, ctx.sp_original_seq_len :] = False

Comment thread
gcanlin marked this conversation as resolved.
# if mask is all true, set it to None
if hidden_states_mask is not None and hidden_states_mask.all():
hidden_states_mask = None

# Transformer blocks
for block in self.blocks:
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, hidden_states_mask)

# Output norm, projection & unpatchify
shift, scale = self.output_scale_shift_prepare(temb)
Expand Down