Skip to content

[Bugfix]: SP attention not enabling when _sp_plan hooks are not applied#1704

Merged
wtomin merged 11 commits intovllm-project:mainfrom
wtomin:sp-test-re
Mar 11, 2026
Merged

[Bugfix]: SP attention not enabling when _sp_plan hooks are not applied#1704
wtomin merged 11 commits intovllm-project:mainfrom
wtomin:sp-test-re

Conversation

@wtomin
Copy link
Copy Markdown
Collaborator

@wtomin wtomin commented Mar 6, 2026

Purpose

This PR aims to fix one bug: SP attention not enabling when _sp_plan hooks are not applied. This bug exists in two cases:

  1. In some models SP implementation, it does not use _sp_plan, e.g., LongCatImage [Bug]: LongCat Image Sequence Parallelism is Broken #1556 ;
  2. In standalone SP unit test script, it does not use _sp_plan

Although the first case has a quick fix merged (quick fix in #1631), it is not intended to expose fwd_context._sp_shard_depth to the developers. Developers can easily forget to set it manually.

Therefore, in this PR, it proposes to check _sp_shard_depth only when _sp_plan hooks are applied. If not applied, sp_active is only determined by sp_size in the configuration. This is beneficial to both manual SP implementation and standalone SP unit test.

Minor edits for SP UT:

Test Plan

  • Standalone SP UT

pytest -s -v tests/diffusion/attention/test_attention_sp.py

  • LongCatImage SP
cd examples/offline_inference/text_to_image
python text_to_image.py --model meituan-longcat/LongCat-Image --ulysses-degree 2

python text_to_image.py --model meituan-longcat/LongCat-Image --ulysses-degree 2 --ring-degree 2

Test Result

  • Standalone SP UT

[baseline (no SP)] ✓ Saved output with shape torch.Size([2, 16, 64]):
  - batch_size=2, seq_len=16
  - num_heads=8, head_size=8
  - dtype=torch.bfloat16, causal=False, use_sync=False

[SP (ulysses=2, ring=2)] ✓ Saved output with shape torch.Size([2, 16, 64]):
  - batch_size=2, seq_len=16
  - num_heads=8, head_size=8
  - dtype=torch.bfloat16, causal=False, use_sync=False

================================================================================
Comparing outputs between baseline and SP...
  Baseline output shape: torch.Size([2, 16, 64])
  SP output shape: torch.Size([2, 16, 64])

================================================================================
Output Difference Analysis:
  - Max absolute difference: 1.562500e-02
  - Mean absolute difference: 4.872084e-04
  - Max relative difference: 9.999897e-01
  - Mean relative difference: 2.934728e-03
  - Baseline output range: [-3.140625e+00, 3.359375e+00]
  - SP output range: [-3.140625e+00, 3.359375e+00]
================================================================================

✓ Test passed: SP output matches baseline within tolerance
======================================================================= 1 passed, 20 warnings in 64.19s (0:01:04) ========================================================================
  • LongCatImage SP
ulysses-degree ring-degree generation time image
2 1 2.99s qwen_image_output
2 2 3.68s qwen_image_output

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@wtomin
Copy link
Copy Markdown
Collaborator Author

wtomin commented Mar 6, 2026

In order to solve a bug existent in SP unit test #1705, I raised this PR.

#1692 tackles the LongCat Image SP problem, from the perspective of using _sp_plan instead of manual SP implementation.

Thus the two solutions are not exclusive. @alex-jw-brooks I still suggest you to support LongCat Image SP with _sp_plan.

@wtomin
Copy link
Copy Markdown
Collaborator Author

wtomin commented Mar 6, 2026

@ZJY0516 @gcanlin @hsliuustc0106 @SamitHuang Please give your comments. Thanks.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 4ecde3342b

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/diffusion/registry.py
Comment thread tests/diffusion/attention/test_attention_sp.py Outdated
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

any perf comparison with sgl-d?

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review

Rating: 8.5/10 | Verdict: ✅ Approved

Summary

Correct bugfix enabling SP attention when _sp_plan hooks are not applied (manual SP, standalone tests). Root cause identified and fix is minimal and targeted.

CI Gate Checks (Step 0)

  • ✅ DCO: SUCCESS
  • ✅ Pre-commit: SUCCESS
  • ✅ Mergeable: MERGEABLE

Root Cause Analysis

Problem: sp_active property only checked _sp_shard_depth > 0, which is only meaningful within the _sp_plan hook mechanism. When hooks are not applied (manual SP, standalone tests), _sp_shard_depth stays at 0, causing SP attention to be incorrectly disabled.

Fix: Add sp_plan_hooks_applied flag to distinguish:

  • Hooks applied: use _sp_shard_depth (original behavior)
  • Hooks NOT applied: default to True when sequence_parallel_size > 1

Correctness Analysis

Scenario Before After Status
_sp_plan hooks applied ✅ Use _sp_shard_depth ✅ Same Preserved
Manual SP (no hooks) ❌ Always disabled ✅ Enabled when SP > 1 Fixed
Standalone tests ❌ Always disabled ✅ Enabled when SP > 1 Fixed

Highlights

  • ✅ Minimal change (3 files, focused on root cause)
  • ✅ Clear flag (sp_plan_hooks_applied) for state tracking
  • ✅ Backward compatible (preserves hook behavior)
  • ✅ Existing test modified to verify fix
  • ✅ Error handling for missing omni_diffusion_config

Test Changes

Removed: attn_backend parameter (simplification)
Added: seed_everything() helper function
Modified: Test now works without _sp_plan hooks

Minor Suggestions (non-blocking)

  1. Test coverage: The existing test is modified but no new test explicitly validates the "no hooks" scenario. Consider adding a comment in the test explaining it now exercises the new code path (hooks NOT applied).

  2. Error message: Line 60-61 raises ValueError when omni_diffusion_config is None. Consider adding context: "omni_diffusion_config is not set when checking sp_active! Please call ..."

  3. Flag initialization: sp_plan_hooks_applied defaults to False. Consider adding a class-level comment explaining the flag's purpose and when it's set.

Pitfalls Check

Directory Pitfall Status
diffusion/forward_context.py State management ✅ New flag
diffusion/registry.py Flag setting ✅ Correct
tests/ Regression test ✅ Modified

Recommendation

Ready to merge. Clean bugfix with good test coverage.


Reviewed by OpenClaw with vllm-omni-skills 🦐

Skill: vllm-omni-review (Bugfix)

Comment thread vllm_omni/diffusion/forward_context.py
Comment thread vllm_omni/diffusion/forward_context.py Outdated
Comment thread vllm_omni/diffusion/registry.py
Comment thread tests/diffusion/attention/test_attention_sp.py
Comment thread vllm_omni/diffusion/forward_context.py Outdated
@pytest.mark.parametrize("head_size", [8])
@pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) # [torch.float16, torch.bfloat16]
@pytest.mark.parametrize("dtype", [torch.bfloat16])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for removing fp16 here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to #906, the default attention backend FA does not support fp16.

@wtomin wtomin added the ready label to trigger buildkite CI label Mar 9, 2026
Copy link
Copy Markdown
Contributor

@alex-jw-brooks alex-jw-brooks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me, thanks!

wtomin and others added 11 commits March 10, 2026 13:09
Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com>
Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com>
Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com>
Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com>
Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com>
Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com>
Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com>
Co-authored-by: Canlin Guo <961750412@qq.com>
Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com>
Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com>
Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com>
Signed-off-by: Didan Deng <33117903+wtomin@users.noreply.github.com>
@wtomin wtomin merged commit 7543f2f into vllm-project:main Mar 11, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Test test_attention_sp.py failed: TypeError: 'NoneType' object is not callable when calling current_omni_platform.seed_everything

4 participants