Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/design/feature/sequence_parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ class TransformerWithRoPE(nn.Module):

## Approach 2: Intrusive Modification (For Complex Cases)

For models with dynamic sharding logic that cannot be expressed via `_sp_plan`, manually insert shard/gather calls.
For models with dynamic sharding logic that cannot be expressed via `_sp_plan`, manually insert shard/gather calls. Importantly, when taking this approach, be careful to ensure that you correctly manage the `_sp_shard_depth`; if the sequence parallel shard depth is 0, Ulysses will not be used.


**When to use:**
- Dynamic/conditional sharding logic
Expand All @@ -252,15 +253,18 @@ from vllm_omni.diffusion.distributed.sp_sharding import sp_shard, sp_gather

def forward(self, hidden_states, ...):
if self.parallel_config.sequence_parallel_size > 1:
# <Increment the _sp_shard_depth on the fwd context>
hidden_states = sp_shard(hidden_states, dim=1)

# ... computation ...

if self.parallel_config.sequence_parallel_size > 1:
output = sp_gather(output, dim=1)
# <Decrement the _sp_shard_depth on the fwd context>

return output
```
Note that currently, `sp_shard` / `sp_gather` do *not* automatically manage the `_sp_shard_depth`; you need to be careful to manage it yourself.

---

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,19 +583,28 @@ def forward(
guidance: torch.Tensor = None,
return_dict: bool = True,
) -> torch.FloatTensor | Transformer2DModelOutput:
fwd_context = get_forward_context()
# Before: hidden_states shape = (B, img_seq_len, in_channels)
# After: hidden_states shape = (B, img_seq_len // SP, in_channels)
sp_size = self.parallel_config.sequence_parallel_size
# Store SP size in forward context for sub-modules to access
get_forward_context().sequence_parallel_size = sp_size
if sp_size > 1:
if sp_size is not None and sp_size > 1:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep sequence_parallel_size in context for all forwards

LongCatImageAttention.forward reads get_forward_context().sequence_parallel_size unconditionally (both joint and single-stream paths), but this change only sets fwd_context.sequence_parallel_size inside the sp_size > 1 branch. When LongCat runs with sequence_parallel_size of 1 (or None), the context no longer has that attribute and attention will raise AttributeError before inference completes. Please set the field regardless of SP mode so non-SP runs continue to work.

Useful? React with 👍 / 👎.

sp_world_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
original_shape = hidden_states.shape
hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank]
# LongCat uses dual-stream (text + image) with joint attention
# Text embeddings should be replicated across SP ranks for correctness
get_forward_context().split_text_embed_in_sp = False
fwd_context.sequence_parallel_size = sp_size
fwd_context.split_text_embed_in_sp = False

# Mark SP as active so attention layers; we need this to ensure we use
# Ulysses instead of NoParallelAttention since we don't set an sp plan
# for this model.
# TODO: would be nice to refactor this to use sp_plan if possible to
# tracking this directly, even though we only have one level.
fwd_context._sp_shard_depth = 1

# Debug log (only first forward)
if not hasattr(self, "_sp_forward_logged"):
self._sp_forward_logged = True
Expand All @@ -604,6 +613,7 @@ def forward(
f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}"
)
else:
fwd_context._sp_shard_depth = 0
if not hasattr(self, "_sp_forward_logged"):
self._sp_forward_logged = True
logger.info(f"[LongCat Transformer] SP disabled: sp_size={sp_size}")
Expand All @@ -624,7 +634,7 @@ def forward(
image_rotary_emb = self.pos_embed(ids)

# SP: Chunk RoPE embeddings along sequence dimension
if self.parallel_config.sequence_parallel_size > 1:
if sp_size is not None and sp_size > 1:
sp_world_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
freqs_cos, freqs_sin = image_rotary_emb
Expand Down Expand Up @@ -656,46 +666,30 @@ def forward(
torch.cat([txt_freqs_sin, img_freqs_sin], dim=0),
)

for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
for block in self.transformer_blocks:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)

for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_single_checkpoint[index_block]:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
for block in self.single_transformer_blocks:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)

hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)

# SP: All-gather output to reconstruct full sequence
if self.parallel_config.sequence_parallel_size > 1:
if sp_size is not None and sp_size > 1:
output = get_sp_group().all_gather(output, dim=1)
# Mark SP as inactive after gathering
get_forward_context()._sp_shard_depth = 0

if not return_dict:
return (output,)
Expand Down