Skip to content

[Diffusion] Apply qknorm to flux2 and apply lightx2v rms_norm_one_pass kernel(without residual)#17305

Merged
BBuf merged 7 commits intomainfrom
apply_qknorm_to_flux2
Jan 19, 2026
Merged

[Diffusion] Apply qknorm to flux2 and apply lightx2v rms_norm_one_pass kernel(without residual)#17305
BBuf merged 7 commits intomainfrom
apply_qknorm_to_flux2

Conversation

@BBuf
Copy link
Collaborator

@BBuf BBuf commented Jan 18, 2026

Motivation

FLUX2

sglang generate --model-path black-forest-labs/FLUX.2-dev --prompt "A Logo With Bold Large Text: SGL Diffusion" --width=1024 --height=1024 --dit-layerwise-offload false --enable-torch-compile --warmup --dit-cpu-offload false --text-encoder-cpu-offload true --vae-cpu-offload false

main:

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:22<00:00, 2.24it/s]
[01-18 14:19:18] [DenoisingStage] average time per step: 0.4461 seconds
[01-18 14:19:19] [DenoisingStage] finished in 22.7520 seconds

pr:

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:21<00:00, 2.34it/s]
[01-18 14:52:54] [DenoisingStage] average time per step: 0.4280 seconds
[01-18 14:52:55] [DenoisingStage] finished in 21.8318 seconds

About 4.2% end2end performace improve for flux2.

图片

Qwen-Image-Edit-2511

sglang generate --model-path Qwen/Qwen-Image-Edit-2511 --prompt "Change the person to a standing position, bending over to hold the dog's front paws." --image-path "/home/lmsys/bbuf/LightX2V/examples/qwen_image/1.png" --warmup --enable-torch-compile
  • main:

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:23<00:00, 1.73it/s]
[01-18 15:11:37] [DenoisingStage] average time per step: 0.5769 seconds
[01-18 15:11:37] [DenoisingStage] finished in 23.0815 seconds

  • pr:

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:23<00:00, 1.73it/s]
[01-18 15:13:52] [DenoisingStage] average time per step: 0.5768 seconds
[01-18 15:13:52] [DenoisingStage] finished in 23.0749 seconds

Since in qwen-image there is almost no separate RMSNorm, so performance remains consistent.

Micro benchmark

Provided by @triple-mu

import torch
import torch.nn.functional as F

import triton
import triton.language as tl

import flashinfer # flashinfer

@triton.jit
def _rms_norm_tiled_onepass(
    y_ptr,
    x_ptr,
    w_ptr,
    SEQ: tl.constexpr,
    DIM: tl.constexpr,
    EPS: tl.constexpr,
    BLOCK_SIZE_SEQ: tl.constexpr,
    BLOCK_SIZE_DIM: tl.constexpr,
):
    seq_blk_id = tl.program_id(0)
    seq_id = seq_blk_id * BLOCK_SIZE_SEQ

    seq_offset = seq_id + tl.arange(0, BLOCK_SIZE_SEQ)[:, None]
    s_mask = seq_offset < SEQ
    d_offset = tl.arange(0, BLOCK_SIZE_DIM)[None, :]
    d_mask = d_offset < DIM
    y_blk = y_ptr + seq_offset * DIM + d_offset
    x_blk = x_ptr + seq_offset * DIM + d_offset
    mask = s_mask & d_mask

    x = tl.load(x_blk, mask=mask, other=0.0).to(tl.float32)
    mean_square = tl.sum(x * x, axis=1, keep_dims=True) / DIM
    rstd = tl.math.rsqrt(mean_square + EPS)
    w = tl.load(w_ptr + d_offset, mask=d_mask)
    tl.store(y_blk, x * rstd * w, mask=mask)


def rms_norm_kernel(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    shape = x.shape
    x = x.contiguous()
    y = torch.empty_like(x)
    x_view = x.reshape(-1, shape[-1])
    y_view = y.reshape(-1, shape[-1])
    S, D = x_view.shape

    BLOCK_SIZE_SEQ = min(16, triton.next_power_of_2(max(1, S // 512)))
    grid = (triton.cdiv(S, BLOCK_SIZE_SEQ),)

    with torch.cuda.device(x.device):
        torch.library.wrap_triton(_rms_norm_tiled_onepass)[grid](
            y_view,
            x_view,
            w,
            S,
            D,
            eps,
            BLOCK_SIZE_DIM=triton.next_power_of_2(D),
            BLOCK_SIZE_SEQ=BLOCK_SIZE_SEQ,
        )
    return y


@torch.inference_mode()
def main():

    for head_size in [32, 64, 128]:
        print(f"Testing head_size: {head_size}")

        dtype = torch.bfloat16
        device = torch.device('cuda:0')

        seq_lens = [1 << i for i in range(10, 19)]
        inputs = [
            torch.randn(seq_len, 40, head_size, dtype=dtype, device=device) for seq_len in seq_lens
        ]

        weight = torch.randn(head_size, dtype=dtype, device=device)

        quantiles = [0.5, 0.2, 0.8]

        for i in range(3):
            for input in inputs:
                flashinfer.norm.rmsnorm(input, weight, eps=1e-6)
                rms_norm_kernel(input, weight, eps=1e-6)
                F.rms_norm(input, (head_size, ), weight, eps=1e-6)


        fns = {}
        for input in inputs:
            fn1 = lambda : flashinfer.norm.rmsnorm(input, weight, eps=1e-6)
            fn2 = lambda : rms_norm_kernel(input, weight, eps=1e-6)
            fn3 = lambda : F.rms_norm(input, (head_size, ), weight, eps=1e-6)
            fns[tuple(input.shape)] = (fn1, fn2, fn3)
        
        for shape, all_fns in fns.items():
            for i, fn in enumerate(all_fns):
                ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
                print(f'{shape}: fn[{i}] {ms:.5f} ms')
            print('*' * 100)
        
        print('='*50)


if __name__ == '__main__':
    main()

Result:

Testing head_size: 32
(1024, 40, 32): fn[0] 2.76439 ms
(1024, 40, 32): fn[1] 0.45949 ms
(1024, 40, 32): fn[2] 6.36691 ms
****************************************************************************************************
(2048, 40, 32): fn[0] 2.76452 ms
(2048, 40, 32): fn[1] 0.45889 ms
(2048, 40, 32): fn[2] 6.37582 ms
****************************************************************************************************
(4096, 40, 32): fn[0] 2.76442 ms
(4096, 40, 32): fn[1] 0.45948 ms
(4096, 40, 32): fn[2] 6.36517 ms
****************************************************************************************************
(8192, 40, 32): fn[0] 2.76528 ms
(8192, 40, 32): fn[1] 0.45889 ms
(8192, 40, 32): fn[2] 6.37479 ms
****************************************************************************************************
(16384, 40, 32): fn[0] 2.76506 ms
(16384, 40, 32): fn[1] 0.45951 ms
(16384, 40, 32): fn[2] 6.36547 ms
****************************************************************************************************
(32768, 40, 32): fn[0] 2.76439 ms
(32768, 40, 32): fn[1] 0.45890 ms
(32768, 40, 32): fn[2] 6.37488 ms
****************************************************************************************************
(65536, 40, 32): fn[0] 2.76483 ms
(65536, 40, 32): fn[1] 0.45949 ms
(65536, 40, 32): fn[2] 6.36510 ms
****************************************************************************************************
(131072, 40, 32): fn[0] 2.76436 ms
(131072, 40, 32): fn[1] 0.45889 ms
(131072, 40, 32): fn[2] 6.37313 ms
****************************************************************************************************
(262144, 40, 32): fn[0] 2.76399 ms
(262144, 40, 32): fn[1] 0.45950 ms
(262144, 40, 32): fn[2] 6.36560 ms
****************************************************************************************************
==================================================
Testing head_size: 64
(1024, 40, 64): fn[0] 2.75171 ms
(1024, 40, 64): fn[1] 0.87956 ms
(1024, 40, 64): fn[2] 6.37175 ms
****************************************************************************************************
(2048, 40, 64): fn[0] 2.75112 ms
(2048, 40, 64): fn[1] 0.88112 ms
(2048, 40, 64): fn[2] 6.37811 ms
****************************************************************************************************
(4096, 40, 64): fn[0] 2.75024 ms
(4096, 40, 64): fn[1] 0.87958 ms
(4096, 40, 64): fn[2] 6.37240 ms
****************************************************************************************************
(8192, 40, 64): fn[0] 2.75057 ms
(8192, 40, 64): fn[1] 0.88118 ms
(8192, 40, 64): fn[2] 6.37649 ms
****************************************************************************************************
(16384, 40, 64): fn[0] 2.75042 ms
(16384, 40, 64): fn[1] 0.87955 ms
(16384, 40, 64): fn[2] 6.37218 ms
****************************************************************************************************
(32768, 40, 64): fn[0] 2.75036 ms
(32768, 40, 64): fn[1] 0.88116 ms
(32768, 40, 64): fn[2] 6.37589 ms
****************************************************************************************************
(65536, 40, 64): fn[0] 2.75120 ms
(65536, 40, 64): fn[1] 0.87947 ms
(65536, 40, 64): fn[2] 6.37186 ms
****************************************************************************************************
(131072, 40, 64): fn[0] 2.75116 ms
(131072, 40, 64): fn[1] 0.88108 ms
(131072, 40, 64): fn[2] 6.37512 ms
****************************************************************************************************
(262144, 40, 64): fn[0] 2.75039 ms
(262144, 40, 64): fn[1] 0.87956 ms
(262144, 40, 64): fn[2] 6.37167 ms
****************************************************************************************************
==================================================
Testing head_size: 128
(1024, 40, 128): fn[0] 2.77051 ms
(1024, 40, 128): fn[1] 1.76063 ms
(1024, 40, 128): fn[2] 6.38813 ms
****************************************************************************************************
(2048, 40, 128): fn[0] 2.77091 ms
(2048, 40, 128): fn[1] 1.76215 ms
(2048, 40, 128): fn[2] 6.38253 ms
****************************************************************************************************
(4096, 40, 128): fn[0] 2.77168 ms
(4096, 40, 128): fn[1] 1.76031 ms
(4096, 40, 128): fn[2] 6.38663 ms
****************************************************************************************************
(8192, 40, 128): fn[0] 2.76875 ms
(8192, 40, 128): fn[1] 1.76225 ms
(8192, 40, 128): fn[2] 6.38367 ms
****************************************************************************************************
(16384, 40, 128): fn[0] 2.77135 ms
(16384, 40, 128): fn[1] 1.76003 ms
(16384, 40, 128): fn[2] 6.38361 ms
****************************************************************************************************
(32768, 40, 128): fn[0] 2.76964 ms
(32768, 40, 128): fn[1] 1.76233 ms
(32768, 40, 128): fn[2] 6.38564 ms
****************************************************************************************************
(65536, 40, 128): fn[0] 2.77183 ms
(65536, 40, 128): fn[1] 1.76024 ms
(65536, 40, 128): fn[2] 6.38377 ms
****************************************************************************************************
(131072, 40, 128): fn[0] 2.77039 ms
(131072, 40, 128): fn[1] 1.76203 ms
(131072, 40, 128): fn[2] 6.38260 ms
****************************************************************************************************
(262144, 40, 128): fn[0] 2.77206 ms
(262144, 40, 128): fn[1] 1.76016 ms
(262144, 40, 128): fn[2] 6.39005 ms
****************************************************************************************************
==================================================

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@github-actions github-actions bot added the diffusion SGLang Diffusion label Jan 18, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @BBuf, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant performance optimizations by integrating a custom Triton-based one-pass RMSNorm kernel and a fused query-key normalization (qknorm) mechanism. These changes are primarily aimed at enhancing the efficiency of the FLUX2 model, resulting in a notable speedup during denoising stages. The new RMSNorm kernel is conditionally applied based on tensor dimensions, ensuring targeted performance gains without affecting other models where these specific optimizations are not applicable or beneficial.

Highlights

  • Performance Optimization for FLUX2: Achieved an approximate 4.2% end-to-end performance improvement for the FLUX2 model during denoising stages.
  • New Triton One-Pass RMSNorm Kernel: Implemented a custom Triton-based one-pass RMSNorm kernel, adapted from LightX2V, which significantly outperforms existing flashinfer.norm.rmsnorm and torch.nn.functional.rms_norm for various sequence lengths and head sizes.
  • Conditional RMSNorm Application: Integrated the new Triton RMSNorm kernel into the RMSNorm layer, applying it conditionally when the input tensor's last dimension is 128 or less for targeted efficiency gains.
  • Fused Query-Key Normalization (qknorm): Applied a fused, in-place query-key normalization (qknorm) strategy within the Flux2 model's attention mechanism, leveraging a custom kernel for enhanced efficiency under specific conditions.
  • Qwen-Image-Edit-2511 Performance: No significant performance change observed for the Qwen-Image-Edit-2511 model, as it does not heavily utilize the specific RMSNorm optimizations introduced.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@BBuf BBuf changed the title []Apply qknorm to flux2 and apply lightx2v rms_norm_one_pass kernel(without residual) [Diffusion] Apply qknorm to flux2 and apply lightx2v rms_norm_one_pass kernel(without residual) Jan 18, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces performance optimizations by applying a fused QK-norm kernel for Flux2 and a new one-pass RMSNorm Triton kernel from LightX2V. The changes show a good performance improvement.

My review focuses on improving code quality and correctness:

  • In triton_ops.py, I've identified a potential bug in the new Triton kernel where an incorrect mask is used for loading weights, which could lead to incorrect results and reduced performance. I've suggested a fix.
  • In flux_2.py, I've pointed out code duplication for applying the fused QK-norm and suggested refactoring it into a helper function for better maintainability.
  • In layernorm.py, I've recommended replacing a magic number with a named constant to improve code clarity.

Overall, the changes are beneficial, and with these adjustments, the code will be more robust and maintainable.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@BBuf
Copy link
Collaborator Author

BBuf commented Jan 18, 2026

/tag-and-rerun-ci

@BBuf BBuf merged commit cc410a1 into main Jan 19, 2026
154 of 160 checks passed
@BBuf BBuf deleted the apply_qknorm_to_flux2 branch January 19, 2026 13:25
DotSlash-A pushed a commit to DotSlash-A/sglang that referenced this pull request Jan 19, 2026
* fix(ci): recover from corrupted MMMU parquet cache (sgl-project#17256)

* [diffusion] feat: support default 4-step inference for Flux2-Klein distilled models (sgl-project#17225)

Signed-off-by: Lancer <maruixiang6688@gmail.com>

* Add runner utilization report workflow (sgl-project#17234)

* cli: support sglang version (sgl-project#17250)

* Use swa radix cache and memory pool for gpt-oss model (sgl-project#17261)

* [VLM][Reland] Refactor load_mm_data to improve performance (sgl-project#16152)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>

* [Tiny] Improve docs (sgl-project#17264)

* [diffusion] fix: set guidance_scale default to None (sgl-project#17182)

* Tiny fix comment typo (sgl-project#17287)

* [SPEC_V2] Enable cudagraph draft_extend for trtllm_mla_backend and Acclen Fix for DP under cudagraph mode (sgl-project#16974)

* Add kl test for swa radix cache (sgl-project#17281)

* fix: Handle multiple named chat templates in HuggingFace tokenizers (sgl-project#17236)

Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>

* Move radix cache related tests (sgl-project#17295)

* [Refactor] Add `-fp4-gemm-backend` to replace `SGLANG_FLASHINFER_FP4_GEMM_BACKEND` (sgl-project#16534)

Co-authored-by: Vincent Zhong <207368749+vincentzed@users.noreply.github.com>

* [Bugfix] Fix PD accuracy when MTP is not configured on the prefill node (sgl-project#17212)

Co-authored-by: Shangming Cai <csmthu@gmail.com>

* [Diffusion] Apply jit qk_norm to flux1 (sgl-project#17296)

* [Refactor] Split out deepseek v2 weight loader function into mixin (sgl-project#16649)

* [NPU]Support GPT-OSS for NPU (sgl-project#14197)

* [jit-kernel] Add CuTe DSL GDN Decode Kernel (sgl-project#15631)

Co-authored-by: Jinyan Chen <jinyanc@nvidia.com>

* [GLM 4.7] Add RTX 6000 Pro aka sm120 (sgl-project#17235)

Co-authored-by: root <root@ubuntu-nvidia.localdomain>

* Update CODEOWNERS for multimodal_gen (sgl-project#17308)

Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>

* [Feature] overlap LoRA weight loading with compute (sgl-project#15512)

* [PD] Optimize MHA models pp util calculation logic (sgl-project#17306)

* [Minor] Correct sglang version when installing from source (sgl-project#17315)

* Use dsv3 optimized routing `fused_topk_deepseek` instead of `moe_fused_gate` (sgl-project#15347)

* [DeepSeek v3.2] Opt MTP decode cuda batch sizes and nsa implementation (sgl-project#16961)

* Update code sync scripts (sgl-project#17319)

* [Auto Sync] Update tokenizer_manager.py (20260119) (sgl-project#17317)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* support new qwen3_coder_detector (sgl-project#16744)

Co-authored-by: liugaoji.lgj <liugaoji.lgj@alibaba-inc.com>

* Fix kernel selection in biased_grouped_topk_gpu (sgl-project#17325)

* KV Cache Events with Attention DP bug fix (sgl-project#16030) (sgl-project#16412)

* [Perf] fuse q, k norm for Flux2Attention (sgl-project#17241)

Co-authored-by: Minglei Zhu <zminglei@linkedin.com>

* [CI] Add partition to stage-b-test-large-1-gpu (11->12) (sgl-project#17245)

* fix(ci): rate limit and permission errors in trace publishing (sgl-project#17238)

* Revert "[Perf] fuse q, k norm for Flux2Attention (sgl-project#17241)" (sgl-project#17332)

* Migrate performance, accuracy, and quantization tests to CI registry (sgl-project#17177)

Co-authored-by: Kangyan-Zhou <zky314343421@gmail.com>

* Inclusion of nvfp4 blockscale in EPLB Rebalance (sgl-project#17158)

* [Refactor] Set `fp4-gemm-backend=auto` on SM100 and rename `fp4-gemm-backend` with `flashinfer_` prefix (sgl-project#17309)

* [Diffusion] Apply qknorm to flux2 and apply lightx2v rms_norm_one_pass kernel(without residual) (sgl-project#17305)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Fix v32 continue_final_message not work (sgl-project#16567)

* Evict swa kv cache during decoding (sgl-project#17220)

* [RadixTree][1/N Refactor]: Support unified match_prefix params (sgl-project#17142)

Co-authored-by: yizhang2077 <1109276519@qq.com>
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com>

* [AMD CI] Migrate and Add More Testcases (sgl-project#17116)

Co-authored-by: yctseng0211 <yctseng@amd.com>

* [AMD] CI - add partitions for stage-b-test-small-1-gpu-amd (sgl-project#17345)

* Restore deepseek_v2.py to main's code, except the utils

* Ran `pre-commit`

---------

Signed-off-by: Lancer <maruixiang6688@gmail.com>
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Hudson Xing <1277646412@qq.com>
Co-authored-by: Lancer <402430575@qq.com>
Co-authored-by: Alison Shao <54658187+alisonshao@users.noreply.github.com>
Co-authored-by: Mick <mickjagger19@icloud.com>
Co-authored-by: Ke Bao <ispobaoke@gmail.com>
Co-authored-by: Yuan Luo <yuan.luo@hotmail.com>
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: Mohammad Miadh Angkad <mangkad.bsdsba2027@aim.edu>
Co-authored-by: Changyi Yang <112288487+ChangyiYang@users.noreply.github.com>
Co-authored-by: YAMY <74099316+YAMY1234@users.noreply.github.com>
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: b8zhong <b8zhong@uwaterloo.ca>
Co-authored-by: Vincent Zhong <207368749+vincentzed@users.noreply.github.com>
Co-authored-by: Ch3ngY1 <91232537+Ch3ngY1@users.noreply.github.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: Jerry Ji <jerryjilol@gmail.com>
Co-authored-by: Todobe <43903496+Todobe@users.noreply.github.com>
Co-authored-by: Jinyan Chen <93358689+liz-badada@users.noreply.github.com>
Co-authored-by: Jinyan Chen <jinyanc@nvidia.com>
Co-authored-by: Koushik Dutta <koush@koushikdutta.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Glen Liu <62917497+glenliu21@users.noreply.github.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: Lee Nau <lnau@nvidia.com>
Co-authored-by: Yongfei Xu <xuyongfei.xyf@antgroup.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Gaoji Liu <34803073+attack204@users.noreply.github.com>
Co-authored-by: liugaoji.lgj <liugaoji.lgj@alibaba-inc.com>
Co-authored-by: yudian0504 <138860534+yudian0504@users.noreply.github.com>
Co-authored-by: Kartik Ramesh <kartikx2000@gmail.com>
Co-authored-by: Minglei Zhu <mingleizhu1122@gmail.com>
Co-authored-by: Minglei Zhu <zminglei@linkedin.com>
Co-authored-by: Kangyan-Zhou <zky314343421@gmail.com>
Co-authored-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: ybyang <10629930+whybeyoung@users.noreply.github.com>
Co-authored-by: zhangheng <hzh0425@apache.org>
Co-authored-by: yizhang2077 <1109276519@qq.com>
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com>
Co-authored-by: Bingxu Chen <Bingxu.Chen@amd.com>
Co-authored-by: yctseng0211 <yctseng@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant