Skip to content

NVFP4 KV Cache for SM120#18314

Closed
samuellees wants to merge 11 commits intosgl-project:mainfrom
samuellees:nvfp4-kvcache-sm120
Closed

NVFP4 KV Cache for SM120#18314
samuellees wants to merge 11 commits intosgl-project:mainfrom
samuellees:nvfp4-kvcache-sm120

Conversation

@samuellees
Copy link
Copy Markdown
Contributor

@samuellees samuellees commented Feb 5, 2026

Rebased and moved to #21601

Motivation

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match|↑  |0.7157|±  |0.0124|
|     |       |strict-match    |     8|exact_match|↑  |0.2547|±  |0.0120|
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match|↑  |0.7657|±  |0.0117|
|     |       |strict-match    |     8|exact_match|↑  |0.3654|±  |0.0133|
Add wrap for sm100 rescale;
Clean code
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     8|exact_match|↑  |0.7513|±  |0.0119|
|     |       |strict-match    |     8|exact_match|↑  |0.3381|±  |0.0130|
@github-actions github-actions bot added the blackwell SM100/SM120 label Feb 5, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @samuellees, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the SGLang framework by introducing and optimizing support for NVIDIA's NVFP4 Key-Value (KV) cache, particularly targeting the SM120 (Blackwell) and SM100 architectures. The changes involve implementing specialized CUDA kernels for efficient quantization and dequantization, integrating these into existing attention mechanisms like FlashInfer and TRT-LLM MHA, and updating the memory management system to accommodate the new KV cache format. This work aims to improve performance and memory efficiency for models running on compatible NVIDIA GPUs.

Highlights

  • NVFP4 KV Cache Support: Introduced comprehensive support for NVIDIA's NVFP4 Key-Value (KV) cache, including specialized data structures and logic for handling this quantized format across different attention backends.
  • Optimized Quantization/Dequantization Kernels: Implemented highly optimized CUDA kernels for NVFP4 quantization and dequantization (cuda_nvfp4_dequantize, cuda_nvfp4_quantize_blackwell) within kvfp4_tensor.py, specifically tailored for SM100 (Blackwell) and SM120 architectures to enhance performance.
  • FlashInfer Backend Integration: Modified the flashinfer_backend.py to seamlessly integrate NVFP4 KV cache during both prefill and extend attention stages, including logic for preloading K/V scales and preparing dequantization metadata.
  • TRT-LLM MHA Backend Integration: Extended the trtllm_mha_backend.py to support NVFP4 KV cache for decode and extend operations, incorporating specific handling for SM100/SM120 GPUs and adjusting scaling factors as required by the TRT-LLM kernel.
  • Memory Pool Enhancements: Added a new MHATokenToKVPoolNVFP4 class in memory_pool.py to manage NVFP4 KV cache buffers and their associated FP8 scales, and updated the memory pool selection logic to utilize this new class when NVFP4 is enabled.
  • Refined Backend Compatibility Checks: Improved the server_args.py to include more granular compatibility checks for the trtllm_mha attention backend, differentiating support for prefill (SM100 only) and decode (SM90, SM100, SM120) operations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/attention/flashinfer_backend.py
    • Imported ModelConfig and added _preload_kv_scales for NVFP4 KV cache.
    • Introduced is_nvfp4_kvcache and kv_cache_dtype_alias properties.
    • Initialized k_scales_gpu and v_scales_gpu for NVFP4 KV Cache.
    • Added _prepare_nvfp4_metadata_for_extend_base to manage dequantization page tables.
    • Modified init_forward_metadata to incorporate NVFP4 metadata preparation.
    • Implemented _dequant_nvfp4_kv_for_extend_base for dequantizing KV cache during extend.
    • Updated forward_extend to handle NVFP4 KV cache, including assertions for cross-attention and using dequantized buffers.
    • Adjusted data_type in PrefillWrapperPaged and DecodeWrapperPaged to use the new alias.
    • Modified update_single_wrapper and call_begin_forward to accept custom_kv_indices.
  • python/sglang/srt/layers/attention/trtllm_mha_backend.py
    • Imported SM architecture support utilities (is_sm90_supported, is_sm100_supported, is_sm120_supported).
    • Added prefix_lengths_kv_cpu and extend_lengths_kv_cpu to TRTLLMMHAMetadata.
    • Initialized NVFP4 related flags (is_xqa_impl, is_sm100_gpu, is_nvfp4_kvcache).
    • Added preload_kv_scales method for NVFP4 KV cache scales.
    • Updated init_forward_metadata to set NVFP4 specific length attributes.
    • Modified forward_decode to support NVFP4 KV cache, including scale adjustments and KV cache preparation.
    • Modified forward_extend for NVFP4 KV cache, incorporating dequantization logic and specific handling for different forward modes.
  • python/sglang/srt/layers/quantization/fp8_utils.py
    • Forced _dispatch_auto_backend to return triton_w8a8_block_fp8_linear.
  • python/sglang/srt/layers/quantization/kvfp4_tensor.py
    • Introduced FP4KVCacheRecipe enum for different FP4 formats.
    • Expanded E2M1_VALUES to include negative values for E2M1 format.
    • Added NVFP4QuantizeUtil class with fi_nvfp4_quantize, cuda_nvfp4_dequantize, cuda_nvfp4_quantize_blackwell, batched_quantize, and batched_dequantize methods.
    • Included inline CUDA C++ kernels for optimized NVFP4 dequantization and quantization for SM100/SM120 architectures.
  • python/sglang/srt/mem_cache/memory_pool.py
    • Imported NVFP4QuantizeUtil and SM architecture support functions.
    • Introduced MHATokenToKVPoolNVFP4 class for managing NVFP4 KV cache buffers and scales.
    • Overrode _create_buffers, _clear_buffers, _get_key_buffer, _get_value_buffer, and set_kv_buffer in MHATokenToKVPoolNVFP4 for NVFP4 specific logic.
    • Added get_fp4_value_buffer, get_fp4_key_buffer, and get_dq_kv_buffer methods.
    • Updated MHATokenToKVPoolFP4 to use NVFP4QuantizeUtil.
    • Modified TokenToKVPool to select MHATokenToKVPoolNVFP4 when kv_cache_dtype is torch.float4_e2m1fn_x2.
  • python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py
    • Updated import from MHATokenToKVPoolFP4 to MHATokenToKVPoolNVFP4.
    • Modified init_memory_pool to instantiate MHATokenToKVPoolNVFP4 for torch.float4_e2m1fn_x2 KV cache dtype.
  • python/sglang/srt/server_args.py
    • Refined _handle_attention_backend_compatibility for trtllm_mha backend, adding separate checks for prefill (SM100 only) and decode (SM90, SM100, SM120) support.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for NVFP4 KV cache, primarily targeting SM120 architecture. The changes are extensive, adding new quantization/dequantization utilities, a dedicated memory pool class for NVFP4, and updating both the FlashInfer and TRT-LLM MHA attention backends to utilize this new format. While the implementation appears to be on the right track for enabling new hardware features, my review has identified several areas for improvement. There are significant code duplications between the two attention backends that should be refactored into shared utilities for better maintainability. I also found a critical issue that seems to disable optimized kernels, likely a temporary change that needs to be reverted, and a potential bug involving in-place list modification. Addressing these points will improve the robustness and quality of the code.

# 4. AITER (if AMD GPU with AITER enabled)
# 5. Triton (fallback)

return triton_w8a8_block_fp8_linear
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function _dispatch_auto_backend has been modified to unconditionally return triton_w8a8_block_fp8_linear. This bypasses the logic to select more optimized backends like DeepGEMM, FlashInfer, or CUTLASS, which will likely lead to a significant performance regression. This looks like a temporary change for debugging and should be removed before merging.

paged_seq_lens_cpu = forward_batch.seq_lens_cpu
if sum(paged_seq_lens_cpu) > 0:
# [prefix_len, 256] -> [padded_prefix_len, 256] -> sum_tokens -> token_indices[page_size, ..., padde_prefix_len + 256 + page_size]
paged_seq_lens_cpu.append(256)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code paged_seq_lens_cpu.append(256) modifies the paged_seq_lens_cpu list in-place. This list is a reference to a list within the forward_batch object, which could lead to unexpected side effects in other parts of the codebase that might not expect this modification. It is safer to create a new list for this operation to avoid modifying the original data structure.

Suggested change
paged_seq_lens_cpu.append(256)
paged_seq_lens_cpu_with_dummy = paged_seq_lens_cpu + [256]

Comment on lines +152 to +211
def preload_kv_scales(self, config, model_runner: ModelRunner):
if not self.is_nvfp4_kvcache:
return None, None
num_layers = config.num_hidden_layers
k_scales_cpu = torch.ones(num_layers, dtype=torch.float32, device="cpu")
v_scales_cpu = torch.ones(num_layers, dtype=torch.float32, device="cpu")

from sglang.srt.model_executor.model_runner import resolve_language_model

attention_layers = []
language_model = resolve_language_model(model_runner.model)
for layer in language_model.layers:
if hasattr(layer, "self_attn"):
if hasattr(layer.self_attn, "attn"):
attention_layers.append(layer.self_attn.attn)
elif hasattr(layer.self_attn, "attn_mqa"):
attention_layers.append(layer.self_attn.attn_mqa)
elif hasattr(layer, "attn"):
attention_layers.append(layer.attn)
elif hasattr(layer, "attention"):
if hasattr(layer.attention, "attn"):
attention_layers.append(layer.attention.attn)

# logger.info(f"Preloading k/v scales for {len(attention_layers)} layers to GPU")
for layer in attention_layers:
layer_id = layer.layer_id
if layer_id >= len(v_scales_cpu):
continue

# prepare k/v global scale
if not hasattr(layer, "k_scale") or layer.k_scale is None:
k_scale = 1.0
else:
k_scale = layer.k_scale
if not hasattr(layer, "v_scale") or layer.v_scale is None:
v_scale = 1.0
else:
v_scale = layer.v_scale

if self.is_sm100_gpu:
k_scale = k_scale * 6.0
v_scale = v_scale * 6.0

k_scales_cpu[layer_id] = k_scale
v_scales_cpu[layer_id] = v_scale

# 一次性拷贝到 GPU
k_scales_gpu = torch.ones(
num_layers, dtype=torch.float32, device=model_runner.device
)
v_scales_gpu = torch.ones(
num_layers, dtype=torch.float32, device=model_runner.device
)
k_scales_gpu.copy_(k_scales_cpu, non_blocking=True)
v_scales_gpu.copy_(v_scales_cpu, non_blocking=True)
# logger.info(f"{k_scales_gpu=}, {v_scales_gpu=}")
# import sys
# sys.stdout.flush()
return k_scales_gpu, v_scales_gpu

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This preload_kv_scales method is nearly identical to the _preload_kv_scales function in python/sglang/srt/layers/attention/flashinfer_backend.py. To improve maintainability and reduce redundancy, this logic should be extracted into a shared utility function that both backends can call.

Comment on lines +882 to +980
if (
self.is_nvfp4_kvcache
and forward_batch.forward_mode.is_extend_without_speculative()
):
# path:
# 1. nvfp4, target model prefill/chunkedprefill, w/o cudagraph, is_extend_without_speculative(), context mha kernel
# 2. nvfp4, draft model prefill/chunkedprefill, w/o cudagraph, is_extend_without_speculative(), context mha kernel
from sglang.srt.layers.quantization.fp4_utils import NVFP4QuantizeUtil

batch_size = forward_batch.batch_size

k_buffer_nvfp4, k_scales_buffer = (
forward_batch.token_to_kv_pool.get_fp4_key_buffer(layer.layer_id)
)
v_buffer_nvfp4, v_scales_buffer = (
forward_batch.token_to_kv_pool.get_fp4_value_buffer(layer.layer_id)
)
k_buffer_dq, v_buffer_dq = forward_batch.token_to_kv_pool.get_dq_kv_buffer()

# Convert current k/v to fp8 once
k_cur_fp8 = k.to(torch.float8_e4m3fn)
v_cur_fp8 = v.to(torch.float8_e4m3fn)

# Process each request in batch
cur_batch_start_loc_cpu = 0
# skip first page for dummy output
cur_token_idx_dq_buffer_cpu = self.page_size
for batch_idx in range(batch_size):
req_pool_idx = self.cpu_req_pool_indices[batch_idx]
prev_len = forward_batch.extend_prefix_lens_cpu[batch_idx]
extend_len = forward_batch.extend_seq_lens_cpu[batch_idx]
# prev_len = self.prefix_lengths_kv_cpu[batch_idx]
# extend_len = self.extend_lengths_kv_cpu[batch_idx]

# Dequantize and copy previous KV
if prev_len > 0:
prev_token_indices = forward_batch.req_to_token_pool.req_to_token[
req_pool_idx, :prev_len
]
k_prev_nvfp4 = k_buffer_nvfp4[prev_token_indices]
k_prev_scales = k_scales_buffer[prev_token_indices]
v_prev_nvfp4 = v_buffer_nvfp4[prev_token_indices]
v_prev_scales = v_scales_buffer[prev_token_indices]

# Dequantize: [prev_len, num_heads, head_dim]
k_prev_bf16 = NVFP4QuantizeUtil.cuda_nvfp4_dequantize(
k_prev_nvfp4.view(torch.uint8),
k_prev_scales,
cur_k_scale_gpu,
)
v_prev_bf16 = NVFP4QuantizeUtil.cuda_nvfp4_dequantize(
v_prev_nvfp4.view(torch.uint8),
v_prev_scales,
cur_v_scale_gpu,
)
k_prev_fp8 = k_prev_bf16.to(torch.float8_e4m3fn)
v_prev_fp8 = v_prev_bf16.to(torch.float8_e4m3fn)

# Direct continuous copy
k_buffer_dq[
cur_token_idx_dq_buffer_cpu : cur_token_idx_dq_buffer_cpu
+ prev_len
] = k_prev_fp8
v_buffer_dq[
cur_token_idx_dq_buffer_cpu : cur_token_idx_dq_buffer_cpu
+ prev_len
] = v_prev_fp8

# Write of current chunk
cur_end = cur_batch_start_loc_cpu + extend_len
k_cur_chunk = k_cur_fp8[cur_batch_start_loc_cpu:cur_end]
v_cur_chunk = v_cur_fp8[cur_batch_start_loc_cpu:cur_end]
k_buffer_dq[
cur_token_idx_dq_buffer_cpu
+ prev_len : cur_token_idx_dq_buffer_cpu
+ prev_len
+ extend_len
] = k_cur_chunk
v_buffer_dq[
cur_token_idx_dq_buffer_cpu
+ prev_len : cur_token_idx_dq_buffer_cpu
+ prev_len
+ extend_len
] = v_cur_chunk

cur_batch_start_loc_cpu = cur_end
# align to page size
cur_token_idx_dq_buffer_cpu = (
(
cur_token_idx_dq_buffer_cpu
+ prev_len
+ extend_len
+ self.page_size
- 1
)
// self.page_size
* self.page_size
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The NVFP4 dequantization logic within this forward_extend method appears to be a direct copy of the _dequant_nvfp4_kv_for_extend_base function from flashinfer_backend.py. This significant code duplication should be refactored. Consider moving this logic to a shared utility or a common base class to avoid redundancy and make future maintenance easier.

if sum(paged_seq_lens_cpu) > 0:
# [prefix_len, 256] -> [padded_prefix_len, 256] -> sum_tokens -> token_indices[page_size, ..., padde_prefix_len + 256 + page_size]
paged_seq_lens_cpu.append(256)
import numpy as np
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The numpy import is located inside a method. It is a best practice to place all imports at the top of the file for better readability, consistency, and to avoid potential overhead from repeated imports.

Comment on lines 1104 to +1123
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
if self.is_nvfp4_kvcache:
cur_k_scale_gpu = self.k_scales_gpu[
layer.layer_id : layer.layer_id + 1
]
cur_v_scale_gpu = self.v_scales_gpu[
layer.layer_id : layer.layer_id + 1
]
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
cur_k_scale_gpu,
cur_v_scale_gpu,
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for saving the KV cache is a duplicate of the logic found earlier in this method (lines 1000-1020). To improve maintainability and reduce redundancy, this logic should be extracted into a private helper method.

@samuellees
Copy link
Copy Markdown
Contributor Author

Rebased and moved to #21601

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blackwell SM100/SM120

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant