Skip to content

Conversation

@YAMY1234
Copy link
Contributor

@YAMY1234 YAMY1234 commented Nov 2, 2025

Motivation

Fuse MLA RoPE + FP8 quantization + KV-cache write into a single CUDA kernel to reduce kernel launches and global-memory traffic during decode. This primarily targets MLA’s FP8 path where we previously (1) applied RoPE, (2) quantized Q/K to FP8, and (3) wrote K into KV cache in separate steps.

Modifications

  • Add fused CUDA kernel mla_rope_fp8_kv_fused.cu that:

    • Applies RoPE to rope parts,

    • Quantizes {Q_nope,Q_rope,K_nope,K_rope} to FP8,

    • Emits Q (packed) and writes K directly into KV cache when locations are provided.

    • Wire a fused path in python/sglang/srt/layers/attention/trtllm_mla_backend.py that calls the fused op when available.

  • Add tests & microbench:

    • test/srt/test_mla_fp8_fused.py for correctness,
    • benchmark/kernels/bench_flashmla_fused_kv.py for microbenchmarking.
MLA RoPE + FP8 Quantization + KV Cache Write Fusion Benchmark
======================================================================
Config: Dn=512, Dr=64, iters=1000, warmup=100
======================================================================
nnz= 1024 | baseline=  0.020 ms | fused=  0.004 ms | speedup x4.76 (+375.8%)
nnz= 4096 | baseline=  0.029 ms | fused=  0.006 ms | speedup x4.66 (+366.0%)
nnz= 8192 | baseline=  0.037 ms | fused=  0.008 ms | speedup x4.50 (+349.8%)
nnz=16384 | baseline=  0.059 ms | fused=  0.012 ms | speedup x4.83 (+383.4%)
nnz=32768 | baseline=  0.119 ms | fused=  0.029 ms | speedup x4.14 (+313.6%)
hostuser@e762632918fb:/sgl-workspace/sglang$ 

Accuracy Tests

SGLANG_ENABLE_SPEC_V2=1 python3 -m sglang.launch_server   --model-path nvidia/DeepSeek-R1-0528-FP4-v2   --trust-remote-code   --quantization modelopt_fp4   --tp 8   --port 40020   --speculative-algorithm EAGLE   --speculative-num-steps 3   --speculative-eagle-topk 1   --speculative-num-draft-tokens 4   --kv-cache-dtype fp8_e4m3   --page-size 64

python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319 --port 40020
100%|██████████████████████████████████████████████████████████████████| 1319/1319 [02:00<00:00, 10.96it/s]
Accuracy: 0.955
Invalid: 0.000
Latency: 120.511 s

Benchmarking and Profiling

SGLANG_ENABLE_SPEC_V2=1 python3 -m sglang.launch_server   --model-path nvidia/DeepSeek-R1-0528-FP4-v2   --trust-remote-code   --quantization modelopt_fp4   --tp 8   --port 40020   --speculative-algorithm EAGLE   --speculative-num-steps 3   --speculative-eagle-topk 1   --speculative-num-draft-tokens 4   --kv-cache-dtype fp8_e4m3   --page-size 64

fused(3us):
image

main( 4+us & inter ops):

image image

Send_one:
main:

python3 -m sglang.test.send_one --port 40020
+-------------+--------+------------+-----------------+
| Latency (s) | Tokens | Acc Length | Speed (token/s) |
+-------------+--------+------------+-----------------+
|    1.499    |  512   |   3.048    |     341.46      |
+-------------+--------+------------+-----------------+

this pr:

python3 -m sglang.test.send_one --port 40020
+-------------+--------+------------+-----------------+
| Latency (s) | Tokens | Acc Length | Speed (token/s) |
+-------------+--------+------------+-----------------+
|    1.477    |  512   |   3.048    |     346.73      |
+-------------+--------+------------+-----------------+

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @YAMY1234, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant performance enhancement by fusing three critical operations—MLA RoPE application, FP8 quantization, and KV-cache writing—into a single CUDA kernel. This fusion aims to minimize kernel launch overheads and global memory access, which are bottlenecks in the FP8 decode path. The changes include the implementation of the new fused kernel, its integration into the existing attention backend with conditional activation for modern GPU architectures, and comprehensive benchmarking and testing to validate both performance gains and accuracy preservation.

Highlights

  • Performance Optimization: Fused MLA RoPE application, FP8 quantization, and KV-cache write into a single CUDA kernel to significantly reduce kernel launch overheads and global-memory traffic during the decode phase, primarily for the MLA's FP8 path.
  • New CUDA Kernel: Introduced a new CUDA kernel, mla_rope_fp8_kv_fused.cu, which performs all three operations (RoPE, FP8 quantization, KV-cache write) in a single pass, emitting packed Q and directly writing K to the KV cache.
  • Conditional Activation: The fused kernel is conditionally activated in the trtllm_mla_backend.py when running on NVIDIA GPUs with compute capability SM90+ (e.g., H100/B200), leveraging hardware support for FP8.
  • Benchmarking: Microbenchmarks show a substantial speedup of 4.14x to 4.83x for the fused operation compared to the baseline, and overall system tests indicate a slight improvement in token generation speed (from 341.46 to 346.73 tokens/s).
  • Accuracy Preservation: Accuracy tests confirm that the fusion does not negatively impact model performance, maintaining an accuracy of 0.955.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused CUDA kernel for MLA RoPE, FP8 quantization, and KV cache writing, which is a significant performance optimization. The changes are well-structured, including a new standalone build for the kernel, benchmarks, and correctness tests. The Python-side logic correctly integrates the new kernel, with a fallback to the old path, ensuring safety and compatibility. The CUDA kernel itself is well-written with vectorization for performance and a scalar fallback. I have a couple of suggestions for improving code maintainability in the setup script and the CUDA kernel. Overall, this is a high-quality contribution.

Comment on lines +41 to +44
"../sgl-kernel/csrc/elementwise/mla_rope_fp8_kv_fused.cu",
],
include_dirs=[
"../sgl-kernel/include",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The sources and include_dirs paths for the CUDAExtension use relative paths with ... This makes the build script dependent on the current working directory from which setup.py is invoked. To make it more robust, you could construct absolute paths based on the location of the setup.py file itself.

Suggested change
"../sgl-kernel/csrc/elementwise/mla_rope_fp8_kv_fused.cu",
],
include_dirs=[
"../sgl-kernel/include",
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "sgl-kernel", "csrc", "elementwise", "mla_rope_fp8_kv_fused.cu"),
],
include_dirs=[
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "sgl-kernel", "include"),

return (uint32_t)a0 | ((uint32_t)a1 << 8) | ((uint32_t)a2 << 16) | ((uint32_t)a3 << 24);
}

__device__ inline void rope_rotate(float& xr, float& xi, float c, float s, bool) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The bool parameter in the rope_rotate function signature is unnamed and unused within the function. It seems to be a remnant from a previous implementation. For better code clarity and maintainability, it's best to remove it from the function definition and all its call sites within this file.

__device__ inline void rope_rotate(float& xr, float& xi, float c, float s) {

@Qiaolin-Yu Qiaolin-Yu self-assigned this Nov 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants