Skip to content

[Kernel] Fuse temperature + softmax in sampling for decode speedup#20501

Merged
BBuf merged 18 commits intosgl-project:mainfrom
Godmook:fused_sampling
Apr 2, 2026
Merged

[Kernel] Fuse temperature + softmax in sampling for decode speedup#20501
BBuf merged 18 commits intosgl-project:mainfrom
Godmook:fused_sampling

Conversation

@Godmook
Copy link
Copy Markdown
Contributor

@Godmook Godmook commented Mar 13, 2026

Motivation

The standard sampling path in sampler.py executes temperature scaling and softmax as two separate CUDA kernels:

logits.div_(sampling_info.temperatures)       # Kernel 1
logits[:] = torch.softmax(logits, dim=-1)     # Kernel 2

This is called on every decode step and incurs:

  • 2 kernel launches (~10us overhead each)
  • 6 global memory passes over the full [batch_size, vocab_size] tensor (4 reads + 2 writes)

For large-vocab models (Llama 3 128K, Qwen 152K), this becomes a meaningful bottleneck in the decode latency budget.

Modifications

New file: python/sglang/srt/layers/fused_sampling.py

Triton kernel that fuses div_(temperature) + softmax into a single kernel with two variants:

  • Single-pass kernel (vocab ≤ 32768): Loads entire vocab into registers once. Computes max, exp, sum, normalize all in-register. Only 1 read + 1 write = 2 memory passes.
  • Multi-pass kernel (vocab > 32768): 2-pass online softmax with @triton.autotune over (BLOCK_SIZE, num_warps, num_stages). 2 reads + 1 write = 3 memory passes.

Both variants apply logits / temperature during the load, avoiding a separate division pass.

Modified: python/sglang/srt/layers/sampler.py (3 lines changed)

  • Import fused_temperature_softmax_inplace
  • Replace the div_ + softmax two-line sequence with the fused kernel in the standard sampling path
  • Original implementation kept as fallback for non-CUDA platforms

New file: test/registered/sampling/test_fused_temperature_softmax.py

14 correctness tests comparing fused output against the reference (div_ + softmax) across:

  • Batch sizes: 1, 2, 16, 64, 128, 512
  • Vocab sizes: 1024, 32000, 128256
  • Temperatures: 0.01 to 100.0
  • Dtypes: bf16, fp16, fp32
  • Edge cases: empty batch, temperature=1.0

New file: benchmark/kernels/bench_fused_temperature_softmax.py

Microbenchmark using torch.cuda.Event timing (200 iters, 50 warmup).

Accuracy Tests

All 14 tests pass — fused output matches reference within tolerance (atol=1e-4, rtol=1e-3 for bf16):

test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_basic PASSED
test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_large_vocab PASSED
test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_batch_sizes PASSED
test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_temperature_one PASSED
test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_very_low_temperature PASSED
test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_very_high_temperature PASSED
test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_fp16_input PASSED
test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_fp32_input PASSED
test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_mixed_temperatures PASSED
test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_empty_batch PASSED
test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_inplace_basic PASSED
test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_inplace_bf16 PASSED
test_fused_temperature_softmax.py::TestFusedTemperatureSoftmax::test_inplace_large_vocab PASSED

13 passed

Benchmarking and Profiling

Measured on NVIDIA A40 (SM 8.6), bf16, torch.cuda.Event timing, 200 iters after 50 warmup:

bs vocab original (us) fused inplace (us) speedup
1 32000 25.5 23.4 1.09x
1 128256 43.0 33.8 1.27x
32 32000 28.5 21.3 1.34x
32 128256 150.5 46.7 3.22x
128 32000 140.9 34.7 4.06x
128 128256 612.3 177.1 3.46x
512 32000 541.7 121.4 4.46x
512 128256 2416.8 694.8 3.48x

Speedup across all configurations: 1.09x–4.46x. Improvement comes from:

  • Eliminating 1 kernel launch (~10us)
  • Reducing memory passes from 6 to 2 (single-pass) or 3 (multi-pass)
  • Autotune selecting optimal (BLOCK_SIZE, num_warps, num_stages) per vocab size

Checklist

Godmook added 3 commits March 12, 2026 21:41
… for sampling

- Add fused_temperature_softmax Triton kernel (2-pass online softmax)
- Replace div_ + softmax with fused kernel in sampler standard path
- Add correctness tests and benchmark script
- Reduces kernel launches from 2 to 1 and memory passes from 6 to 3
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Copy link
Copy Markdown
Collaborator

@DarkSharpness DarkSharpness left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I just wonder how long it would take to compile & autotune this triton kernel? Also, the autotune must be done in warm-up stage and should not introduce extra runtime overhead.

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Mar 30, 2026

LGTM. I just wonder how long it would take to compile & autotune this triton kernel? Also, the autotune must be done in warm-up stage and should not introduce extra runtime overhead.

Measured on A100 (SM 8.0).

1. Compile + Autotune overhead (first call per vocab_size)

vocab_size first call
32K 0.003s
128K 1.447s

Autotuning is cached by key=["vocab_size"], so changing batch_size within the same vocab_size does not trigger re-autotuning.


2. Runtime latency after warmup (200 iters, 50 warmup, bf16)

bs vocab cached (ms)
1 32K 0.051
32 32K 0.050
128 32K 0.050
512 32K 0.112
1 128K 0.063
32 128K 0.064
128 128K 0.118
512 128K 0.462

Once autotuned, runtime overhead is negligible (0.05–0.46ms).


3. Autotune overhead per vocab_size x dtype

vocab dtype autotune (s)
32K bf16 0.003
32K fp16 0.122
32K fp32 0.120
128K bf16 1.448
128K fp16 3.686
128K fp32 3.667

Each dtype triggers a separate autotune. fp16/fp32 takes ~30x longer than bf16.


4. Disk cache hit (first call after runtime restart)

vocab 1st call (s)
32K 0.003
128K 1.455

Even with Triton's ~/.triton/cache present, the same autotune cost is incurred after every process restart. The disk cache reuses compiled PTX/cubin but re-runs the autotune config search on every new process.

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Mar 30, 2026

I added autotune.

Changes

1. python/sglang/srt/layers/fused_sampling.py

Added warmup_fused_temperature_softmax() at the end of the file.
Runs both fused_temperature_softmax and fused_temperature_softmax_inplace with a dummy input to trigger JIT compile and autotune before the first real request.

2. python/sglang/srt/model_executor/model_runner.py

Added _warmup_fused_sampling() call inside kernel_warmup(), after _flashinfer_autotune().


Boot flow

Server start
  → ModelRunner.__init__()
    → kernel_warmup()
      → _flashinfer_autotune()        # existing
      → _warmup_fused_sampling()      # new
        → warmup_fused_temperature_softmax(vocab_size)
          - vocab ≤ 32K : single-pass JIT only  (~3ms)
          - vocab > 32K : multi-pass autotune   (~1–4s)
    → init_device_graphs()
  → Ready to serve (no latency spike on first request)

Autotune is cached by key=["vocab_size"]. Warming up once with the model's vocab_size covers all batch sizes at runtime — no re-autotuning occurs during serving. Runtime overhead after warmup is 0.

Also, I reinforced test codes and refactor some minor changes. @DarkSharpness Could you check it? Thanks!

@Godmook Godmook requested a review from hnyls2002 as a code owner March 30, 2026 11:59
@DarkSharpness
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Mar 30, 2026

/rerun-failed-ci

1 similar comment
@DarkSharpness
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@DarkSharpness
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Mar 31, 2026

/rerun-failed-ci

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Apr 1, 2026

/rerun-failed-ci

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Apr 1, 2026

/rerun-failed-ci

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Apr 1, 2026

Sampling Kernel: Correctness Fix + Hybrid Dispatch

Problem

The original fused Triton kernel used a numerically different computation order
from PyTorch, causing structured output (grammar-constrained decoding) to fail.

Three sources of divergence:

Original kernel PyTorch path
Temperature scaling x * (1/temp) logits.div_(temp)
Softmax accumulation Online softmax (with correction term) 3-pass: max → sum → normalize
Normalization exp(x - max - log(sum)) exp(x - max) / sum

Fix

Rewrote the kernel to follow the same mathematical operation order as PyTorch
(x / temp → exp(x - max) / sum). Also removed the grammar fallback branch in
sampler.py — the fused kernel now always runs.

The only remaining difference is sub-ULP reduction ordering inside tl.max /
tl.sum vs PyTorch's CUDA kernels, which is extremely unlikely to affect
sampling results in practice.


Performance Impact

Aligning with PyTorch's algorithm required moving from 2-pass to 3-pass
reduction (one extra read pass), which regressed small-batch latency slightly.

Before (original kernel, numerically divergent from PyTorch):

bs vocab triton_ip (µs) flashinfer (µs) fi/tri
1 32K 63.0 61.5 0.98x
1 128K 77.2 62.0 0.80x
32 32K 61.3 64.2 1.05x
32 128K 77.1 86.1 1.12x
128 32K 62.5 74.1 1.19x
128 128K 118.8 308.2 2.59x
512 32K 110.8 259.5 2.34x
512 128K 463.6 1183.2 2.55x

After fix (numerically aligned with PyTorch, grammar decoding works):

bs vocab triton_ip (µs) flashinfer (µs) fi/tri
1 32K 50.6 43.5 0.70x
1 128K 70.3 44.2 0.56x
32 32K 50.2 44.0 0.72x
32 128K 76.7 81.0 0.95x
128 32K 50.4 71.0 1.15x
128 128K 176.6 308.7 1.74x
512 32K 113.5 259.7 1.49x
512 128K 609.1 1184.2 1.81x

Small-batch latency improved (~63µs → ~50µs at bs=1/32K) but triton_ip now
falls behind flashinfer at small batches. Large-batch advantage also narrowed
(bs=512/128K: 2.55x → 1.81x over flashinfer).


Hybrid Dispatch

To recover small-batch performance without sacrificing correctness, added a
dispatch layer:

  • bs < 128torch.div_ + torch.softmax (faster for small batches,
    mathematically identical)
  • bs >= 128 → fused Triton kernel (2–3x faster than flashinfer at large
    batches)

Both paths perform the same operations (x / temp → softmax) so grammar-
constrained decoding works correctly on either path.

Memory access pattern comparison:

Path Reads Writes
PyTorch (div_ + softmax) 4 2
Fused Triton (3-pass) 3 1

@DarkSharpness Do you think it is good approach or not?

@DarkSharpness
Copy link
Copy Markdown
Collaborator

/ temp → softmax) so grammar- constrained decoding works correctly on either path.

Memory access pattern comparison:

For now, the baseline is naive torch implementation. As long as your implementation is faster than torch native (div + softmax), it will be fine. We can try flashinfer solution in later PRs as follow-up.

@Godmook
Copy link
Copy Markdown
Contributor Author

Godmook commented Apr 2, 2026

@DarkSharpness I think XPU test and Xeon test has some problems now. I try to fix XPU CI tests in another PR now. So I'll try to do rerun and if there is no problem for other CIs, I think it is good now. I'll do rerun ci after 2 hours because of time limit...

top_p_renorm_prob,
)

from sglang.srt.layers.fused_sampling import fused_temperature_softmax_inplace
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this import to the top of the file?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fused_sampling pulls in Triton at import time. we keep this import inside if is_cuda(): with the other CUDA-only deps (flashinfer, sgl_kernel) so sampler stays importable on non-CUDA builds.

Thanks for help me :)

@BBuf BBuf merged commit 7a59e05 into sgl-project:main Apr 2, 2026
131 of 166 checks passed
@Godmook Godmook deleted the fused_sampling branch April 2, 2026 04:57
realray808 pushed a commit to Ascend/sglang that referenced this pull request Apr 3, 2026
* [AMD] Fix AMD CI monitor GitHub API rate limit exhaustion (sgl-project#21527)

* [CI] Register missing jit_kernel test files (sgl-project#21547)

* [diffusion] fix: return None instead of raising RuntimeError when no model info found (sgl-project#21319)

Co-authored-by: Mick <mickjagger19@icloud.com>

* [rl][sgl] fix tensor mismatch after pause (sgl-project#21514)

* [Hicache & JIT_kernel] Support page first layout  & mla jit kernel (sgl-project#18311)

* test: point DSV3 int8 MLA CI models to lmsys Hugging Face org (sgl-project#21561)

* [CI] Relax several thresholds in flaky CIs (sgl-project#21562)

* feat: add gc_threshold arg (sgl-project#21481)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Fix flaky test_pp_single_node (sgl-project#21564)

* Split workflow for releasing runtime docker (sgl-project#21563)

* fix tp capture in vit cuda graph (sgl-project#17255)

* [1/n] lora support - Auto detect lora target modules (sgl-project#21439)

Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>

* [fix] qwen3.5 fuse_moe_triton_tune bug (sgl-project#20232)

* Remove sync when enabling return_logprob (sgl-project#20972)

* Scope streaming backlog coalescing to incremental_streaming_output mode (sgl-project#21037)

Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>

* docs: flesh out MAINTAINER.md oncall lists and link GitHub profiles (sgl-project#21575)

* [NVIDIA] Enable automatic NUMA configuration (sgl-project#19452)

* [diffusion] UX: aggregate expected dtype-cast logs during weight loading (sgl-project#21552)

* [diffusion] refactor: Unify `TeaCacheParams` and `WanTeaCacheParams` (sgl-project#20706)

Co-authored-by: Mick <mickjagger19@icloud.com>

* [diffusion] chore: remove redundant identity preprocess_text functions(sgl-project#20633)

Co-authored-by: Fengyuan Yu <15fengyuan@gmail.com>

* Update CODEOWNERS for transformers.py and docs (sgl-project#21555)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>

* reduce CPU peak memory in multimodal tensor hashing (sgl-project#21123)

* Fix HFRunner hang when subprocess dies during init (sgl-project#21582)

* Fix Piecewise CUDA Graph crash with `-enable-mixed-chunk` (sgl-project#20441)

Co-authored-by: jianyingzhu <joeyzhu@nvidia.com>

* [CI] Replace upload/download-artifact with job outputs in release-docker workflow (sgl-project#21579)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Patch transformers is_base_mistral in CI to avoid HF 429 rate limiting (sgl-project#21586)

* [CI] Move v32 cp test to deepep running suite (sgl-project#21585)

* [AMD] Add GLM-4.7-FP8 accuracy CI test for MI35x (sgl-project#21534)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* [Clean] Remove deprecated environs (sgl-project#21536)

* [diffusion] fix: fix Flux2-Klein prompt tokenization length to 512 and add regression coverage (sgl-project#21407)

* [CI] hot-fix ci lint (sgl-project#21608)

* [diffusion] feat: support overlay model materialization (sgl-project#21600)

* [VLM] Optimize ShmPointerMMData for multi-pickle safety and deferred unwrap (sgl-project#21465)

* feat: enable CUDA graph and timestamp for the whisper model(sgl-project#21190)

* [NPU] Update quantization&CI documentation (sgl-project#21100)

Co-authored-by: Tamir Baydasov <41994229+TamirBaydasov@users.noreply.github.com>

* Skip ci for .md files (sgl-project#21482)

* Support skip-softmax attention (sgl-project#19089)

* fix: piecewise_cuda_graph get correct qo_indptr (sgl-project#21452)

Co-authored-by: Avery Huang <averyh@nvidia.com>

* fix bench_serving sglang backend to support image dataset  (sgl-project#21294)

* [AMD] Add peft>=0.18.0 to diffusion_hip deps for transformers 5.x compat for AMD diffusion model (sgl-project#21442)

Co-authored-by: HaiShaw <hixiao@gmail.com>

* [GDN] Fuse GDN kkt + solve_tril into one kernel (sgl-project#21411)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>

* [Diffusion] Align diffusion benchmark skill presets with nightly comparison cases (sgl-project#21616)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Clean up detokenizer and remove dead multimodal_gen code (sgl-project#21588)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [CI] Skip flaky elastic EP test (sgl-project#21619)

* feat(ci): add GB300 nightly benchmark test suites (sgl-project#21487)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [CI] Lossen test_return_routed_experts threshold (sgl-project#21270)

* Add subprocess liveness monitor to detect scheduler crashes (sgl-project#18582)

Co-authored-by: 继优 <jiyou.ljy@alibaba-inc.com>
Co-authored-by: shuwenn <47200617+alphabetc1@users.noreply.github.com>

* fix: scheduler launch hang when non-current rank dies (sgl-project#20287)

* Wrap IPv6 addresses in gRPC, bench_serving, and log messages (sgl-project#21236)

Co-authored-by: hnyls2002 <lsyincs@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>

* [HiCache] fix: graceful shutdown of pending async tasks in bench_mix.py (sgl-project#20276)

* Clean up _wait_for_scheduler_ready implementation (sgl-project#21626)

* fix cuda graph capturing error in sm120 mxfp8 triton path (sgl-project#19835)

* [sgl] disable piecewise cuda graph when a model doesn't have layers (sgl-project#21565)

* [Feature] Optimizations for JPEG input on NVIDIA GPU (sgl-project#19749)

* [VLM] perf: optimize CUDA IPC for multimodal transfer by caching IPC pool handles (sgl-project#21418)

* [Fix] SGLANG_USE_CUDA_IPC_TRANSPORT=1 and SGLANG_ENABLE_MM_SPLITTING=1 do not work at the same time. (sgl-project#19915)

* [Fix] Remove redundant allreduce fusion block and skip TP=1 (sgl-project#20621)

* Simplify routed experts test and move base64 encoding to tokenizer manager (sgl-project#21634)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [Cleanup] Remove unused BatchMultimodalOutput and BatchMultimodalDecodeReq (sgl-project#21640)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Clean up TokenizerManager: remove dead code and improve rid validation (sgl-project#21639)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* README: coding agent sponsorship for long-term contributors (sgl-project#21642)

* Fix circular reference in CustomTestCase.__init_subclass__ (sgl-project#21650)

Co-authored-by: wan4ch <wan4ch@gmail.com>

* [Fix] Fix Qwen3.5 MoE model loading and Mamba cache sharding in PP mode (sgl-project#21448)

Co-authored-by: zhangxiaolei123456 <zhangxiaolei.666@bytedance.com>

* [diffusion] CI: fix dashboard chart (nightly) display issues (sgl-project#21653)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update sponsorship details in README.md (sgl-project#21658)

* [Fix] Handle pre-release tags in nightly wheel version parsing (sgl-project#21656)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [Intel GPU] Enable DeepSeek R1 inference on XPU (sgl-project#18461)

Signed-off-by: P V R K Jyothendra Varma <polisetty.v.r.k.jyothendra.varma@intel.com>

* [Doc] Update tips for developer new-comers (sgl-project#21659)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* [CI] [FlashInfer v0.6.7] Use offline quantized checkpoint for MXFP8 Gemm tests (sgl-project#21625)

* MFU metrics in Prometheus  (sgl-project#19395)

* fix topk softmax performance issue (sgl-project#14702)

* [CPU] add kernel apply_rotary_pos_emb_cpu for Qwen3-VL and Qwen3-Omni (sgl-project#13121)

Co-authored-by: Ma Mingfei <mingfei.ma@intel.com>

* [CPU] Implement MXFP4 Gemm kernels for intel AMX to support GPT OSS series. (sgl-project#14385)

* [AMD] Fused rope kv store (sgl-project#21315)

Co-authored-by: wunhuang <wunhuang@amd.com>

* [NPU] Update DeepSeek-V3.2 model deployment instructions in documentation (sgl-project#21468)

Co-authored-by: wuxue (C) <w00964934@china.huawei.com>

* [AMD] Support AMD MXFP4 Qwen3.5-397B-A17B model (sgl-project#21234)

* [Fix] Fix weight_loader property assignment for qwen3-next FP8 models (sgl-project#21662)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix mamba cache leak when adder fails to add a matched req. (sgl-project#21404)

* fix: Mistral Small 4 fails to start due to config/weight format mismatch (sgl-project#21620)

Co-authored-by: mengxiancheng03 <mengxiancheng03@kuaishou.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [diffusion] feat: enhance overlay mechanism (sgl-project#21648)

* [diffusion] CI: relax pr-test threshold (sgl-project#21682)

* [NPU][Diffusion] fix sp modulate for qwen-image-edit (sgl-project#20974)

Co-authored-by: 高鑫 <gaoxin@gaoxindeMacBook-Pro.local>

* [NPU] fix eagle3 accept rate (sgl-project#21255)

* DeepSeek-R1-0528-w4a8: DeepEP Low Latency Dispatch Adopts FP8 Communication (sgl-project#14162)

Co-authored-by: undefined <zhouchen.arrebol@jd.com>

* [NPU] GLM-5 optimize with fused kernels (sgl-project#18617)

* [NPU][diffusion]: support parallel decoding of qwen-image (sgl-project#20757)

Co-authored-by: 高鑫 <gaoxin@gaoxindeMacBook-Pro.local>

* [diffusion] [NPU] support ring attention on NPU with FA (sgl-project#21383)

* [diffusion][doc]: add ring sp performance benchmark page (sgl-project#20998)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* [GLM-V and GLM-4.7] Cast to FP32 before gate projection for GLM model. (sgl-project#21660)

* fix nemotron capture for non attention layers (sgl-project#21436)

* [Bugfix][NPU] Skip FRACTAL_NZ format for MoE weights with unaligned dimensions (sgl-project#21209)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: ronnie_zheng <zl19940307@163.com>

* [AMD] Add SGLANG_DISAGGREGATION_NUM_PRE_ALLOCATE_REQS env var for configurable KV transfer overlap (sgl-project#20410)

Co-authored-by: HaiShaw <hixiao@gmail.com>

* [AMD][MoRI] bump MoRI to v0.1.0 (sgl-project#21673)

* [AMD] fix performance regression issue when run gpt-oss with "--context-length 13824" (sgl-project#21691)

* Remove flashinfer wheel cache cleanup that deletes other versions (sgl-project#21711)

Co-authored-by: Alison Shao <alison.shao@MacBook-Pro-D2W773R9CD.local>

* [misc] multiprocess compilation to speed up test (sgl-project#21483)

* Fix human-eval CI install on 5090 runners (sgl-project#21714)

Co-authored-by: Alison Shao <alison.shao@Mac.attlocal.net>

* Revert "DeepSeek-R1-0528-w4a8: DeepEP Low Latency Dispatch Adopts FP8 Communication" (sgl-project#21719)

* [Fix] Update supported custom_mem_pool types for mooncake (sgl-project#21728)

Co-authored-by: 百麒 <yaozhong.lyz@alibaba-inc.com>

* [Perf]Remove H2D  for Qwen3.5 SpecV2 (sgl-project#20864)

* [AMD] Fix CI multimodal-gen-test-1-gpu-amd for gen model  (sgl-project#21621)

* [diffusion] fix: fix Flux.2 with tp(sgl-project#21664)

* Add explicit disable flag for FlashInfer allreduce fusion (sgl-project#21446)

* [NPU] fix conflict between empty_cache and use_mem_pool (sgl-project#21507)

* [AMD] Use tgemm.mm for MoEGate router gemm in deepseek_v2.py (sgl-project#21657)

* [CI]Remove msgm-en and mmlu tests which cause timeout (sgl-project#21733)

* Fix disaggregation hybrid attention ci (sgl-project#21745)

* Rename rerun-ut to rerun-test (sgl-project#21747)

* bugfix(model):fix deepstack index out of range error (sgl-project#21727)

Co-authored-by: xiaoqi.31 <xiaoqi.31@jd.com>

* [diffusion] fix: fix typo (sgl-project#21746)

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* [CI] Fix rerun-test suite detection to skip commented registrations (sgl-project#21753)

* [PD] Refactor Disagg Conn and Fix Hang with total_request/total_tokens Balancing (sgl-project#21299)

Co-authored-by: Weiliangl User <weiliangl@login-node.hosted.internal>

* [CI] Fix ring test timeout (sgl-project#21751)

* Enable evict swa with piecewise cuda graph (sgl-project#21754)

* Fix kimi-linear launch server error (sgl-project#21752)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>

* [PD] Tiny cleanup after KVReceiver refactor (sgl-project#21760)

Signed-off-by: Shangming Cai <csmthu@gmail.com>

* Fix remote weight info nnode>1 and dp>1 (sgl-project#17389)

* [diffusion] UX: replace deprecated ORJSONResponse with orjson_response (sgl-project#21755)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* [diffusion] fix: fix Wan2.2-I2V-A14B video max size issue(sgl-project#21390)

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Co-authored-by: Mick <mickjagger19@icloud.com>

* [HiMambaTree]: Optimize mamba host lock mechanism (sgl-project#21750)

* [AMD] Fix Handle missing rope_theta in get_rope_config for Grok-1 (sgl-project#21518)

* [bugfix] Fix rope theta config for MiniMax after transformers v5 update (sgl-project#21241)

* Fix ineffective is_base_mistral CI patch for HF API rate limiting (sgl-project#21729)

* [2/n] lora - Shared outer experts and support qwen3_30b_a3b_instruct (sgl-project#21466)

Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>

* Fix cuda graph max bs capture upper bound (sgl-project#21005)

* [Fix] Fall back to triton MOE for GPT-OSS on Blackwell with driver >= 595 (sgl-project#21780)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Cache nvidia wheels locally to skip repeated 830 MB downloads in CI (sgl-project#21778)

* Add Trivy vulnerability scanning to nightly dev Docker builds (sgl-project#21772)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [CI] Remove more redundant PCG tests (sgl-project#21554)

* [moe] add customized option to moe-a2a-backend (sgl-project#21786)

* Add CompletionSampler for non-chat eval in run_eval (sgl-project#21785)

* Remove redundant test_moe_eval_accuracy_large (sgl-project#21787)

* Increase hicache eval to 200 examples (sgl-project#21791)

* Switch MooncakeSpec to EAGLE3 + Llama-3.1 (sgl-project#21794)

* Reduce redundant speculative decoding CI tests (sgl-project#21779)

* Fix killall.py crash when sglang is not yet installed (sgl-project#21797)

* Remove obsolete sgl-kernel legacy paths (sgl-project#21528)

* [jit_kernel] Optimize fused_qknorm_rope: deduplicate sincosf for interleave RoPE  (sgl-project#21654)

* CUTLASS NVFP4 GEMM improvement of SM120 (sgl-project#21314)

* [gRPC] Preserve original ImportError in grpc_server.py (sgl-project#21801)

Signed-off-by: Chang Su <chang.s.su@oracle.com>

* [Misc] Tiny: Add test network timeouts and dynamic max-parallel for 5090/2-gpu runners (sgl-project#21800)

* Fix draft extend cuda graph when spec_step=1 (sgl-project#21709)

* [Diffusion] Add `--uvicorn-access-log-exclude-prefixes` to suppress noisy access logs (sgl-project#20379)

* Add latency and throughput metrics to run_eval (sgl-project#21793)

* [diffusion] CI: improve ci reliability (sgl-project#21763)

* [bugfix]GLM-4V model (sgl-project#17122)

* Fix CVEs in Docker image: pillow, linux-libc-dev, and broken sgl-model-gateway build (sgl-project#21789)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: only showing recent runners from ci failure analysis (sgl-project#21015)

* [MPS] Fix Triton stub sub-module imports on Python 3.12+ (sgl-project#21551)

Co-authored-by: karanb192 <karan@example.com>
Co-authored-by: R0CKSTAR <yeahdongcn@gmail.com>
Co-authored-by: R0CKSTAR <xiaodong.ye@mthreads.com>

* [KDA] Fuse scaled_dot_kkt + solve_tril + recompute_w_u for KDA (sgl-project#21604)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>

* chore: bump flashinfer version to 0.6.7 (sgl-project#21422)

Co-authored-by: sglang-bot <sglang-bot@users.noreply.github.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>

* [3/n] lora moe - Support Qwen3-VL-30B-A3B-Instruct  (sgl-project#21469)

Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>

* [Feature Restoration] repetition_penalty is essential for GLM-V models (sgl-project#21258)

Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: hnyls2002 <lsyincs@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>

* VLM: change default mm-attention backend from triton_attn to fa4 (on blackwell) (sgl-project#21595)

* Fix added tokens config with sensible filter (sgl-project#17905)

* [AMD] Optimize Qwen3-VL decode - fuse QK-norm + 3D mRoPE + KV cache write (sgl-project#21458)

Co-authored-by: Bingxu Chen <bingxche@amd.com>
Co-authored-by: HaiShaw <hixiao@gmail.com>

* [Bugfix] Fix PP tied embeddings weight loading for qwen3.5 4B dense model (sgl-project#21347)

* [CI] Fix lint that was not applied in sgl-project#21458 (sgl-project#21818)

* Bug fix for llama eagle3 (sgl-project#21397)

* glm_interleave for GLM-V (sgl-project#21671)

* style refinement for hisparse (sgl-project#21198)

* [Bug][VLM] Fix shared memory race condition in ShmPointerMMData broadcast for multi-GPU VLM serving (sgl-project#21655)

* [Bugfix] Fix effective_mamba_size over-allocation (sgl-project#20858)

Co-authored-by: Shangming Cai <csmthu@gmail.com>

* Fix in-place mode in pause generation (sgl-project#21705)

* [diffusion] fix: respect --prompt-path (sgl-project#21756)

* [NPU] update ascend docs (sgl-project#21807)

* [VLM] remove AsyncMMDataProcessor wrapper (sgl-project#21651)

* Use CustomTestCase for TestSessionControl to enable CI retry (sgl-project#21830)

* [NPU]Add a full test pipeline on NPU, resolve issues in the NPU test architecture (sgl-project#20751)

* [diffusion][CI]: Add individual component accuracy CI for diffusion models (sgl-project#18709)

Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>

* [Feature] JIT rmsnorm update (with claude) (sgl-project#21834)

* [Diffusion][NPU] add ring sp performance benchmark page in npu (sgl-project#21811)

* fix(MiMo-V2-Flash): add mimo reasoning parser (sgl-project#21414)

* [diffusion] hardware: support FA3 attention backend on MUSA (attn backend, 14/N) (sgl-project#18648)

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Co-authored-by: Mick <mickjagger19@icloud.com>

* fix: pre-init tokenizer_manager to avoid AttributeError in shutdown (sgl-project#21824)

* [FlashInver v0.6.7] Integrate flashinfer_trtllm mxfp8 gemm (sgl-project#21576)

* [Misc] Add network timeout to eval dataset downloads (sgl-project#21873)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [refactor] Clean up duplicate flashinfer trtllm moe code (sgl-project#21233)

* [DSA] Support trtllm sparse mla kernel for prefill batches  (sgl-project#21783)

* [Disagg] GPU staging buffer with dynamic ring allocator for heterogeneous TP KV transfer (sgl-project#19890)

* Add merge prohibition policy during CI maintenance mode (sgl-project#21882)

* [Misc] Fix comparator e2e tests: add polars dep + fix dp-attention test (sgl-project#21804)

Co-authored-by: Alison Shao <alison.shao@mac.lan>

* revert: remove TTL-based hard pin from HiRadixCache (sgl-project#21884)

* Unify GSM8K eval path to Chat API for regression CI readiness (sgl-project#21667)

* [HiCache] fix: Clone host indices to avoid memory leak (sgl-project#21624)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>

* [HiCache & PD]Fixed detailed cache hit breakdown in PD scenarios. (sgl-project#21764)

* [CI] Add Llama 3.1 8B Instruct FP4 CI test on SM120 (sgl-project#20648)

* [CI] Add Per-Tensor, Blockwise FP8 Tests on SM120 (sgl-project#20717)

Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>

* Allow /rerun-test to checkout fork PR branch for trusted users (sgl-project#21890)

* Direct model loading from object storage with Runai Model Streamer (sgl-project#17948)

Signed-off-by: Noa Neria <noa@run.ai>

* fix pcg torch dynamo recompile in mxfp8 Triton path (sgl-project#21888)

Co-authored-by: Hanlin Bi <hanlinbi@umich.edu>

* chore: bump mooncake version to 0.3.10.post1 (sgl-project#21844)

* [VLM] Add VLM TP=4 per-commit CI test and improve MMMU eval prompt/parser (sgl-project#21841)

* fix(ci): update est_time for 57 tests based on runtime analysis (sgl-project#21896)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [CI] Increase multimodal server test timeout from 60 to 90 minutes (sgl-project#21897)

* [CI] Remove crashing Kimi K2.5 EAGLE3/MTP variants, keep TP8 and TP8+DP8 (sgl-project#21898)

* [diffusion] CI: add initial nvfp4 ci test for b200 (sgl-project#21767)

Co-authored-by: Mick <mickjagger19@icloud.com>

* Migrate all callers from /get_server_info to /server_info (sgl-project#21463)

* Support PP key for file backend (sgl-project#21901)

* Enable multi-thread weight loading by default (sgl-project#20289)

* Skip Go stdlib and NVIDIA tool CVEs in Trivy scan (sgl-project#21905)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [Kernel] Fuse temperature + softmax in sampling for decode speedup (sgl-project#20501)

* Multi tool streaming fix (sgl-project#20004)

* Return HTTP 400 for streaming validation errors (sgl-project#21900)

* [Spec][Ngram] 4/N: Remove `max_match_window_size` and `min_match_window_size`, matching all suffixes of the Trie (sgl-project#21225)

* Fix ngram doc for speculative_num_draft_tokens default (sgl-project#21910)

* [NVIDIA] Enable fp8 flashinfer_trtllm_routed MoE for MiniMax-M2.5 (sgl-project#20394)

* scheduler: add prefill-only update in merge batch (sgl-project#21840)

* [DSA] Set trtllm kernels as nsa default for Blackwell (sgl-project#21914)

* Revert "Rollback flashmla to older version [1/2]" (sgl-project#21922)

* test: add manual init test for mooncake transfer engine (sgl-project#21842)

Co-authored-by: yunzhi <ningyunxiao.nyx@antgroup.com>

* Fix spec v2 + logprob when max_num_token is set (sgl-project#20799)

* Migrate ngram corpus from torch cpp_extension to TVM FFI jit_kernel (sgl-project#21920)

Co-authored-by: DarkSharpness <2040703891@qq.com>

* [NPU] Support  GLM-4.7-Flash on NPU (sgl-project#21408)

* [CI] Fix gpu deps import in cpu test (sgl-project#21950)

* [Parallel State Refactor 1/n] Remove stream of PyNCCL (sgl-project#20866)

* [diffusion] chore: fix stage profiler for multi-stage denoising (sgl-project#21955)

* [CI] [Tracing] Add ci for tracing and fix bugs (sgl-project#21740)

* Remove logging for subprocess watchdog start (sgl-project#21968)

* [4/n] Support gpt oss 20b lora (sgl-project#21570)

* [MUSA][9/N] Add FA3 attention backend support through MATE (MUSA AI Tensor Engine) (sgl-project#17985)

Co-authored-by: R0CKSTAR <xiaodong.ye@mthreads.com>

* [Feature] Stronger transformers modeling backend with TP, PP, MoE, VLMs, and torch compile (sgl-project#19163)

* [CI] Remove stale Ascend suite entries from test/srt/run_suite.py (sgl-project#21978)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Skip broken AutoModel mapping entries when resolving Llava submodules (sgl-project#21892)

* [CI] Add timeouts to Slack upload urlopen and WebClient (sgl-project#21903)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [Diffusion][NPU] Add support for MOVA (sgl-project#21633)

Co-authored-by: zhangshuai (S) <z00836796@china.huawei.com>

* Remove maxItems=1 restriction when tool_choice is specified (sgl-project#20208)

* [Feature] NVFP4 Marlin fallback for non-Blackwell GPUs (SM75+) (sgl-project#19652)

* [PP] qwen3 vl skip layer id for pp (sgl-project#19135)

* [VLM] Enable per-image MM splitting by default and remove MULTI_IMAGES modality (sgl-project#21899)

* [Bugfix] Fix incorrect dp-attention parallel info in bench_one_batch (sgl-project#21519)

* Revert "[MUSA][9/N] Add FA3 attention backend support through MATE (MUSA AI Tensor Engine)" (sgl-project#22002)

* [NPU] Optimized the wording in the npu docs (sgl-project#21998)

* [Parallel State Refactor 2/n] Unify code path of AMD deterministic all reduce (sgl-project#20871)

* [AMD] Resolve the performance degression when launch server with "--enable-aiter-allreduce-fusion" (sgl-project#21947)

Co-authored-by: wunhuang <wunhuang@amd.com>

* chore: bump sgl-kernel version to 0.4.1 (sgl-project#21447)

Co-authored-by: sglang-bot <sglang-bot@users.noreply.github.com>

* [Workflow] Avoid triggering nightly tests in kernel bump workflow (sgl-project#22010)

* [Workflow] Fix kernel release jobs skipped on push events (sgl-project#22011)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [PD]: Add support for HiSparse to directly transfer the cache from Prefill to Decode DRAM. (sgl-project#21591)

Co-authored-by: Tingwei Huang <huangtingwei9988@gmail.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>

* [Misc] Update CI permission (sgl-project#22014)

* [ROCM][RL] Shuffle Weight In-Place to Preserve Parameter Attributes (sgl-project#21825)

* [CI] Fix duplicate job names that bypass branch protection (sgl-project#22001)

* fix: remove duplicate words in comments (sgl-project#22007)

* [PD] Tiny register info field cleanup for mooncake backend (sgl-project#22016)

* [NPU] optimize glm4.7 (sgl-project#19246)

* [AMD] Enable FP8 KV cache and FP8 attention kernel for NSA on MI300/MI355 with TileLang backend (sgl-project#21511)

* [AMD] Add MiniMax-M2.5 nightly perf benchmarks for MI30x and MI35x (sgl-project#21524)

---------

Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Signed-off-by: P V R K Jyothendra Varma <polisetty.v.r.k.jyothendra.varma@intel.com>
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Signed-off-by: Shangming Cai <csmthu@gmail.com>
Signed-off-by: Chang Su <chang.s.su@oracle.com>
Signed-off-by: Noa Neria <noa@run.ai>
Co-authored-by: Bingxu Chen <bingxche@amd.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: yang1002378395-cmyk <yang1002378395@gmail.com>
Co-authored-by: Mick <mickjagger19@icloud.com>
Co-authored-by: Bi Xue <bi@thinkingmachines.ai>
Co-authored-by: huangtingwei <141888744+huangtingwei9988@users.noreply.github.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: Muqi Li <muqi1029@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Qiaolin Yu <liin1211@outlook.com>
Co-authored-by: narutolhy <582909902@qq.com>
Co-authored-by: Ethan (Yusheng) Su <yushengsu.thu@gmail.com>
Co-authored-by: zhangxiaolei <zhangxiaolei.666@bytedance.com>
Co-authored-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Co-authored-by: Trevor Morris <tmorris@nvidia.com>
Co-authored-by: Eitan Turok <150733043+eitanturok@users.noreply.github.com>
Co-authored-by: Fengyuan Yu <Yuandao151112@163.com>
Co-authored-by: Fengyuan Yu <15fengyuan@gmail.com>
Co-authored-by: Adarsh Shirawalmath <114558126+adarshxs@users.noreply.github.com>
Co-authored-by: Yuhao Yang <47235274+yhyang201@users.noreply.github.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Jianying <53503712+jianyingzhu@users.noreply.github.com>
Co-authored-by: jianyingzhu <joeyzhu@nvidia.com>
Co-authored-by: Kangyan-Zhou <zky314343421@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Jacob0226 <jacchang@amd.com>
Co-authored-by: Aditya Sharma <89210949+adityavaid@users.noreply.github.com>
Co-authored-by: Yuan Luo <yuan.luo@hotmail.com>
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: Артем Савкин <58187114+OrangeRedeng@users.noreply.github.com>
Co-authored-by: Tamir Baydasov <41994229+TamirBaydasov@users.noreply.github.com>
Co-authored-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: eigen <52445717+yyihuang@users.noreply.github.com>
Co-authored-by: Avery Huang <averyh@nvidia.com>
Co-authored-by: jacky.cheng <yichiche@amd.com>
Co-authored-by: HaiShaw <hixiao@gmail.com>
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: Junrong Lin <33685709+ocss884@users.noreply.github.com>
Co-authored-by: Simon (Jiyou) Li <Simon-Li@users.noreply.github.com>
Co-authored-by: 继优 <jiyou.ljy@alibaba-inc.com>
Co-authored-by: shuwenn <47200617+alphabetc1@users.noreply.github.com>
Co-authored-by: psaab <ps@meta.com>
Co-authored-by: hnyls2002 <lsyincs@gmail.com>
Co-authored-by: Hanlin Bi <52993433+wolfcomos@users.noreply.github.com>
Co-authored-by: wili <98001977+wili-65535@users.noreply.github.com>
Co-authored-by: saatwiknagpal <saatwiknagpal@gmail.com>
Co-authored-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com>
Co-authored-by: wan4ch <wan4ch@gmail.com>
Co-authored-by: Feng Su <sufeng@linux.alibaba.com>
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Polisetty V R K Jyothendra Varma <polisetty.v.r.k.jyothendra.varma@intel.com>
Co-authored-by: Ziang Li <ziangli@umich.edu>
Co-authored-by: Aishwarya Ramasethu <56765596+aramasethu@users.noreply.github.com>
Co-authored-by: Ma Mingfei <mingfei.ma@intel.com>
Co-authored-by: blzheng <beilei.zheng@intel.com>
Co-authored-by: kk <43161300+kkHuang-amd@users.noreply.github.com>
Co-authored-by: wunhuang <wunhuang@amd.com>
Co-authored-by: Michelle Wu <michellewu351@gmail.com>
Co-authored-by: wuxue (C) <w00964934@china.huawei.com>
Co-authored-by: Hubert Lu <55214931+hubertlu-tw@users.noreply.github.com>
Co-authored-by: strgrb <zhangkaihong.zkh@antgroup.com>
Co-authored-by: LiYomi <106872109+LiYomi@users.noreply.github.com>
Co-authored-by: mengxiancheng03 <mengxiancheng03@kuaishou.com>
Co-authored-by: GXIN <37653830+gxxx-hum@users.noreply.github.com>
Co-authored-by: 高鑫 <gaoxin@gaoxindeMacBook-Pro.local>
Co-authored-by: heziiop <q_m_p@qq.com>
Co-authored-by: xieminghe1 <141820649+xieminghe1@users.noreply.github.com>
Co-authored-by: undefined <zhouchen.arrebol@jd.com>
Co-authored-by: Makcum888e <79456407+Makcum888e@users.noreply.github.com>
Co-authored-by: yuefeng Wu <33725817+ChefWu551@users.noreply.github.com>
Co-authored-by: Yuxuan Zhang <2448370773@qq.com>
Co-authored-by: Vedant V Jhaveri <vedantjh2@gmail.com>
Co-authored-by: ronnie_zheng <zl19940307@163.com>
Co-authored-by: Zhai Feiyue <80079571+ZhaiFeiyue@users.noreply.github.com>
Co-authored-by: jhchouuu <jiahzhou@amd.com>
Co-authored-by: Alison Shao <54658187+alisonshao@users.noreply.github.com>
Co-authored-by: Alison Shao <alison.shao@MacBook-Pro-D2W773R9CD.local>
Co-authored-by: DarkSharpness <76582120+DarkSharpness@users.noreply.github.com>
Co-authored-by: Alison Shao <alison.shao@Mac.attlocal.net>
Co-authored-by: Lewis <63569348+TTThanos@users.noreply.github.com>
Co-authored-by: 百麒 <yaozhong.lyz@alibaba-inc.com>
Co-authored-by: Jincong Chen <jincong.cjc@ant-intl.com>
Co-authored-by: xiazhahe <86939755+xiazhahe@users.noreply.github.com>
Co-authored-by: Thomas Wang <thomawan@amd.com>
Co-authored-by: Ke Bao <ispobaoke@gmail.com>
Co-authored-by: xiaoqi <xq25478@qq.com>
Co-authored-by: xiaoqi.31 <xiaoqi.31@jd.com>
Co-authored-by: R0CKSTAR <xiaodong.ye@mthreads.com>
Co-authored-by: weireweire <weiliangl@nvidia.com>
Co-authored-by: Weiliangl User <weiliangl@login-node.hosted.internal>
Co-authored-by: JD <jaedon.guo@gmail.com>
Co-authored-by: Zhangheng <hzh0425@apache.org>
Co-authored-by: Michael <13900043+michaelzhang-ai@users.noreply.github.com>
Co-authored-by: Yilong Zhao <74357408+happierpig@users.noreply.github.com>
Co-authored-by: Johnsonms <lizhaofu@gmail.com>
Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Co-authored-by: Chang Su <chang.s.su@oracle.com>
Co-authored-by: KnightLTC <56717110+KnightLTC@users.noreply.github.com>
Co-authored-by: Douglas Yang <dyang@college.harvard.edu>
Co-authored-by: Karan Bansal <karanb192@users.noreply.github.com>
Co-authored-by: karanb192 <karan@example.com>
Co-authored-by: R0CKSTAR <yeahdongcn@gmail.com>
Co-authored-by: sglang-bot <sglangbot@gmail.com>
Co-authored-by: sglang-bot <sglang-bot@users.noreply.github.com>
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: sbeurnier <sbeurnier@together.ai>
Co-authored-by: YC Yen-Ching Tseng <yctseng@amd.com>
Co-authored-by: Wenyao Gao <105094497+edwingao28@users.noreply.github.com>
Co-authored-by: Alex Nails <alex.nails@radixark.ai>
Co-authored-by: khalilzhk <khalilzhk@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
Co-authored-by: yudian0504 <138860534+yudian0504@users.noreply.github.com>
Co-authored-by: yunkchen <chenyunkuo.cyk@alibaba-inc.com>
Co-authored-by: wduan-hai <wduan@humansand.ai>
Co-authored-by: amote-i <49533125+amote-i@users.noreply.github.com>
Co-authored-by: Cherry_ming <136634645@qq.com>
Co-authored-by: Ratish P <114130421+Ratish1@users.noreply.github.com>
Co-authored-by: YAMY <74099316+YAMY1234@users.noreply.github.com>
Co-authored-by: Alison Shao <alison.shao@mac.lan>
Co-authored-by: ishandhanani <82981111+ishandhanani@users.noreply.github.com>
Co-authored-by: Derek Yu <81697272+DerekY2@users.noreply.github.com>
Co-authored-by: Noa Neria <noa@run.ai>
Co-authored-by: Hanlin Bi <hanlinbi@umich.edu>
Co-authored-by: Prozac614 <dwt614707404@163.com>
Co-authored-by: David Cheung <d7cheung@gmail.com>
Co-authored-by: Mook <68294499+Godmook@users.noreply.github.com>
Co-authored-by: Khoa Pham <khoa.pham@radixark.ai>
Co-authored-by: foraxe <73625538+foraxe@users.noreply.github.com>
Co-authored-by: yunzhi <ningyunxiao.nyx@antgroup.com>
Co-authored-by: DarkSharpness <2040703891@qq.com>
Co-authored-by: Todobe <43903496+Todobe@users.noreply.github.com>
Co-authored-by: ori <39351881+froststeam@users.noreply.github.com>
Co-authored-by: Thomas <zs033@qq.com>
Co-authored-by: zhangshuai (S) <z00836796@china.huawei.com>
Co-authored-by: lviy <142899752+lviy@users.noreply.github.com>
Co-authored-by: Tingwei Huang <huangtingwei9988@gmail.com>
Co-authored-by: Yuzhen Zhou <82826991+zyzshishui@users.noreply.github.com>
Co-authored-by: Ricardo-M-L <69202550+Ricardo-M-L@users.noreply.github.com>
Co-authored-by: Kelon <kelonlu@163.com>
Co-authored-by: cen121212 <luochen23@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants