Skip to content

[Refactor] Use SP Plan for LongCat Sequence Parallelism#1772

Merged
wtomin merged 8 commits intovllm-project:mainfrom
alex-jw-brooks:longcat_sp
Mar 13, 2026
Merged

[Refactor] Use SP Plan for LongCat Sequence Parallelism#1772
wtomin merged 8 commits intovllm-project:mainfrom
alex-jw-brooks:longcat_sp

Conversation

@alex-jw-brooks
Copy link
Copy Markdown
Contributor

@alex-jw-brooks alex-jw-brooks commented Mar 10, 2026

Purpose

Fix #1692
Refactors LongCat Image to use the SP Plan Approach

It also looks like LongCat Image is broken for the non-sp case due to a missing attr on forward_ctx, probably from refactoring in the PR that fixed sp 😞

This PR

  • Fixes the no SP case
  • Migrates LongCat Image to use SP Plan
  • Removes some of the hacks around storing extra SP info on the context; instead just pass the parallel config
  • Removes docs for Intrusive Modification for sequence parallelism to push contributors towards contributing sequence parallelism with a unified pattern moving forward
  • Adds an example for handling dual stream attention in the sequence parallel docs

CC @wtomin @ZJY0516, could you PTAL?

I'll also wait for this to be merged before finishing #1487 so that the SP details won't bleed into the teacache extractor for this model

Signed-off-by: Alex Brooks <albrooks@redhat.com>

don't pass sp on fwd context

Signed-off-by: Alex Brooks <albrooks@redhat.com>

formatting

Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 87d6cbacfc

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py Outdated
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
@wtomin
Copy link
Copy Markdown
Collaborator

wtomin commented Mar 10, 2026

This SP implementation looks good to me. Please enrich your PR body with some experiment evidence:

  • the inference script that runs ulysses-sp and ring-attention;
  • the speed/memory performance of LongCat Image SP inference before and after this PR, just to verify that the sp performance is intact.


## Approach 2: Intrusive Modification (For Complex Cases)

For models with dynamic sharding logic that cannot be expressed via `_sp_plan`, manually insert shard/gather calls. Importantly, when taking this approach, be careful to ensure that you correctly manage the `_sp_shard_depth`; if the sequence parallel shard depth is 0, Ulysses will not be used.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest that we still keep the document related to intrusive modification, in case some models have very complicated sharding logic in the future.

Remove _sp_shard_depth in the document as it is not an interface supposed to be exposed to developers.
Since #1704 is merged, there is no bug when _sp_plan or manual SP implementation.

Copy link
Copy Markdown
Contributor Author

@alex-jw-brooks alex-jw-brooks Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wtomin @lishunyang12 I see - added it back, but my main concern is contributors copy+pasting this approach across models where it isn't needed, and it being an extra maintenance burden with hooks that take a similar approach to TeaCache.

Maybe this is more of an antipattern on the way TeaCache is implemented, but given that we already have to copy so much forward code into the extractors, I think it would be cleaner & a similar amount of work to have a SupportsTeaCache protocol, and just refactor the model .forward boundaries to support it.

Any thoughts on this? I will open a PR so that we can discuss this more concretely, but the combination of SP potentially modifying forward with forward being mostly copied into extractors may be a frequent cause of bugs when they are used together

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, later we can discuss about a better interface for TeaCache that is both cleaner and easier to maintain.

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a couple of nits. agree with wtomin's suggestion on keeping the intrusive modification docs.

Comment thread docs/design/feature/sequence_parallel.md Outdated

_repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"]

# Sequence Parallelism for LongCat (following diffusers' _cp_plan pattern)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: should this say _sp_plan? _cp_plan is the diffusers name for context parallelism.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is correct - it's saying that the pattern for sequence parallel in vLLM Omni is similar to the pattern Diffusers takes for context parallel

Signed-off-by: Alex Brooks <albrooks@redhat.com>
@alex-jw-brooks
Copy link
Copy Markdown
Contributor Author

alex-jw-brooks commented Mar 11, 2026

Thanks for the reviews @wtomin @lishunyang12 - for the inference script, you can just run the text_to_image example, e.g.,

python text_to_image.py --model meituan-longcat/LongCat-Image --ulysses-degree $ulysses --ring-degree $ring --output out_ul${ulysses}_ring${ring}.png

I ran it a few times with to compare to main, and the results are similar (metrics for running on h100).

Main:

  • (ring=1, ulysses=1)
    • Crashes on main due to the forward context attr issue (fixed by this PR)
  • (ring=1, ulysses=2)
    • Total generation time: 2.9256 seconds (2925.56 ms)
  • (ring=2, ulysses=1)
    • Total generation time: 3.2385 seconds (3238.55 ms)
  • (ring=2, ulysses=2)
    • Total generation time: 4.2973 seconds (4297.34 ms)

This branch:

  • (ring=1, ulysses=1)
    -Total generation time: 3.9868 seconds (3986.84 ms)
  • (ring=1, ulysses=2)
    • Total generation time: 2.9242 seconds (2924.18 ms)
  • (ring=2, ulysses=1)
    • Total generation time: 3.2686 seconds (3268.56 ms)
  • (ring=2, ulysses=2)
    • Total generation time: 4.2638 seconds (4263.80 ms)

For memory, both branches peak around ~ 35000MiB on each gpu for the number of gpus used for parallelism.

Comment thread docs/design/feature/sequence_parallel.md Outdated
Comment thread docs/design/feature/sequence_parallel.md Outdated
Comment thread docs/design/feature/sequence_parallel.md Outdated
alex-jw-brooks and others added 2 commits March 11, 2026 08:44
Co-authored-by: Didan Deng <33117903+wtomin@users.noreply.github.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
@Gaohan123 Gaohan123 added this to the v0.18.0 milestone Mar 12, 2026
@wtomin wtomin added the ready label to trigger buildkite CI label Mar 12, 2026
Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall — clean migration to _sp_plan. Left a couple comments.

# LongCat uses dual-stream (text + image) with joint attention
# Text embeddings should be replicated across SP ranks for correctness
fwd_context.sequence_parallel_size = sp_size
fwd_context.split_text_embed_in_sp = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only set when sp_size > 1, but single-stream attention reads it from forward_ctx unconditionally. The non-SP path relies on the ForwardContext default being False — works now but fragile. Consider always setting it explicitly.

Copy link
Copy Markdown
Contributor Author

@alex-jw-brooks alex-jw-brooks Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the attention, it also checks that sp_size isn't None and > 1, so it is fine since this is an attr on fwd context - the other issues around the forward context were because this model was setting things on the forward context that weren't attributes 🙂

I can open a small follow-up around this, but let's do it separate from this PR since it's what the current code does - I think setting it as a flag in forward is an antipattern anyway, since it should be consistent behavior within each model, and currently AFAIK this flag is never True in Omni at the moment anyway

and text_seq_len is not None
):
# Ensure that the SP split won't cause out of bounds issues.
if text_seq_len < 0 or text_seq_len > query.shape[1]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows text_seq_len == query.shape[1] (zero image tokens). Is that intentional? SP would split an empty tensor in that case.

Copy link
Copy Markdown
Contributor Author

@alex-jw-brooks alex-jw-brooks Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had wanted to avoid sequence parallel specific validation that might be confusing, but good point - I think this will hit a bug in the existing code, because with the way it is currently written, it'll throw when applying the rotary embeddings from diffusers before it actually gets to the SP part. I will fix it 🙂

self.parallel_config = od_config.parallel_config

self.pos_embed = LongCatImagePosEmbed(theta=10000, axes_dim=axes_dims_rope)
self.rope_preparer = RoPEPreparer(self.pos_embed)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: pos_embed is now accessible via both self.pos_embed and self.rope_preparer.pos_embed. Any risk of duplicate keys in state_dict?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, since there are no learnable params at the moment. Good point though; this is consistent with what current models are doing, so I think it would be best to refactor them all together if we change this

Copy link
Copy Markdown
Collaborator

@wtomin wtomin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I think it can be merged.

@wtomin wtomin merged commit c8cf8a7 into vllm-project:main Mar 13, 2026
6 of 7 checks passed
yiliu30 pushed a commit to yiliu30/vllm-omni-fork that referenced this pull request Mar 20, 2026
…#1772)

Signed-off-by: Alex Brooks <albrooks@redhat.com>
Co-authored-by: Didan Deng <33117903+wtomin@users.noreply.github.com>

Signed-off-by: yiliu30 <yi4.liu@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Refactor]: Add SP Plan for LongCat Image

4 participants