[Feature] Add FP4 KV cache support for SM120 GPUs#21601
[Feature] Add FP4 KV cache support for SM120 GPUs#21601samuellees wants to merge 1 commit intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements NVFP4 KV cache support for Blackwell GPUs (SM100/SM120), featuring a two-level scaling scheme to reduce memory overhead. The implementation includes a new quantization strategy pattern, FlashInfer and TRT-LLM XQA kernel integration, and memory pool updates. Review feedback identifies a critical regression in backend dispatching, potential silent failures in CUDA kernels on incompatible hardware, and several code quality improvements, including the removal of magic numbers and outdated TODO comments.
| return triton_w8a8_block_fp8_linear | ||
|
|
There was a problem hiding this comment.
The function _dispatch_auto_backend now unconditionally returns triton_w8a8_block_fp8_linear, making the subsequent backend selection logic for DeepGEMM, FlashInfer, etc., unreachable. This appears to be a temporary change for debugging and will break the intended automatic backend dispatching. This should be removed.
| #if HAS_FP8_SUPPORT | ||
| const float scale_0 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_0)); | ||
| const float scale_1 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_1)); | ||
| #else | ||
| const float scale_0 = 1.0f; | ||
| const float scale_1 = 1.0f; | ||
| #endif |
There was a problem hiding this comment.
The CUDA kernel nvfp4_dequant_vectorized_kernel has a fallback for when HAS_FP8_SUPPORT is false, which sets the scales to 1.0f. This will produce incorrect dequantization results silently on hardware that doesn't support FP8, as it ignores the block scales. The kernel should instead fail with an error or there should be a compile-time assertion if FP8 support is required but not available.
| #if HAS_FP8_SUPPORT | |
| const float scale_0 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_0)); | |
| const float scale_1 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_1)); | |
| #else | |
| const float scale_0 = 1.0f; | |
| const float scale_1 = 1.0f; | |
| #endif | |
| #if !HAS_FP8_SUPPORT | |
| #error "This kernel requires FP8 support, which is not available on this architecture." | |
| #endif | |
| const float scale_0 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_0)); | |
| const float scale_1 = static_cast<float>(*reinterpret_cast<const __nv_fp8_e4m3*>(&scale_fp8_1)); |
| #else | ||
| return 0; | ||
| #endif |
There was a problem hiding this comment.
The fp32_vec_to_e2m1 device function returns 0 if __CUDA_ARCH__ < 1000. This will lead to incorrect quantization (all zeros) on older architectures without any warning. Since this function is critical for the SM100+ quantization kernel, it should assert that the architecture is supported.
| #else | |
| return 0; | |
| #endif | |
| #else | |
| static_assert(__CUDA_ARCH__ >= 1000, "This function requires SM100 or newer architecture."); | |
| return 0; | |
| #endif |
| NVFP4 KV cache support is on the `nvfp4-kvcache-sm120-v2` branch. Clone from the fork and install in editable mode: | ||
|
|
||
| ```bash | ||
| git clone -b nvfp4-kvcache-sm120-v2 https://github.com/samuellees/sglang.git |
There was a problem hiding this comment.
The git clone command points to a personal fork (github.com/samuellees/sglang.git). For official documentation, this should be updated to point to the main project repository's branch before this pull request is merged.
| git clone -b nvfp4-kvcache-sm120-v2 https://github.com/samuellees/sglang.git | |
| git clone -b nvfp4-kvcache-sm120-v2 https://github.com/sgl-project/sglang.git |
| if sum(paged_seq_lens_cpu) > 0: | ||
| # [prefix_len, 256] -> [padded_prefix_len, 256] -> sum_tokens -> token_indices[page_size, ..., padde_prefix_len + 256 + page_size] | ||
| paged_seq_lens_cpu.append(256) | ||
| import numpy as np |
|
|
||
| # logger.debug(f"[KERNEL DEBUG] ====== End of Parameter Dump ======") | ||
|
|
||
| # TODO(Sam): NVFP4 kv cache is not supported or MTP. Because draft extend will invoke this api, it needs nvfp4 kv cache support. |
There was a problem hiding this comment.
This TODO comment appears to be outdated or confusing. It states that NVFP4 KV cache is not supported for MTP, but this pull request seems to add this support by routing draft_extend and target_verify through the XQA decode kernel. Please remove or update this comment to reflect the current implementation.
| if self.sm_version == 100: | ||
| k_scale *= 6.0 | ||
| v_scale *= 6.0 |
There was a problem hiding this comment.
The code applies a hardcoded scaling factor of 6.0 for SM100 GPUs. While the comment explains this is for hardware alignment, such "magic numbers" are hard to maintain. It would be better to define this as a named constant with a more detailed explanation, and ideally a reference to NVIDIA's documentation if available.
| if self.sm_version == 100: | |
| k_scale *= 6.0 | |
| v_scale *= 6.0 | |
| # SM100 requires a 6x adjustment to align FP4 range with hardware expectations. | |
| # See [link to NVIDIA doc or further explanation if possible]. | |
| SM100_FP4_SCALE_ADJUSTMENT = 6.0 | |
| if self.sm_version == 100: | |
| k_scale *= SM100_FP4_SCALE_ADJUSTMENT | |
| v_scale *= SM100_FP4_SCALE_ADJUSTMENT |
- Enable NVFP4 KV Cache for SM100 (B200) and SM120 (RTX PRO 6000) - Extract KVCacheQuantMethod ABC with NoneMethod/NVFP4Method/MXFP4Method subclasses - Migrate quantize kernel to flashinfer fp4_quantize, keep custom CUDA dequant - Add MTP (speculative decoding) support for NVFP4/FP8 KV cache - XQA decode kernel integration with proper scale factor handling - Centralize FP4 buffer creation, quantize/dequant, and scale management
a5dd741 to
4473ed7
Compare
samuellees
left a comment
There was a problem hiding this comment.
Review for the first round
Rebased from #18314
Summary
Add NVFP4 (FP4 E2M1) KV cache quantization support for Blackwell GPUs, reducing KV cache memory by ~2x compared to FP8 with no accuracy loss on GSM8K.
Key changes
KVCacheQuantMethodABC withNoneMethod,NVFP4Method,MXFP4Methodsubclasses (kv_cache_quant_method.py). Adding a new FP4 scheme only requiresimplementing one subclass and registering it.
alongside FP4 KV data.
fp4_quantize: Replaces custom JIT CUDA kernel with flashinfer's optimizedimplementation.
target_verify/draft_extendroute through XQA decode kernelwith causal masking (
--speculative-attention-mode decode).models (e.g., Qwen3.5-35B-A3B).
Usage
python3 -m sglang.launch_server \ --model-path <model_path> \ --kv-cache-dtype fp4_e2m1 \ --prefill-attention-backend flashinfer \ --decode-attention-backend trtllm_mha \ --disable-radix-cache Benchmark (Qwen3.5-35B-A3B, GSM8K 100q, Blackwell) ┌────────────────┬─────────────┬──────────┬────────────────────┐ │ KV Cache │ MTP │ Accuracy │ Throughput │ ├────────────────┼─────────────┼──────────┼────────────────────┤ │ FP8 (fp8_e4m3) │ No │ 91.0% │ 350.6 tok/s │ ├────────────────┼─────────────┼──────────┼────────────────────┤ │ FP4 (fp4_e2m1) │ No │ 91.0% │ 452.1 tok/s (+29%) │ ├────────────────┼─────────────┼──────────┼────────────────────┤ │ FP8 (fp8_e4m3) │ draft_len=3 │ 96.0% │ 652.2 tok/s │ ├────────────────┼─────────────┼──────────┼────────────────────┤ │ FP4 (fp4_e2m1) │ draft_len=3 │ 94.0% │ 891.3 tok/s (+37%) │ └────────────────┴─────────────┴──────────┴────────────────────┘ FP4 achieves ~29-37% throughput improvement over FP8 with no accuracy degradation. Requirements - Blackwell GPU (SM120) - CUDA 13.0+, PyTorch 2.9.1+ - FlashInfer >= 0.6.3 (built from source) Changed files ┌────────────────────────────────┬─────────────────────────────────────────────────────────────────────┐ │ File │ Change │ ├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤ │ kv_cache_quant_method.py │ New — Strategy pattern ABC + NVFP4/MXFP4 subclasses │ ├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤ │ kvfp4_tensor.py │ FP4 quantize/dequantize kernels, flashinfer wrapper │ ├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤ │ trtllm_mha_backend.py │ XQA decode for FP4, MTP target_verify/draft_extend with causal mask │ ├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤ │ flashinfer_backend.py │ NVFP4 dequant state init, FP4→FP8 prefill path │ ├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤ │ memory_pool.py │ Pool integration with quant_method, scale buffer management │ ├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤ │ model_runner_kv_cache_mixin.py │ NVFP4Method creation and scale loading │ ├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤ │ attention_registry.py │ Split prefill/decode backend validation for Blackwell │ ├────────────────────────────────┼─────────────────────────────────────────────────────────────────────┤ │ docs/nvfp4_kv_cache.md │ New — Documentation with usage, benchmarks, architecture │ └────────────────────────────────┴─────────────────────────────────────────────────────────────────────┘ Test plan - GSM8K accuracy check (1319 questions, FP4 vs FP8 baseline) - MTP + FP4 end-to-end serving test - Verify no regression on FP8 KV cache path - Verify BF16 KV cache path unaffected (NoneMethod) ## Checklist - [ ] Format your code according to the [Format code with pre-commit](https://docs.sglang.io/developer_guide/contribution_guide.html#format-code-with-pre-commit). - [ ] Add unit tests according to the [Run and add unit tests](https://docs.sglang.io/developer_guide/contribution_guide.html#run-and-add-unit-tests). - [ ] Update documentation according to [Write documentations](https://docs.sglang.io/developer_guide/contribution_guide.html#write-documentations). - [ ] Provide accuracy and speed benchmark results according to [Test the accuracy](https://docs.sglang.io/developer_guide/contribution_guide.html#test-the-accuracy) and [Benchmark the speed](https://docs.sglang.io/developer_guide/contribution_guide.html#benchmark-the-speed). - [ ] Follow the SGLang code style [guidance](https://docs.sglang.io/developer_guide/contribution_guide.html#code-style-guidance). ## Review and Merge Process 1. Ping Merge Oncalls to start the process. See the [PR Merge Process](https://github.com/sgl-project/sglang/blob/main/.github/MAINTAINER.md#pull-request-merge-process). 2. Get approvals from [CODEOWNERS](https://github.com/sgl-project/sglang/blob/main/.github/CODEOWNERS) and other reviewers. 3. Trigger CI tests with [comments](https://docs.sglang.io/developer_guide/contribution_guide.html#how-to-trigger-ci-tests) or contact authorized users to do so. - Common commands include `/tag-and-rerun-ci`, `/tag-run-ci-label`, `/rerun-failed-ci` 4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.