Skip to content

[diffusion]: optimize zimage with rmsnorm + add + (rmsnorm)#16404

Open
attack204 wants to merge 1 commit intosgl-project:mainfrom
attack204:feature/gaoji_support_zimage_rmsnorm_add_rmsnorm_final
Open

[diffusion]: optimize zimage with rmsnorm + add + (rmsnorm)#16404
attack204 wants to merge 1 commit intosgl-project:mainfrom
attack204:feature/gaoji_support_zimage_rmsnorm_add_rmsnorm_final

Conversation

@attack204
Copy link
Copy Markdown
Contributor

@attack204 attack204 commented Jan 4, 2026

E2E IMPROMENT

10s/13% faster

python python/sglang/multimodal_gen/benchmarks/compare_perf.py baseline.json new.json
### Performance Comparison Report

1. High-level Summary

Metric Baseline New Diff Status
E2E Latency 71805.84 ms 62156.90 ms -9648.94 ms (-13.4%)
Throughput 0.01 req/s 0.02 req/s - -

2. Stage Breakdown

Stage Name Baseline (ms) New (ms) Diff (ms) Diff (%) Status
InputValidationStage 0.06 0.06 -0.01 -9.4% ⚪️
TextEncodingStage 929.59 686.98 -242.61 -26.1% 🟢
ConditioningStage 0.03 0.03 -0.00 -7.0% ⚪️
TimestepPreparationStage 2.25 2.12 -0.13 -5.7% ⚪️
LatentPreparationStage 9.30 4.89 -4.41 -47.4% ⚪️
DenoisingStage 69765.61 60430.47 -9335.14 -13.4% 🟢
DecodingStage 1061.44 1000.04 -61.40 -5.8% ⚪️

DETAIL PERF

TODO

@github-actions github-actions bot added the diffusion SGLang Diffusion label Jan 4, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @attack204, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on optimizing the zimage diffusion model by implementing fused RMSNorm and addition operations. These changes are designed to streamline computational graphs and leverage GPU parallelism more effectively, resulting in a substantial reduction in overall execution time. The integration of these new kernels is expected to enhance the model's efficiency without altering its core functionality.

Highlights

  • Performance Optimization: Introduced fused RMSNorm and addition operations to significantly improve the end-to-end latency of the zimage diffusion model.
  • New Fused Kernels: Added two new CUDA kernels: fused_rmsnorm_add_rmsnorm (for RMSNorm(residual) + x -> output1, then RMSNorm(output1) -> output2) and fused_rmsnorm_add (for RMSNorm(residual) + x -> output).
  • Integration: Integrated these fused kernels into the ZImageTransformerBlock within the layernorm.py module, with a fallback to the original implementation if fused kernels cannot be initialized.
  • Benchmark Results: Achieved a 13.4% (9.6 seconds) reduction in E2E Latency, with notable improvements in TextEncodingStage (-26.1%) and DenoisingStage (-13.4%).

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces optimized fused kernels for RMSNorm operations within the ZImage model, which yields a significant performance improvement. The changes are well-structured, including new CUDA kernels, Python wrappers for JIT compilation, and seamless integration into the model with a fallback mechanism for compatibility. My review has identified one critical issue that would cause a runtime error, along with a few medium-severity suggestions to enhance code maintainability, robustness, and performance. Addressing these points will further solidify the quality of this optimization.

output1: intermediate result (x + RMSNorm(residual)) [batch_size, seq_len, hidden_size]
output2: final result (RMSNorm(output1)) [batch_size, seq_len, hidden_size]
"""
can_use_cuda = _is_cuda and (x.shape[-1] % 4 == 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable _is_cuda is not defined, which will cause a NameError at runtime. It seems you intended to check if the tensor is on a CUDA device. You should use x.is_cuda instead. This issue is also present on line 583 for the FusedRMSNormAdd class.

Suggested change
can_use_cuda = _is_cuda and (x.shape[-1] % 4 == 0)
can_use_cuda = x.is_cuda and (x.shape[-1] % 4 == 0)

Comment on lines +21 to +24
# TODO: workaround, do not import cutlass from flashinfer
cutlass_include = os.path.join(
os.path.dirname(flashinfer.__file__), "data", "cutlass", "include"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The cutlass_include path is constructed based on the internal directory structure of the flashinfer package. This creates a brittle dependency that might break if flashinfer changes its packaging in a future version. As noted by the TODO, this is a workaround. For long-term stability, it would be more robust to find a way to get this path through flashinfer's public API if available, or consider vendoring the required headers. This comment also applies to python/sglang/jit_kernel/diffusion/fused_rmsnorm_add_rmsnorm.py.

Comment on lines +541 to +547
def _rmsnorm(self, x: torch.Tensor, weight: Optional[torch.Tensor]) -> torch.Tensor:
"""Fallback RMSNorm implementation"""
variance = x.float().pow(2).mean(dim=-1, keepdim=True)
x_normalized = x * torch.rsqrt(variance + self.eps)
if weight is not None:
x_normalized = x_normalized * weight
return x_normalized.to(x.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _rmsnorm method is duplicated in both FusedRMSNormAddRMSNorm (lines 541-547) and FusedRMSNormAdd (lines 602-608). To improve maintainability and reduce code duplication, consider refactoring this into a static method or a standalone helper function.

@DarkSharpness
Copy link
Copy Markdown
Collaborator

DarkSharpness commented Jan 4, 2026

Do we have any test/benchmark on this kernel in python/sglang/jit_kernel? I think this can help more contributors understand the implementation and usage of this kernel, as well as test the correctness.

#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/optional.h>

#include "cutlass/numeric_types.h"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we try to eliminate this dependency? It seems only numeric types are used.

@RubiaCx RubiaCx added the run-ci label Jan 5, 2026
@yingluosanqian
Copy link
Copy Markdown
Collaborator

After this MR
is merged, we could pass dtype as a C++ template parameter, eliminating the need for runtime dtype dispatch.

if (dtype.code == kDLFloat && dtype.bits == 32) {
    dispatch_ipt(DTypeTag<float4, float>{});
  } else if (dtype.code == kDLFloat && dtype.bits == 16) {
    dispatch_ipt(DTypeTag<half4, half>{});
  } else if (dtype.code == kDLBfloat && dtype.bits == 16) {
    dispatch_ipt(DTypeTag<bf16_4, cutlass::bfloat16_t>{});
  }

@DarkSharpness
Copy link
Copy Markdown
Collaborator

DarkSharpness commented Jan 6, 2026

What's the relationship of this PR and #14717?

It seems there's many redundant code between these 2 PRs. We may need to maximize code reuse and avoid too many similar code.

@yingluosanqian
Copy link
Copy Markdown
Collaborator

What's the relationship of this PR and #14717?

It seems there's many redundant code between these 2 PRs. We may need to maximize code reuse and avoid too many similar code.

We have a kernel fusion plan for diffusion kernel. Many of which involve norm operation.

I summarized how we currently end up generating a large number of kernel source variants:

  1. Some kernels must handle both 3D and 4D input shapes.
  2. When N is smaller than a certain threshold, placing the input in shared memory yields better performance.
  3. Different kernels have different fusion patterns. For example, some kernels fuse norm + scale + shift, while others fuse norm + tanh + scale + shift, or norm + add + norm. This is especially problematic, and we may need a better abstraction to maximize code reuse.

@attack204
Copy link
Copy Markdown
Contributor Author

What's the relationship of this PR and #14717?

It seems there's many redundant code between these 2 PRs. We may need to maximize code reuse and avoid too many similar code.

This PR is based on 14717 and uses some of its utilities functions, so after 14717 is merged, this PR will be rebased.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants