Skip to content

Support regional compile#17702

Open
WhoisZihan wants to merge 5 commits intosgl-project:mainfrom
WhoisZihan:feature/regional-compile
Open

Support regional compile#17702
WhoisZihan wants to merge 5 commits intosgl-project:mainfrom
WhoisZihan:feature/regional-compile

Conversation

@WhoisZihan
Copy link
Copy Markdown

Motivation

Currently the compile integration in diffusions only support compiling the whole model. The graph trace will include all repeated layers like transformer block, which leads to potential long compile time.

To address the issue, other frameworks like vllm/transformers leverages regional compile to only compile the repeated blocks(you can of course manually add other non-repeated blocks as well), so that you need to only compile once and reuse the code cache for the same blocks.

Modifications

A new server arg is added to control whether using regional compile.

Accuracy Tests

Benchmarking and Profiling

I only have A100 at hand, and I use a smaller model Wan2.2-TI2V-5B-Diffusers just to illustrate how it works.

SGLANG_DIFFUSION_ATTENTION_CONFIG=./mask_strategy_wan.json \
sglang generate   \
  --model-path Wan-AI/Wan2.2-TI2V-5B-Diffusers/   \
  --pin-cpu-memory   \
  --num-gpus 8   \
  --ulysses-degree 8 \
  --attention-backend sage_attn  \
  --enable-torch-compile \
  --regional-compile \
  --prompt "A cat walks on the grass, realistic" \
  --num-frames 81 \
  --height 480 \
  --width 832 \
  --num-inference-steps 27 \
  --guidance-scale 3.5 \
  --guidance-scale-2 4.0 \
  --perf-dump-path ./dump/wan_step_profile_cp8_main.json
Option Stage Time(s)
Full denoising 27.3190
Full decoding 33.2832
Full TOTAL 68.32
Regional denoising 43.6123
Regional decoding 9.6507
Regional TOTAL 61.46

The denoising step-wise duration is

Denoise Step Full(ms) Regional
0 13365.79 4963.19
1 466.87 1488.04
2 1568.19 1474.05
3 1559.44 1472.11
4 384.02 1496.27
...
26 924.59 1472.33

Observations

There are a few interesting observations from the statistics

  1. The first-step compile time reduces about 8.41s (13.37s -> 4.96s)
  2. The total end-to-end time decreases about 6.86s (68.32s -> 61.46s), which means region compile indeed introduces some kind of overhead for non-regional components, but I think the overhead also partially comes from the graph breaks, if we mitigate such breaks, the performance difference should be more minor, as per the benchmark by diffusers
  3. The step-wise time has mismatch between full and regional, which could be due to synchronization in stage perf, e.g., non-regional part might contains some unnecessary sync operations, or there might be some bugs in stage perf, but I don't time to dived deep into the perf part yet.

For now I only add transformer blocks into repeated blocks, and have not tested adding other non-repeated layers, which should help improve performance as well.

Also, there are many graph breaks inside the model right now, which could lead to unexpected time change too.

Plan for mitigating graph breaks and a more general compilation mechanism

From the dynamo logs, there are many graph breaks in the compile procedure. Due to the massive graph breaks, the performance has definitely not reached the peak of compiled artifacts, regardless of full or regional compile.

I have already figured out a few graph break points from the dynamo logs, and plan to mitigate such breaks in future PRs.

Also, I think the diffusion model layers are not quite compile-friendly at the moment, and it would be great to have a more structured and flexible compilation mechanism, which could introduces more optimization opportunities.

There is already an on-going work #11830 trying to integrate the compilation mechanism, which would be really helpful to fully utilize the torch.compile benefits. Is there any plan to proceed the work?

I am glad to help and contribute to the compile mechanism if possible.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@github-actions github-actions bot added the diffusion SGLang Diffusion label Jan 25, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @WhoisZihan, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for regional compilation in diffusion models, specifically targeting repeated blocks like transformer layers. The primary goal is to significantly reduce the initial compilation overhead and overall inference time by compiling only the frequently reused parts of the model, rather than the entire graph. This is exposed via a new command-line argument, allowing users to easily enable this optimization and achieve faster model execution, as demonstrated by the provided benchmarks.

Highlights

  • Regional Compile Feature: Introduced a new --regional-compile server argument to enable selective compilation of repeated blocks within diffusion models, aiming to reduce overall compilation time.
  • Performance Improvement: Benchmarking shows a significant reduction in first-step compile time (from 13.37s to 4.96s) and an overall end-to-end time decrease of 6.86s (from 68.32s to 61.46s) when using regional compilation.
  • Modular Compilation Logic: Implemented a regionally_compile method that identifies and compiles specific submodules (e.g., WanTransformerBlock) marked as _repeated_blocks, allowing for more granular optimization.
  • Model Structure Adaptation: The WanTransformer3DModel now includes a _repeated_blocks attribute to specify which components should benefit from regional compilation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces regional compilation to speed up the model compilation time, which is a valuable performance improvement. The changes involve adding a new server argument --regional-compile and logic to compile only specific repeated blocks of the model.

My review has identified two critical issues:

  1. An unused and buggy class WanPreTransformerLayers has been added in wanvideo.py. It should be removed to avoid confusion and potential errors.
  2. The new regional compilation feature is not activated because the call sites for _maybe_enable_torch_compile in denoising.py have not been updated to pass the new regional flag.

Addressing these points will ensure the new feature works as intended and the code remains clean.

I am having trouble creating individual review comments. Click here to see my feedback.

python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py (678-697)

critical

The new class WanPreTransformerLayers appears to be unused in the codebase. Additionally, its __init__ method calls super().__init__(config=config, hf_config=hf_config), but torch.nn.Module.__init__ does not accept these arguments. This will cause a TypeError if the class is ever instantiated.

The layers defined within it also seem to be duplicated from WanTransformer3DModel. It looks like this might be leftover code from a refactoring. To avoid confusion and potential bugs, it would be best to remove this class entirely.

python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py (119)

critical

The new regional parameter is a great addition for enabling regional compilation. However, the call sites for this function don't seem to be updated to pass this new parameter. The value of regional will always be its default False, so the regional compilation logic will never be triggered.

You should update the calls in DenoisingStage.__init__ (line 97) and DenoisingStage._prepare_denoising_loop (line 523) to pass regional=self.server_args.regional_compile.

For example, in __init__:

for transformer in filter(None, [self.transformer, self.transformer_2]):
    self._maybe_enable_torch_compile(transformer, regional=self.server_args.regional_compile)

Without this change, the new feature is not active.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant