From d7efce6f61416c1544789cba94e60ded2db56563 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 3 Mar 2026 07:05:59 +0000 Subject: [PATCH 1/2] update sp docs Signed-off-by: Alex Brooks --- docs/design/feature/sequence_parallel.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/design/feature/sequence_parallel.md b/docs/design/feature/sequence_parallel.md index ba09d6c046f..18477e61a2d 100644 --- a/docs/design/feature/sequence_parallel.md +++ b/docs/design/feature/sequence_parallel.md @@ -240,7 +240,8 @@ class TransformerWithRoPE(nn.Module): ## Approach 2: Intrusive Modification (For Complex Cases) -For models with dynamic sharding logic that cannot be expressed via `_sp_plan`, manually insert shard/gather calls. +For models with dynamic sharding logic that cannot be expressed via `_sp_plan`, manually insert shard/gather calls. Importantly, when taking this approach, be careful to ensure that you correctly manage the `_sp_shard_depth`; if the sequence parallel shard depth is 0, Ulysses will not be used. + **When to use:** - Dynamic/conditional sharding logic @@ -252,15 +253,18 @@ from vllm_omni.diffusion.distributed.sp_sharding import sp_shard, sp_gather def forward(self, hidden_states, ...): if self.parallel_config.sequence_parallel_size > 1: + # hidden_states = sp_shard(hidden_states, dim=1) # ... computation ... if self.parallel_config.sequence_parallel_size > 1: output = sp_gather(output, dim=1) + # return output ``` +Note that currently, `sp_shard` / `sp_gather` do *not* automatically manage the `_sp_shard_depth`; you need to be careful to manage it yourself. --- From 6bfef321ff7f45da663dcaca52d1b7d07544ec47 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 3 Mar 2026 07:06:51 +0000 Subject: [PATCH 2/2] fix parallel shard depth, cleanup, remove checkpointing Signed-off-by: Alex Brooks --- .../longcat_image_transformer.py | 68 +++++++++---------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py index 329fc648bbc..833e7bb2c6a 100644 --- a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py +++ b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py @@ -583,19 +583,28 @@ def forward( guidance: torch.Tensor = None, return_dict: bool = True, ) -> torch.FloatTensor | Transformer2DModelOutput: + fwd_context = get_forward_context() # Before: hidden_states shape = (B, img_seq_len, in_channels) # After: hidden_states shape = (B, img_seq_len // SP, in_channels) sp_size = self.parallel_config.sequence_parallel_size # Store SP size in forward context for sub-modules to access - get_forward_context().sequence_parallel_size = sp_size - if sp_size > 1: + if sp_size is not None and sp_size > 1: sp_world_size = get_sequence_parallel_world_size() sp_rank = get_sequence_parallel_rank() original_shape = hidden_states.shape hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank] # LongCat uses dual-stream (text + image) with joint attention # Text embeddings should be replicated across SP ranks for correctness - get_forward_context().split_text_embed_in_sp = False + fwd_context.sequence_parallel_size = sp_size + fwd_context.split_text_embed_in_sp = False + + # Mark SP as active so attention layers; we need this to ensure we use + # Ulysses instead of NoParallelAttention since we don't set an sp plan + # for this model. + # TODO: would be nice to refactor this to use sp_plan if possible to + # tracking this directly, even though we only have one level. + fwd_context._sp_shard_depth = 1 + # Debug log (only first forward) if not hasattr(self, "_sp_forward_logged"): self._sp_forward_logged = True @@ -604,6 +613,7 @@ def forward( f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}" ) else: + fwd_context._sp_shard_depth = 0 if not hasattr(self, "_sp_forward_logged"): self._sp_forward_logged = True logger.info(f"[LongCat Transformer] SP disabled: sp_size={sp_size}") @@ -624,7 +634,7 @@ def forward( image_rotary_emb = self.pos_embed(ids) # SP: Chunk RoPE embeddings along sequence dimension - if self.parallel_config.sequence_parallel_size > 1: + if sp_size is not None and sp_size > 1: sp_world_size = get_sequence_parallel_world_size() sp_rank = get_sequence_parallel_rank() freqs_cos, freqs_sin = image_rotary_emb @@ -656,46 +666,30 @@ def forward( torch.cat([txt_freqs_sin, img_freqs_sin], dim=0), ) - for index_block, block in enumerate(self.transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]: - encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - ) - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) - for index_block, block in enumerate(self.single_transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_single_checkpoint[index_block]: - encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - ) - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) + for block in self.single_transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) # SP: All-gather output to reconstruct full sequence - if self.parallel_config.sequence_parallel_size > 1: + if sp_size is not None and sp_size > 1: output = get_sp_group().all_gather(output, dim=1) + # Mark SP as inactive after gathering + get_forward_context()._sp_shard_depth = 0 if not return_dict: return (output,)