Skip to content

Integrate fused flashinfer rope_quantize_fp8_append_paged_kv_cache kernel#19451

Open
leejnau wants to merge 16 commits intosgl-project:mainfrom
leejnau:integrate_mla_rope_quantize_fp8_append_paged_mla_kv_cache
Open

Integrate fused flashinfer rope_quantize_fp8_append_paged_kv_cache kernel#19451
leejnau wants to merge 16 commits intosgl-project:mainfrom
leejnau:integrate_mla_rope_quantize_fp8_append_paged_mla_kv_cache

Conversation

@leejnau
Copy link
Collaborator

@leejnau leejnau commented Feb 26, 2026

Motivation

Currently there are two kernels used to perform RoPE + Quantization (first kernel) + KV Cache writing (second kernel). Flashinfer now exposes a single fused kernel that performs RoPE + Quantization + KV Cache: flashinfer-ai/flashinfer#2037
We should use the single fused kernel for improved performance.

Modifications

Call the single fused flashinfer kernel instead of the two separate kernels where possible:

  • This only applies for the case where KV Cache writing is to occur.
  • This is only supported for KV Cache of type FP8.

The CUDA graph buffers must be pre-allocated for various metadata requirements.

Accuracy Tests

The container used for testing was lmsysorg/sglang:v0.5.8. The model tested was DeepSeek R1 FP8 (https://huggingface.co/deepseek-ai/DeepSeek-R1-0528).

GPQA Accuracy

server:

SGLANG_ENABLE_JIT_DEEPGEMM=false python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-0528 --tensor-parallel-size=8 --mem-fraction-static 0.85 --ep-size 8 --fp8-gemm-backend flashinfer_trtllm --kv-cache-dtype fp8_e4m3

client:

python3 -m sglang.test.run_eval --port 30000 --eval-name gpqa --num-examples 198 --max-tokens 128000 --repeat 8 --thinking-mode deepseek-v3

Fused kernel GPQA: 0.779625

Benchmarking and Profiling

NSys Profile Analysis

The same container was mentioned in the Accuracy Tests section was used.

server:

SGLANG_ENABLE_JIT_DEEPGEMM=false python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-0528 --tensor-parallel-size=8 --mem-fraction-static 0.85 --ep-size 8 --fp8-gemm-backend flashinfer_trtllm --kv-cache-dtype fp8_e4m3

client:

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 128 --random-input 1024 --random-output 1024 --random-range-ratio 1 --max-concurrency <N>

A variety of concurrency values <N> were used: 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048. These were tested across both the FP4 (https://huggingface.co/nvidia/DeepSeek-R1-0528-NVFP4) and FP8 (https://huggingface.co/deepseek-ai/DeepSeek-R1-0528) DeepSeek R1 models.

On average approximately 1 microsecond per layer was saved using the fused kernel, as observed from examining nsys profiles. As an example, for the FP4 model at concurrency 16, the fused kernel had a duration of 3.691 $\mu\text{s}$:

fused_conc16_fp4

Whereas the two kernels had a duration of 4.733 $\mu\text{s}$:

main_conc16_fp4

Note that these durations include any gaps before or after the kernel(s).

Stress Tests Isolating Decode

A "low noise decode focused" benchmark was run with concurrencies 1 and 4. This should serve to better isolate the performance improvement in the decode step.

server:

fp4:

python3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-0528-FP4-V2 --tensor-parallel-size=8 --cuda-graph-max-bs 256 --max-running-requests 256 --mem-fraction-static 0.85 --ep-size 8 --enable-symm-mem --moe-runner-backend flashinfer_cutlass --quantization modelopt_fp4

fp8:

SGLANG_ENABLE_JIT_DEEPGEMM=false python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-0528 --tensor-parallel-size=8 --mem-fraction-static 0.85 --ep-size 8 --fp8-gemm-backend flashinfer_trtllm --kv-cache-dtype fp8_e4m3

client:

# Profile: low_noise_decode_focus, concurrency=1
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 256 --random-input 16 --random-output 2048 --random-range-ratio 0.0 --max-concurrency 1

# Profile: low_noise_decode_focus, concurrency=4
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 256 --random-input 16 --random-output 2048 --random-range-ratio 0.0 --max-concurrency 4

Key Results (Baseline vs Fused)

Model Profile Conc Base Throughput Fused Throughput Throughput Δ Base TPOT Fused TPOT TPOT Δ
fp4 low_noise_decode_focus 1 111.45 112.11 +0.59% 8.861 8.803 +0.66%
fp4 low_noise_decode_focus 4 331.27 332.80 +0.46% 11.898 11.840 +0.49%
fp8 low_noise_decode_focus 1 124.05 124.72 +0.54% 7.985 7.944 +0.52%
fp8 low_noise_decode_focus 4 379.29 381.28 +0.52% 10.404 10.350 +0.52%

Takeaways

  • Low-noise cases (conc=1/4) show a consistent ~0.5-0.7% fused win in both throughput and TPOT for fp4 and fp8 models.
  • This matches the expected impact from the microsecond-level per-layer kernel savings seen in traces.

Conclusion

The fused kernel provides a real, consistent micro-optimization (~+0.5% class) in low-noise serving benchmarks, with no clear regression signal in those checks.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions bot added quant LLM Quantization blackwell SM100/SM120 npu labels Feb 26, 2026
@leejnau leejnau force-pushed the integrate_mla_rope_quantize_fp8_append_paged_mla_kv_cache branch 2 times, most recently from 09b7b96 to c6387ca Compare February 27, 2026 22:13
@leejnau leejnau force-pushed the integrate_mla_rope_quantize_fp8_append_paged_mla_kv_cache branch from c6387ca to f17fd50 Compare February 27, 2026 23:08
@leejnau leejnau marked this pull request as ready for review March 2, 2026 20:43
@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Copy link
Collaborator

@nvpohanh nvpohanh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! thanks

@nvpohanh
Copy link
Collaborator

This can be reviewed together with #15729 . They are very similar, except that one is for trtllm_mha and one is for trtllm_mla

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blackwell SM100/SM120 npu quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants