diff --git a/docs/design/feature/sequence_parallel.md b/docs/design/feature/sequence_parallel.md index 18477e61a2..3b899ffe34 100644 --- a/docs/design/feature/sequence_parallel.md +++ b/docs/design/feature/sequence_parallel.md @@ -55,6 +55,8 @@ The `_sp_plan` mechanism allows SP **without modifying `forward()` logic**. The - Tensor operations happen at `nn.Module` boundaries - Predictable sharding/gathering patterns +This is the ideal approach for integrating sequence parallelism into new models, as it is easier to maintain and ensure compatibility with other types of acceleration. + **How it works:** 1. Declare `_sp_plan` dict in your transformer class 2. Framework automatically applies hooks when `sequence_parallel_size > 1` @@ -201,6 +203,36 @@ class TransformerWithRoPE(nn.Module): } ``` +**Pattern 3: Shard RoPE for Dual Stream Attention** +In some cases, different streams in attention may need to handle sequence parallelism differently. For example, we may want to shard the image embeddings, while replicating the text embeddings to correctly configure joint attention. + +```python +class DualStreamTransformer(nn.Module): + """ + Dual-stream model where we need to replicate the text components, but shard + the image components to correctly handle sequence parallelism. + """ + _sp_plan = { + # In this case, the rope_preparer returns a tuple of len 4, where the + # first 2 items correspond to the text, and the second 2 correspond to + # visual inputs, so we only shard the second. + "rope_preparer": { + # Outputs 0, 1 (text) - NOT sharded (replicated) + # Outputs 2, 3 (image) - sharded + 2: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True), # img_cos + 3: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True), # img_sin + }, + # Shard transformer block INPUT + "transformer_blocks.0": { + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), + }, + # Gather at output + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } +``` + +NOTE: be careful to test adequately when refactoring classes that take this style of plan, as changing the order of the return values will break sequence parallelism. + ### API Reference **SequenceParallelInput Parameters:** @@ -240,7 +272,7 @@ class TransformerWithRoPE(nn.Module): ## Approach 2: Intrusive Modification (For Complex Cases) -For models with dynamic sharding logic that cannot be expressed via `_sp_plan`, manually insert shard/gather calls. Importantly, when taking this approach, be careful to ensure that you correctly manage the `_sp_shard_depth`; if the sequence parallel shard depth is 0, Ulysses will not be used. +For models with dynamic sharding logic that cannot be expressed via `_sp_plan`, manually insert shard/gather calls. **When to use:** @@ -253,18 +285,15 @@ from vllm_omni.diffusion.distributed.sp_sharding import sp_shard, sp_gather def forward(self, hidden_states, ...): if self.parallel_config.sequence_parallel_size > 1: - # hidden_states = sp_shard(hidden_states, dim=1) # ... computation ... if self.parallel_config.sequence_parallel_size > 1: output = sp_gather(output, dim=1) - # return output ``` -Note that currently, `sp_shard` / `sp_gather` do *not* automatically manage the `_sp_shard_depth`; you need to be careful to manage it yourself. --- @@ -439,6 +468,7 @@ Complete examples in the codebase: | Model | Path | Pattern | Notes | |-------|------|---------|-------| +| **LongCat** | `vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py` | Dual-stream | Text components replicated, image components sharded | | **Qwen-Image** | `vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py` | Dual-stream + preprocessing | auto_pad, separate RoPE | | **Wan2.2** | `vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py` | Dual-Transformer + RoPE | Video transformer | | **Z-Image** | `vllm_omni/diffusion/models/z_image/z_image_transformer.py` | Unified sequence | Concatenated input | @@ -453,7 +483,7 @@ Complete examples in the codebase: Adding Sequence Parallel support to a transformer: 1. ✅ **Choose approach** - Use `_sp_plan` for standard cases, intrusive modification for complex cases -2. ✅ **Identify sharding boundaries** - Where should tensors be split/gathered? +2. ✅ **Identify sharding boundaries** - Where should tensors be split/gathered? And which module boundaries need to be moved to facilitate this? 3. ✅ **Extract inline operations** - Move `torch.cat`, `pad_sequence`, etc. to submodules 4. ✅ **Define `_sp_plan`** - Declare shard/gather points as class attribute 5. ✅ **Use `auto_pad` for variable lengths** - Support non-uniform sequences diff --git a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py index ccf35ea4ee..8d8e523d60 100644 --- a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py +++ b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py @@ -18,11 +18,10 @@ from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata from vllm_omni.diffusion.attention.layer import Attention -from vllm_omni.diffusion.data import OmniDiffusionConfig -from vllm_omni.diffusion.distributed.parallel_state import ( - get_sequence_parallel_rank, - get_sequence_parallel_world_size, - get_sp_group, +from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.sp_plan import ( + SequenceParallelInput, + SequenceParallelOutput, ) from vllm_omni.diffusion.forward_context import get_forward_context from vllm_omni.platforms import current_omni_platform @@ -50,6 +49,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class LongCatImageAttention(nn.Module): def __init__( self, + parallel_config: DiffusionParallelConfig, query_dim: int, heads: int = 8, dim_head: int = 64, @@ -64,7 +64,7 @@ def __init__( pre_only: bool = False, ): super().__init__() - + self.parallel_config = parallel_config self.head_dim = dim_head self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.query_dim = query_dim @@ -209,12 +209,8 @@ def forward( encoder_query = self.norm_added_q(encoder_query) encoder_key = self.norm_added_k(encoder_key) - # Check if SP is enabled from forward context (set by LongCatImageTransformer2DModel) - forward_ctx = get_forward_context() - sp_size = forward_ctx.sequence_parallel_size - use_sp_joint_attention = sp_size > 1 and not forward_ctx.split_text_embed_in_sp - - if use_sp_joint_attention: + sp_size = self.parallel_config.sequence_parallel_size + if sp_size is not None and sp_size > 1: # SP Mode: Use common helper for RoPE + joint attention hidden_states = self._sp_attention_with_rope( img_query=query, @@ -248,12 +244,22 @@ def forward( # In SP mode, image part is chunked: (B, txt_len + img_len/SP, D) # Check if SP is enabled and we have text_seq_len info - forward_ctx = get_forward_context() - sp_size = forward_ctx.sequence_parallel_size + sp_size = self.parallel_config.sequence_parallel_size text_seq_len = kwargs.get("text_seq_len", None) - use_sp_single_stream = sp_size > 1 and not forward_ctx.split_text_embed_in_sp and text_seq_len is not None + forward_ctx = get_forward_context() + + if ( + sp_size is not None + and sp_size > 1 + and not forward_ctx.split_text_embed_in_sp + and text_seq_len is not None + ): + # Ensure that the SP split won't cause out of bounds issues. + if text_seq_len < 0 or text_seq_len > query.shape[1]: + raise ValueError( + f"text_seq_len={text_seq_len} is out of bounds for sequence length {query.shape[1]}" + ) - if use_sp_single_stream: # SP Mode for single-stream block: # Split QKV into text and image parts, then use common helper hidden_states = self._sp_attention_with_rope( @@ -301,6 +307,7 @@ def forward( class LongCatImageTransformerBlock(nn.Module): def __init__( self, + parallel_config: DiffusionParallelConfig, dim: int, num_attention_heads: int, attention_head_dim: int, @@ -309,10 +316,12 @@ def __init__( ): super().__init__() + self.parallel_config = parallel_config self.norm1 = AdaLayerNormZero(dim) self.norm1_context = AdaLayerNormZero(dim) self.attn = LongCatImageAttention( + parallel_config=parallel_config, query_dim=dim, added_kv_proj_dim=dim, dim_head=attention_head_dim, @@ -429,6 +438,61 @@ def forward(self, timestep, hidden_dtype): return timesteps_emb +class RoPEPreparer(nn.Module): + """ + This module encapsulates RoPE computation to enable _sp_plan sharding + for text / image components. + + For LongCat, which uses dual-stream attention, this means that text + components are replicated across SP ranks, while image components are + sharded. + """ + + def __init__(self, pos_embed: LongCatImagePosEmbed): + super().__init__() + self.pos_embed = pos_embed + + def forward( + self, + txt_ids: torch.Tensor, + img_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute RoPE embeddings for text and image sequences. + + Args: + txt_ids: Text position IDs (txt_seq_len, n_axes) + img_ids: Image position IDs (img_seq_len, n_axes) + + Returns: + Tuple of cosine / sine components for text & image + in the order: (txt_cos, txt_sin, img_cos, img_sin) + + NOTE: careful about output orders if this is refactored in the + future; we need to match the _sp_plan indices, since text + components (0 & 1) need to be replicated across SP ranks, + while image components (2 & 3) must be sharded. + """ + # Concatenate and compute RoPE for full sequence + ids = torch.cat((txt_ids, img_ids), dim=0) + + if current_omni_platform.is_npu(): + freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) + freqs_cos = freqs_cos.npu() + freqs_sin = freqs_sin.npu() + else: + freqs_cos, freqs_sin = self.pos_embed(ids) + + # Split into text and image portions + txt_len = txt_ids.shape[0] + txt_cos = freqs_cos[:txt_len] + txt_sin = freqs_sin[:txt_len] + img_cos = freqs_cos[txt_len:] + img_sin = freqs_sin[txt_len:] + + return txt_cos, txt_sin, img_cos, img_sin + + class LongCatImageSingleTransformerBlock(nn.Module): """ Single-stream Transformer block for LongCat with SP (Sequence Parallelism) support. @@ -437,7 +501,14 @@ class LongCatImageSingleTransformerBlock(nn.Module): This keeps the block logic clean and centralizes SP logic in the attention layer. """ - def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + def __init__( + self, + parallel_config: DiffusionParallelConfig, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + ): super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) @@ -448,6 +519,7 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, # SP handling is delegated to LongCatImageAttention via text_seq_len kwarg self.attn = LongCatImageAttention( + parallel_config=parallel_config, query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, @@ -511,6 +583,23 @@ class LongCatImageTransformer2DModel(nn.Module): _repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"] + # Sequence Parallelism for LongCat (following diffusers' _cp_plan pattern) + _sp_plan = { + "": { + # Chunk the hidden states prior to the forward() + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), + }, + # Shard RoPE image embeddings after rope_preparer computes them + # Outputs 0, 1 are text components, so they aren't sharded + # Outputs 2, 3 are image components and are sharded + "rope_preparer": { + 2: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True), + 3: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True), + }, + # Gather at the last linear projection + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } + def __init__( self, od_config: OmniDiffusionConfig, @@ -534,6 +623,7 @@ def __init__( self.parallel_config = od_config.parallel_config self.pos_embed = LongCatImagePosEmbed(theta=10000, axes_dim=axes_dims_rope) + self.rope_preparer = RoPEPreparer(self.pos_embed) self.time_embed = LongCatImageTimestepEmbeddings(embedding_dim=self.inner_dim) @@ -543,6 +633,7 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ LongCatImageTransformerBlock( + parallel_config=self.parallel_config, dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, @@ -554,6 +645,7 @@ def __init__( self.single_transformer_blocks = nn.ModuleList( [ LongCatImageSingleTransformerBlock( + parallel_config=self.parallel_config, dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, @@ -581,40 +673,11 @@ def forward( return_dict: bool = True, ) -> torch.FloatTensor | Transformer2DModelOutput: fwd_context = get_forward_context() - # Before: hidden_states shape = (B, img_seq_len, in_channels) - # After: hidden_states shape = (B, img_seq_len // SP, in_channels) sp_size = self.parallel_config.sequence_parallel_size - # Store SP size in forward context for sub-modules to access if sp_size is not None and sp_size > 1: - sp_world_size = get_sequence_parallel_world_size() - sp_rank = get_sequence_parallel_rank() - original_shape = hidden_states.shape - hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank] - # LongCat uses dual-stream (text + image) with joint attention - # Text embeddings should be replicated across SP ranks for correctness - fwd_context.sequence_parallel_size = sp_size fwd_context.split_text_embed_in_sp = False - # Mark SP as active so attention layers; we need this to ensure we use - # Ulysses instead of NoParallelAttention since we don't set an sp plan - # for this model. - # TODO: would be nice to refactor this to use sp_plan if possible to - # tracking this directly, even though we only have one level. - fwd_context._sp_shard_depth = 1 - - # Debug log (only first forward) - if not hasattr(self, "_sp_forward_logged"): - self._sp_forward_logged = True - logger.info( - f"[LongCat Transformer] SP enabled: sp_size={sp_size}, world_size={sp_world_size}, " - f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}" - ) - else: - fwd_context._sp_shard_depth = 0 - if not hasattr(self, "_sp_forward_logged"): - self._sp_forward_logged = True - logger.info(f"[LongCat Transformer] SP disabled: sp_size={sp_size}") - + # Hidden states are sharded prior to forward() when sp is active hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 @@ -622,46 +685,17 @@ def forward( temb = self.time_embed(timestep, hidden_states.dtype) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - ids = torch.cat((txt_ids, img_ids), dim=0) + # Compute RoPE embeddings via rope_preparer module + # _sp_plan will automatically shard img_cos/img_sin (outputs 2, 3) + # txt_cos/txt_sin (outputs 0, 1) remain replicated for dual-stream attention + txt_cos, txt_sin, img_cos, img_sin = self.rope_preparer(txt_ids, img_ids) - if current_omni_platform.is_npu(): - freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) - image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) - else: - image_rotary_emb = self.pos_embed(ids) - - # SP: Chunk RoPE embeddings along sequence dimension - if sp_size is not None and sp_size > 1: - sp_world_size = get_sequence_parallel_world_size() - sp_rank = get_sequence_parallel_rank() - freqs_cos, freqs_sin = image_rotary_emb - txt_len = txt_ids.shape[0] - - # Split RoPE into text and image portions - # txt_freqs: (txt_seq_len, head_dim) - keep full for all ranks - # img_freqs: (img_seq_len, head_dim) -> (img_seq_len // SP, head_dim) - txt_freqs_cos = freqs_cos[:txt_len] - txt_freqs_sin = freqs_sin[:txt_len] - img_freqs_cos = freqs_cos[txt_len:] - img_freqs_sin = freqs_sin[txt_len:] - - # Chunk image RoPE for each SP rank - # img_freqs_cos: (img_seq_len // SP, head_dim) - # img_freqs_sin: (img_seq_len // SP, head_dim) - img_freqs_cos = torch.chunk(img_freqs_cos, sp_world_size, dim=0)[sp_rank] - img_freqs_sin = torch.chunk(img_freqs_sin, sp_world_size, dim=0)[sp_rank] - - # Optionally chunk text RoPE if split_text_embed_in_sp is True - if get_forward_context().split_text_embed_in_sp: - txt_freqs_cos = torch.chunk(txt_freqs_cos, sp_world_size, dim=0)[sp_rank] - txt_freqs_sin = torch.chunk(txt_freqs_sin, sp_world_size, dim=0)[sp_rank] - - # Reconstruct image_rotary_emb with chunked values - # Final shape: (txt_seq_len + img_seq_len // SP, head_dim) - image_rotary_emb = ( - torch.cat([txt_freqs_cos, img_freqs_cos], dim=0), - torch.cat([txt_freqs_sin, img_freqs_sin], dim=0), - ) + # Reconstruct image_rotary_emb with chunked values + # Final shape: (txt_seq_len + img_seq_len // SP, head_dim) + image_rotary_emb = ( + torch.cat([txt_cos, img_cos], dim=0), + torch.cat([txt_sin, img_sin], dim=0), + ) for block in self.transformer_blocks: encoder_hidden_states, hidden_states = block( @@ -680,13 +714,9 @@ def forward( ) hidden_states = self.norm_out(hidden_states, temb) - output = self.proj_out(hidden_states) - # SP: All-gather output to reconstruct full sequence - if sp_size is not None and sp_size > 1: - output = get_sp_group().all_gather(output, dim=1) - # Mark SP as inactive after gathering - get_forward_context()._sp_shard_depth = 0 + # proj_out gathers for sequence parallel + output = self.proj_out(hidden_states) if not return_dict: return (output,)