Skip to content

[Feature] Add FP4 KV cache support for SM120 GPUs#21601

Open
samuellees wants to merge 1 commit intosgl-project:mainfrom
samuellees:nvfp4-kvcache-sm120
Open

[Feature] Add FP4 KV cache support for SM120 GPUs#21601
samuellees wants to merge 1 commit intosgl-project:mainfrom
samuellees:nvfp4-kvcache-sm120

Conversation

@samuellees
Copy link
Copy Markdown
Contributor

@samuellees samuellees commented Mar 28, 2026

Rebased from #18314

  • Enable NVFP4 KV Cache for SM120
  • Extract KVCacheQuantMethod ABC with NoneMethod/NVFP4Method/MXFP4Method subclasses
  • Migrate quantize kernel to flashinfer fp4_quantize, keep custom CUDA dequant
  • Add MTP (speculative decoding) support for NVFP4/FP8 KV cache
  • XQA decode kernel integration with proper scale factor handling
  • Centralize FP4 buffer creation, quantize/dequant, and scale management

Summary

Add NVFP4 (FP4 E2M1) KV cache quantization support for Blackwell GPUs, reducing KV cache memory by ~2x compared to FP8 with no accuracy loss on GSM8K.

Key changes

  • Strategy pattern for KV cache quantization: Introduce KVCacheQuantMethod ABC with NoneMethod,
    NVFP4Method, MXFP4Method subclasses (kv_cache_quant_method.py). Adding a new FP4 scheme only requires
    implementing one subclass and registering it.
  • NVFP4 two-level scaling: Per-tensor FP32 global scale + per-block FP8 E4M3 scale factors, stored
    alongside FP4 KV data.
  • Kernel dispatch:
    • Prefill: FlashInfer dequantizes FP4→FP8, then runs standard FP8 prefill kernel
    • Decode: TRT-LLM XQA kernel reads FP4 natively with two-level scales
  • Quantize via flashinfer fp4_quantize: Replaces custom JIT CUDA kernel with flashinfer's optimized
    implementation.
  • MTP (Multi-Token Prediction) support: target_verify / draft_extend route through XQA decode kernel
    with causal masking (--speculative-attention-mode decode).
  • Hybrid model support: Mamba state update works correctly under speculative attention mode for hybrid
    models (e.g., Qwen3.5-35B-A3B).

Usage

python3 -m sglang.launch_server \
    --model-path <model_path> \
    --kv-cache-dtype fp4_e2m1 \
    --prefill-attention-backend flashinfer \
    --decode-attention-backend trtllm_mha \
    --disable-radix-cache

Benchmark (Qwen3.5-35B-A3B, GSM8K 100q, Blackwell)

┌────────────────┬─────────────┬──────────┬────────────────────┐
│    KV Cache    │     MTP     │ Accuracy │     Throughput     │
├────────────────┼─────────────┼──────────┼────────────────────┤
│ FP8 (fp8_e4m3) │ No          │ 91.0%    │ 350.6 tok/s        │
├────────────────┼─────────────┼──────────┼────────────────────┤
│ FP4 (fp4_e2m1) │ No          │ 91.0%    │ 452.1 tok/s (+29%) │
├────────────────┼─────────────┼──────────┼────────────────────┤
│ FP8 (fp8_e4m3) │ draft_len=3 │ 96.0%    │ 652.2 tok/s        │
├────────────────┼─────────────┼──────────┼────────────────────┤
│ FP4 (fp4_e2m1) │ draft_len=3 │ 94.0%    │ 891.3 tok/s (+37%) │
└────────────────┴─────────────┴──────────┴────────────────────┘

FP4 achieves ~29-37% throughput improvement over FP8 with no accuracy degradation.

Requirements

- Blackwell GPU (SM120)
- CUDA 13.0+, PyTorch 2.9.1+
- FlashInfer >= 0.6.3 (built from source)

Changed files

┌────────────────────────────────┬─────────────────────────────────────────────────────────────────────┐
│              File              │                               Change                                │
├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤
│ kv_cache_quant_method.py       │ New — Strategy pattern ABC + NVFP4/MXFP4 subclasses                 │
├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤
│ kvfp4_tensor.py                │ FP4 quantize/dequantize kernels, flashinfer wrapper                 │
├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤
│ trtllm_mha_backend.py          │ XQA decode for FP4, MTP target_verify/draft_extend with causal mask │
├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤
│ flashinfer_backend.py          │ NVFP4 dequant state init, FP4→FP8 prefill path                      │
├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤
│ memory_pool.py                 │ Pool integration with quant_method, scale buffer management         │
├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤
│ model_runner_kv_cache_mixin.py │ NVFP4Method creation and scale loading                              │
├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤
│ attention_registry.py          │ Split prefill/decode backend validation for Blackwell               │
├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤
│ docs/nvfp4_kv_cache.md         │ New — Documentation with usage, benchmarks, architecture            │
└────────────────────────────────┴─────────────────────────────────────────────────────────────────────┘

Test plan

- GSM8K accuracy check (1319 questions, FP4 vs FP8 baseline)
- MTP + FP4 end-to-end serving test
- Verify no regression on FP8 KV cache path
- Verify BF16 KV cache path unaffected (NoneMethod)



## Checklist

- [ ] Format your code according to the [Format code with pre-commit](https://docs.sglang.io/developer_guide/contribution_guide.html#format-code-with-pre-commit).
- [ ] Add unit tests according to the [Run and add unit tests](https://docs.sglang.io/developer_guide/contribution_guide.html#run-and-add-unit-tests).
- [ ] Update documentation according to [Write documentations](https://docs.sglang.io/developer_guide/contribution_guide.html#write-documentations).
- [ ] Provide accuracy and speed benchmark results according to [Test the accuracy](https://docs.sglang.io/developer_guide/contribution_guide.html#test-the-accuracy) and [Benchmark the speed](https://docs.sglang.io/developer_guide/contribution_guide.html#benchmark-the-speed).
- [ ] Follow the SGLang code style [guidance](https://docs.sglang.io/developer_guide/contribution_guide.html#code-style-guidance).

## Review and Merge Process

1. Ping Merge Oncalls to start the process. See the [PR Merge Process](https://github.com/sgl-project/sglang/blob/main/.github/MAINTAINER.md#pull-request-merge-process).
2. Get approvals from [CODEOWNERS](https://github.com/sgl-project/sglang/blob/main/.github/CODEOWNERS) and other reviewers.
3. Trigger CI tests with [comments](https://docs.sglang.io/developer_guide/contribution_guide.html#how-to-trigger-ci-tests) or contact authorized users to do so.
 - Common commands include `/tag-and-rerun-ci`, `/tag-run-ci-label`, `/rerun-failed-ci`
4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements NVFP4 KV cache support for Blackwell GPUs (SM100/SM120), featuring a two-level scaling scheme to reduce memory overhead. The implementation includes a new quantization strategy pattern, FlashInfer and TRT-LLM XQA kernel integration, and memory pool updates. Review feedback identifies a critical regression in backend dispatching, potential silent failures in CUDA kernels on incompatible hardware, and several code quality improvements, including the removal of magic numbers and outdated TODO comments.

Comment on lines +438 to +439
return triton_w8a8_block_fp8_linear

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function _dispatch_auto_backend now unconditionally returns triton_w8a8_block_fp8_linear, making the subsequent backend selection logic for DeepGEMM, FlashInfer, etc., unreachable. This appears to be a temporary change for debugging and will break the intended automatic backend dispatching. This should be removed.

Comment on lines +562 to +568
#if HAS_FP8_SUPPORT
const float scale_0 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_0));
const float scale_1 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_1));
#else
const float scale_0 = 1.0f;
const float scale_1 = 1.0f;
#endif
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The CUDA kernel nvfp4_dequant_vectorized_kernel has a fallback for when HAS_FP8_SUPPORT is false, which sets the scales to 1.0f. This will produce incorrect dequantization results silently on hardware that doesn't support FP8, as it ignores the block scales. The kernel should instead fail with an error or there should be a compile-time assertion if FP8 support is required but not available.

Suggested change
#if HAS_FP8_SUPPORT
const float scale_0 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_0));
const float scale_1 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_1));
#else
const float scale_0 = 1.0f;
const float scale_1 = 1.0f;
#endif
#if !HAS_FP8_SUPPORT
#error "This kernel requires FP8 support, which is not available on this architecture."
#endif
const float scale_0 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_0));
const float scale_1 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_1));

Comment on lines +771 to +773
#else
return 0;
#endif
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fp32_vec_to_e2m1 device function returns 0 if __CUDA_ARCH__ < 1000. This will lead to incorrect quantization (all zeros) on older architectures without any warning. Since this function is critical for the SM100+ quantization kernel, it should assert that the architecture is supported.

Suggested change
#else
return 0;
#endif
#else
static_assert(__CUDA_ARCH__ >= 1000, "This function requires SM100 or newer architecture.");
return 0;
#endif

NVFP4 KV cache support is on the `nvfp4-kvcache-sm120-v2` branch. Clone from the fork and install in editable mode:

```bash
git clone -b nvfp4-kvcache-sm120-v2 https://github.com/samuellees/sglang.git
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The git clone command points to a personal fork (github.com/samuellees/sglang.git). For official documentation, this should be updated to point to the main project repository's branch before this pull request is merged.

Suggested change
git clone -b nvfp4-kvcache-sm120-v2 https://github.com/samuellees/sglang.git
git clone -b nvfp4-kvcache-sm120-v2 https://github.com/sgl-project/sglang.git

if sum(paged_seq_lens_cpu) > 0:
# [prefix_len, 256] -> [padded_prefix_len, 256] -> sum_tokens -> token_indices[page_size, ..., padde_prefix_len + 256 + page_size]
paged_seq_lens_cpu.append(256)
import numpy as np
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import numpy as np statement is inside the _prepare_nvfp4_metadata_for_extend_base method. According to PEP 8, imports should be at the top of the file. This also avoids repeated import overhead if the method is called multiple times.


# logger.debug(f"[KERNEL DEBUG] ====== End of Parameter Dump ======")

# TODO(Sam): NVFP4 kv cache is not supported or MTP. Because draft extend will invoke this api, it needs nvfp4 kv cache support.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This TODO comment appears to be outdated or confusing. It states that NVFP4 KV cache is not supported for MTP, but this pull request seems to add this support by routing draft_extend and target_verify through the XQA decode kernel. Please remove or update this comment to reflect the current implementation.

Comment on lines +213 to +215
if self.sm_version == 100:
k_scale *= 6.0
v_scale *= 6.0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code applies a hardcoded scaling factor of 6.0 for SM100 GPUs. While the comment explains this is for hardware alignment, such "magic numbers" are hard to maintain. It would be better to define this as a named constant with a more detailed explanation, and ideally a reference to NVIDIA's documentation if available.

Suggested change
if self.sm_version == 100:
k_scale *= 6.0
v_scale *= 6.0
# SM100 requires a 6x adjustment to align FP4 range with hardware expectations.
# See [link to NVIDIA doc or further explanation if possible].
SM100_FP4_SCALE_ADJUSTMENT = 6.0
if self.sm_version == 100:
k_scale *= SM100_FP4_SCALE_ADJUSTMENT
v_scale *= SM100_FP4_SCALE_ADJUSTMENT

@samuellees samuellees mentioned this pull request Mar 28, 2026
5 tasks
- Enable NVFP4 KV Cache for SM100 (B200) and SM120 (RTX PRO 6000)
- Extract KVCacheQuantMethod ABC with NoneMethod/NVFP4Method/MXFP4Method subclasses
- Migrate quantize kernel to flashinfer fp4_quantize, keep custom CUDA dequant
- Add MTP (speculative decoding) support for NVFP4/FP8 KV cache
- XQA decode kernel integration with proper scale factor handling
- Centralize FP4 buffer creation, quantize/dequant, and scale management
@samuellees samuellees force-pushed the nvfp4-kvcache-sm120 branch from a5dd741 to 4473ed7 Compare March 28, 2026 13:14
Copy link
Copy Markdown
Contributor Author

@samuellees samuellees left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review for the first round

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blackwell SM100/SM120 documentation Improvements or additions to documentation quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant