Skip to content

[Bugfix] [diffusion] Fix cache-dit with sp-degree only#19965

Merged
BBuf merged 10 commits into
sgl-project:mainfrom
OrangeRedeng:sequence_parallelism_with_cache_fix
Mar 18, 2026
Merged

[Bugfix] [diffusion] Fix cache-dit with sp-degree only#19965
BBuf merged 10 commits into
sgl-project:mainfrom
OrangeRedeng:sequence_parallelism_with_cache_fix

Conversation

@OrangeRedeng
Copy link
Copy Markdown
Contributor

@OrangeRedeng OrangeRedeng commented Mar 5, 2026

Motivation

Fix bug from #19955 issue, now sp-degree only works correctly

Modifications

Change getattr(sp_group, "ulysses_world_size", None) to get_sp_group().ulysses_world_size, the same with ring_world_size, change ParallelismBackend.NATIVE_PYTORCH to ParallelismBackend.AUTO

Accuracy Tests

No influence

Benchmarking and Profiling

SP-degree only gives best performance (If there is enough memory)

sp-degree 4/tp-size 2:
SGLANG_CACHE_DIT_FN=2 SGLANG_CACHE_DIT_BN=1 SGLANG_CACHE_DIT_WARMUP=4 SGLANG_CACHE_DIT_RDT=0.4 SGLANG_CACHE_DIT_MC=4 SGLANG_CACHE_DIT_TAYLORSEER=true SGLANG_CACHE_DIT_TS_ORDER=2 SGLANG_CACHE_DIT_ENABLED=true sglang generate --model-path ./Wan2.2-T2V-A14B-Diffusers-w8a8/ --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --height 720 --width 1280 --tp-size 2 --sp-degree 4 --num-gpus 8 --num-frames 81 --num-inference-steps 40 --port 30088

image

sp-degree 8/tp-size 1 (sp-degree only scenario):
SGLANG_CACHE_DIT_FN=2 SGLANG_CACHE_DIT_BN=1 SGLANG_CACHE_DIT_WARMUP=4 SGLANG_CACHE_DIT_RDT=0.4 SGLANG_CACHE_DIT_MC=4 SGLANG_CACHE_DIT_TAYLORSEER=true SGLANG_CACHE_DIT_TS_ORDER=2 SGLANG_CACHE_DIT_ENABLED=true sglang generate --model-path ./Wan2.2-T2V-A14B-Diffusers-w8a8/ --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --height 720 --width 1280 --tp-size 1 --sp-degree 8 --num-gpus 8 --num-frames 81 --num-inference-steps 40 --port 30088

image

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@github-actions github-actions Bot added the diffusion SGLang Diffusion label Mar 5, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug affecting the sp-degree only functionality within the diffusion cache-dit integration. The changes ensure that spatial parallelism configurations are correctly recognized and applied, improving the stability and performance of distributed model inference, particularly when using specific sp-degree setups. The update also refines how parallelism group sizes are accessed and defaults to an automatic backend selection for better compatibility.

Highlights

  • Bug Fix: Resolved an issue where sp-degree only configurations were not functioning correctly, as reported in issue [Bug] [Diffusion] Failed Wan2.2 with cache-dit with sp-degree only #19955.
  • Parallelism Configuration Update: Modified the retrieval of ulysses_world_size and ring_world_size to use get_sp_group().ulysses_world_size and get_sp_group().ring_world_size respectively, ensuring consistent access to parallelism group properties.
  • Backend Default Change: Updated the default ParallelismBackend from NATIVE_PYTORCH to AUTO for more flexible and automatic backend selection.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/multimodal_gen/runtime/cache/cache_dit_integration.py
    • Imported get_sp_group from sglang.multimodal_gen.runtime.distributed.parallel_state.
    • Replaced getattr(sp_group, "ulysses_world_size", None) with get_sp_group().ulysses_world_size.
    • Replaced getattr(sp_group, "ring_world_size", None) with get_sp_group().ring_world_size.
    • Changed ParallelismBackend.NATIVE_PYTORCH to ParallelismBackend.AUTO in _build_parallelism_config.
Activity
  • The author has formatted the code according to the project's pre-commit guidelines.
  • Benchmarking results for sp-degree 4/tp-size 2 and sp-degree 8/tp-size 1 configurations were provided, demonstrating performance with the changes.
  • The author confirmed that the changes have no influence on accuracy.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@OrangeRedeng OrangeRedeng changed the title [Diffusion] Fix cache-dit with sp-degree only [Bugfix] [diffusion] Fix cache-dit with sp-degree only Mar 5, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a bug related to using sp-degree with cache-dit by correctly retrieving parallelism information. The changes involve using get_sp_group() to access ulysses_world_size and ring_world_size, and updating the cache-dit backend to AUTO for better flexibility. The fix appears correct, but I've suggested a minor refactoring to avoid redundant function calls and improve code clarity.

Comment on lines +111 to +112
ulysses_size = get_sp_group().ulysses_world_size
ring_size = get_sp_group().ring_world_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve readability and avoid calling get_sp_group() twice, you can store its result in a local variable and reuse it.

Suggested change
ulysses_size = get_sp_group().ulysses_world_size
ring_size = get_sp_group().ring_world_size
sp_coord = get_sp_group()
ulysses_size = sp_coord.ulysses_world_size
ring_size = sp_coord.ring_world_size

@yhyang201
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Mar 5, 2026
@ping1jing2 ping1jing2 linked an issue Mar 5, 2026 that may be closed by this pull request
5 tasks
@mickqian
Copy link
Copy Markdown
Collaborator

mickqian commented Mar 5, 2026

cc @DefTruth

@ping1jing2
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@ping1jing2 ping1jing2 self-assigned this Mar 5, 2026
if sp_group is not None:
ulysses_size = getattr(sp_group, "ulysses_world_size", None)
ring_size = getattr(sp_group, "ring_world_size", None)
ulysses_size = get_sp_group().ulysses_world_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use sp_group? It's no longer None.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure? I test it now and it still doesn't work on the latest version of slang from the main

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which PR fix this issue?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my case (and #19955) the sp_group has no attributes "ulysses_world_size" and "ring_world_size"

Copy link
Copy Markdown
Contributor

@DefTruth DefTruth Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my case (and #19955) the sp_group has no attributes "ulysses_world_size" and "ring_world_size"

Wired, @BBuf could you please also take a look? thanks~

Copy link
Copy Markdown
Contributor

@DefTruth DefTruth Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my case (and #19955) the sp_group has no attributes "ulysses_world_size" and "ring_world_size"

@OrangeRedeng Hi~ I think you are right, the sp/tp_group inside _build_parallelism_config are assumed to be torch.distributed.ProcessGroup, but ProcessGroup has no attributes "ulysses_world_size" and "ring_world_size". The bug #19955 you has encountered maybe cause by errored sp/tp_group assigning logics at:

sp_group = sp_group_candidate.device_group if has_sp else None
tp_group = tp_group_candidate.device_group if has_tp else None

Maybe we should change these code snippets to:

 sp_group = sp_group_candidate if has_sp else None 
 tp_group = tp_group_candidate if has_tp else None 

By the way, we should also fix the signature of _build_parallelism_config to avoid misleading developers with the wrong usage. From the old one:

def _build_parallelism_config(
sp_group: Optional[torch.distributed.ProcessGroup],
tp_group: Optional[torch.distributed.ProcessGroup],
):

to the correctly signature:

def _build_parallelism_config( 
     sp_group: Optional[GroupCoordinator], 
     tp_group: Optional[GroupCoordinator], 
 ): 

@OrangeRedeng Could you please take a try?
Also cc @mickqian

@DefTruth
Copy link
Copy Markdown
Contributor

DefTruth commented Mar 9, 2026

Also encounter the same error while using FLUX.1 with cache:

File "/workspace/dev/vipshop/sglang/python/sglang/multimodal_gen/runtime/cache/cache_dit_integration.py", line 304, in enable_cache_on_transformer
    parallelism_config = _build_parallelism_config(sp_group, tp_group)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/vipshop/sglang/python/sglang/multimodal_gen/runtime/cache/cache_dit_integration.py", line 117, in _build_parallelism_config
    return ParallelismConfig(
           ^^^^^^^^^^^^^^^^^^
  File "<string>", line 26, in __init__
  File "/workspace/dev/vipshop/cache-dit/src/cache_dit/parallelism/config.py", line 119, in __post_init__
    raise ValueError(
ValueError: No parallelism is enabled. Please set ulysses_size, ring_size, or tp_size to enable parallelism.
[03-09 08:21:06] Failed to generate output for prompt 1: Error executing request 53f0ff81-08b2-4d81-8cdb-525b8eb653e0: No parallelism is enabled. Please set ulysses_size, ring_size, or tp_size to enable parallelism.
Traceback (most recent call last):
  File "/workspace/dev/vipshop/sglang/python/sglang/multimodal_gen/runtime/utils/logging_utils.py", line 466, in log_generation_timer
    yield timer
  File "/workspace/dev/vipshop/sglang/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py", line 209, in generate
    raise Exception(f"{output_batch.error}")
Exception: Error executing request 53f0ff81-08b2-4d81-8cdb-525b8eb653e0: No parallelism is enabled. Please set ulysses_size, ring_size, or tp_size to enable parallelism.
[03-09 08:21:06] Generation failed for prompt 1/1: Error executing request 53f0ff81-08b2-4d81-8cdb-525b8eb653e0: No parallelism is enabled. Please set ulysses_size, ring_size, or tp_size to enable parallelism.
Traceback (most recent call last):
  File "/workspace/dev/vipshop/sglang/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py", line 209, in generate
    raise Exception(f"{output_batch.error}")
Exception: Error executing request 53f0ff81-08b2-4d81-8cdb-525b8eb653e0: No parallelism is enabled. Please set ulysses_size, ring_size, or tp_size to enable parallelism.

cache-dit introduces a new "post_init" check after v1.2.3 that throws an error if all parallel sizes are not valid.

@DefTruth
Copy link
Copy Markdown
Contributor

DefTruth commented Mar 9, 2026

Also encounter the same error while using FLUX.1 with cache:

File "/workspace/dev/vipshop/sglang/python/sglang/multimodal_gen/runtime/cache/cache_dit_integration.py", line 304, in enable_cache_on_transformer
    parallelism_config = _build_parallelism_config(sp_group, tp_group)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/dev/vipshop/sglang/python/sglang/multimodal_gen/runtime/cache/cache_dit_integration.py", line 117, in _build_parallelism_config
    return ParallelismConfig(
           ^^^^^^^^^^^^^^^^^^
  File "<string>", line 26, in __init__
  File "/workspace/dev/vipshop/cache-dit/src/cache_dit/parallelism/config.py", line 119, in __post_init__
    raise ValueError(
ValueError: No parallelism is enabled. Please set ulysses_size, ring_size, or tp_size to enable parallelism.
[03-09 08:21:06] Failed to generate output for prompt 1: Error executing request 53f0ff81-08b2-4d81-8cdb-525b8eb653e0: No parallelism is enabled. Please set ulysses_size, ring_size, or tp_size to enable parallelism.
Traceback (most recent call last):
  File "/workspace/dev/vipshop/sglang/python/sglang/multimodal_gen/runtime/utils/logging_utils.py", line 466, in log_generation_timer
    yield timer
  File "/workspace/dev/vipshop/sglang/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py", line 209, in generate
    raise Exception(f"{output_batch.error}")
Exception: Error executing request 53f0ff81-08b2-4d81-8cdb-525b8eb653e0: No parallelism is enabled. Please set ulysses_size, ring_size, or tp_size to enable parallelism.
[03-09 08:21:06] Generation failed for prompt 1/1: Error executing request 53f0ff81-08b2-4d81-8cdb-525b8eb653e0: No parallelism is enabled. Please set ulysses_size, ring_size, or tp_size to enable parallelism.
Traceback (most recent call last):
  File "/workspace/dev/vipshop/sglang/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py", line 209, in generate
    raise Exception(f"{output_batch.error}")
Exception: Error executing request 53f0ff81-08b2-4d81-8cdb-525b8eb653e0: No parallelism is enabled. Please set ulysses_size, ring_size, or tp_size to enable parallelism.

cache-dit introduces a new "post_init" check after v1.2.3 that throws an error if all parallel sizes are not valid.

How to reproduce:

SGLANG_CACHE_DIT_ENABLED=true \
  sglang generate --model-path $FLUX_DIR \
    --prompt "test prompt for checking bug" \
    --dit-cpu-offload false \
    --text-encoder-cpu-offload false \
    --sp-degree 4 --num-gpus 4

@DefTruth
Copy link
Copy Markdown
Contributor

DefTruth commented Mar 9, 2026

@OrangeRedeng @mickqian Confirm this fix works as expected. However, this remains a hacky workaround; we should instead consider addressing the root cause by passing the correct parallel groups to the helper function.

@OrangeRedeng
Copy link
Copy Markdown
Contributor Author

@DefTruth Thank you for your detailed response! I agree that the solution is temporary, and I will try to offer a new version of the fix ASAP

@OrangeRedeng
Copy link
Copy Markdown
Contributor Author

@DefTruth Hi, I took a closer look at the code, first of all, as it turned out, we already have a lot of built-in (and partially unused) functions, such as get_ring_parallel_world_size(), so now I use them, because it seems to me more correct than using getattr(.... Secondly, I tried returning GroupCoordinator instead of torch.distributed.ProcessGroup, but this leads to new errors:
image
and our code now is really difficult to understand, a lot of functions are either commented out or not used, especially inside /sglang/multimodal_gen/runtime/distributed/parallel_state.py. It might make sense to merge this workaround now, and then refactor parallelism related files to improve the code

@DefTruth
Copy link
Copy Markdown
Contributor

@DefTruth Hi, I took a closer look at the code, first of all, as it turned out, we already have a lot of built-in (and partially unused) functions, such as get_ring_parallel_world_size(), so now I use them, because it seems to me more correct than using getattr(.... Secondly, I tried returning GroupCoordinator instead of torch.distributed.ProcessGroup, but this leads to new errors: image and our code now is really difficult to understand, a lot of functions are either commented out or not used, especially inside /sglang/multimodal_gen/runtime/distributed/parallel_state.py. It might make sense to merge this workaround now, and then refactor parallelism related files to improve the code

cc @mickqian

@mickqian
Copy link
Copy Markdown
Collaborator

hey @OrangeRedeng , cleaning this up in: https://github.com/sgl-project/sglang/pull/20760/changes

@ping1jing2
Copy link
Copy Markdown
Collaborator

hey @OrangeRedeng , cleaning this up in: https://github.com/sgl-project/sglang/pull/20760/changes

Hi @mickqian , we've rebased and retriggered CI. now all CIs passed, shall we merge this PR?

@yhyang201
Copy link
Copy Markdown
Collaborator

@mickqian All CI (Nvidia + AMD) passed and PR is approved, ready for merge

— SGLDHelper bot

@BBuf BBuf merged commit c64681f into sgl-project:main Mar 18, 2026
69 checks passed
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
…9965)

Co-authored-by: Mick <mickjagger19@icloud.com>
Co-authored-by: ronnie_zheng <zl19940307@163.com>
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
…9965)

Co-authored-by: Mick <mickjagger19@icloud.com>
Co-authored-by: ronnie_zheng <zl19940307@163.com>
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
…9965)

Co-authored-by: Mick <mickjagger19@icloud.com>
Co-authored-by: ronnie_zheng <zl19940307@163.com>
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
…9965)

Co-authored-by: Mick <mickjagger19@icloud.com>
Co-authored-by: ronnie_zheng <zl19940307@163.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
…9965)

Co-authored-by: Mick <mickjagger19@icloud.com>
Co-authored-by: ronnie_zheng <zl19940307@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] [Diffusion] Failed Wan2.2 with cache-dit with sp-degree only

6 participants