Conversation
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 8|exact_match|↑ |0.7157|± |0.0124| | | |strict-match | 8|exact_match|↑ |0.2547|± |0.0120|
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 8|exact_match|↑ |0.7657|± |0.0117| | | |strict-match | 8|exact_match|↑ |0.3654|± |0.0133|
Add wrap for sm100 rescale; Clean code |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 8|exact_match|↑ |0.7513|± |0.0119| | | |strict-match | 8|exact_match|↑ |0.3381|± |0.0130|
Summary of ChangesHello @samuellees, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the SGLang framework by introducing and optimizing support for NVIDIA's NVFP4 Key-Value (KV) cache, particularly targeting the SM120 (Blackwell) and SM100 architectures. The changes involve implementing specialized CUDA kernels for efficient quantization and dequantization, integrating these into existing attention mechanisms like FlashInfer and TRT-LLM MHA, and updating the memory management system to accommodate the new KV cache format. This work aims to improve performance and memory efficiency for models running on compatible NVIDIA GPUs. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for NVFP4 KV cache, primarily targeting SM120 architecture. The changes are extensive, adding new quantization/dequantization utilities, a dedicated memory pool class for NVFP4, and updating both the FlashInfer and TRT-LLM MHA attention backends to utilize this new format. While the implementation appears to be on the right track for enabling new hardware features, my review has identified several areas for improvement. There are significant code duplications between the two attention backends that should be refactored into shared utilities for better maintainability. I also found a critical issue that seems to disable optimized kernels, likely a temporary change that needs to be reverted, and a potential bug involving in-place list modification. Addressing these points will improve the robustness and quality of the code.
| # 4. AITER (if AMD GPU with AITER enabled) | ||
| # 5. Triton (fallback) | ||
|
|
||
| return triton_w8a8_block_fp8_linear |
There was a problem hiding this comment.
The function _dispatch_auto_backend has been modified to unconditionally return triton_w8a8_block_fp8_linear. This bypasses the logic to select more optimized backends like DeepGEMM, FlashInfer, or CUTLASS, which will likely lead to a significant performance regression. This looks like a temporary change for debugging and should be removed before merging.
| paged_seq_lens_cpu = forward_batch.seq_lens_cpu | ||
| if sum(paged_seq_lens_cpu) > 0: | ||
| # [prefix_len, 256] -> [padded_prefix_len, 256] -> sum_tokens -> token_indices[page_size, ..., padde_prefix_len + 256 + page_size] | ||
| paged_seq_lens_cpu.append(256) |
There was a problem hiding this comment.
The code paged_seq_lens_cpu.append(256) modifies the paged_seq_lens_cpu list in-place. This list is a reference to a list within the forward_batch object, which could lead to unexpected side effects in other parts of the codebase that might not expect this modification. It is safer to create a new list for this operation to avoid modifying the original data structure.
| paged_seq_lens_cpu.append(256) | |
| paged_seq_lens_cpu_with_dummy = paged_seq_lens_cpu + [256] |
| def preload_kv_scales(self, config, model_runner: ModelRunner): | ||
| if not self.is_nvfp4_kvcache: | ||
| return None, None | ||
| num_layers = config.num_hidden_layers | ||
| k_scales_cpu = torch.ones(num_layers, dtype=torch.float32, device="cpu") | ||
| v_scales_cpu = torch.ones(num_layers, dtype=torch.float32, device="cpu") | ||
|
|
||
| from sglang.srt.model_executor.model_runner import resolve_language_model | ||
|
|
||
| attention_layers = [] | ||
| language_model = resolve_language_model(model_runner.model) | ||
| for layer in language_model.layers: | ||
| if hasattr(layer, "self_attn"): | ||
| if hasattr(layer.self_attn, "attn"): | ||
| attention_layers.append(layer.self_attn.attn) | ||
| elif hasattr(layer.self_attn, "attn_mqa"): | ||
| attention_layers.append(layer.self_attn.attn_mqa) | ||
| elif hasattr(layer, "attn"): | ||
| attention_layers.append(layer.attn) | ||
| elif hasattr(layer, "attention"): | ||
| if hasattr(layer.attention, "attn"): | ||
| attention_layers.append(layer.attention.attn) | ||
|
|
||
| # logger.info(f"Preloading k/v scales for {len(attention_layers)} layers to GPU") | ||
| for layer in attention_layers: | ||
| layer_id = layer.layer_id | ||
| if layer_id >= len(v_scales_cpu): | ||
| continue | ||
|
|
||
| # prepare k/v global scale | ||
| if not hasattr(layer, "k_scale") or layer.k_scale is None: | ||
| k_scale = 1.0 | ||
| else: | ||
| k_scale = layer.k_scale | ||
| if not hasattr(layer, "v_scale") or layer.v_scale is None: | ||
| v_scale = 1.0 | ||
| else: | ||
| v_scale = layer.v_scale | ||
|
|
||
| if self.is_sm100_gpu: | ||
| k_scale = k_scale * 6.0 | ||
| v_scale = v_scale * 6.0 | ||
|
|
||
| k_scales_cpu[layer_id] = k_scale | ||
| v_scales_cpu[layer_id] = v_scale | ||
|
|
||
| # 一次性拷贝到 GPU | ||
| k_scales_gpu = torch.ones( | ||
| num_layers, dtype=torch.float32, device=model_runner.device | ||
| ) | ||
| v_scales_gpu = torch.ones( | ||
| num_layers, dtype=torch.float32, device=model_runner.device | ||
| ) | ||
| k_scales_gpu.copy_(k_scales_cpu, non_blocking=True) | ||
| v_scales_gpu.copy_(v_scales_cpu, non_blocking=True) | ||
| # logger.info(f"{k_scales_gpu=}, {v_scales_gpu=}") | ||
| # import sys | ||
| # sys.stdout.flush() | ||
| return k_scales_gpu, v_scales_gpu | ||
|
|
There was a problem hiding this comment.
| if ( | ||
| self.is_nvfp4_kvcache | ||
| and forward_batch.forward_mode.is_extend_without_speculative() | ||
| ): | ||
| # path: | ||
| # 1. nvfp4, target model prefill/chunkedprefill, w/o cudagraph, is_extend_without_speculative(), context mha kernel | ||
| # 2. nvfp4, draft model prefill/chunkedprefill, w/o cudagraph, is_extend_without_speculative(), context mha kernel | ||
| from sglang.srt.layers.quantization.fp4_utils import NVFP4QuantizeUtil | ||
|
|
||
| batch_size = forward_batch.batch_size | ||
|
|
||
| k_buffer_nvfp4, k_scales_buffer = ( | ||
| forward_batch.token_to_kv_pool.get_fp4_key_buffer(layer.layer_id) | ||
| ) | ||
| v_buffer_nvfp4, v_scales_buffer = ( | ||
| forward_batch.token_to_kv_pool.get_fp4_value_buffer(layer.layer_id) | ||
| ) | ||
| k_buffer_dq, v_buffer_dq = forward_batch.token_to_kv_pool.get_dq_kv_buffer() | ||
|
|
||
| # Convert current k/v to fp8 once | ||
| k_cur_fp8 = k.to(torch.float8_e4m3fn) | ||
| v_cur_fp8 = v.to(torch.float8_e4m3fn) | ||
|
|
||
| # Process each request in batch | ||
| cur_batch_start_loc_cpu = 0 | ||
| # skip first page for dummy output | ||
| cur_token_idx_dq_buffer_cpu = self.page_size | ||
| for batch_idx in range(batch_size): | ||
| req_pool_idx = self.cpu_req_pool_indices[batch_idx] | ||
| prev_len = forward_batch.extend_prefix_lens_cpu[batch_idx] | ||
| extend_len = forward_batch.extend_seq_lens_cpu[batch_idx] | ||
| # prev_len = self.prefix_lengths_kv_cpu[batch_idx] | ||
| # extend_len = self.extend_lengths_kv_cpu[batch_idx] | ||
|
|
||
| # Dequantize and copy previous KV | ||
| if prev_len > 0: | ||
| prev_token_indices = forward_batch.req_to_token_pool.req_to_token[ | ||
| req_pool_idx, :prev_len | ||
| ] | ||
| k_prev_nvfp4 = k_buffer_nvfp4[prev_token_indices] | ||
| k_prev_scales = k_scales_buffer[prev_token_indices] | ||
| v_prev_nvfp4 = v_buffer_nvfp4[prev_token_indices] | ||
| v_prev_scales = v_scales_buffer[prev_token_indices] | ||
|
|
||
| # Dequantize: [prev_len, num_heads, head_dim] | ||
| k_prev_bf16 = NVFP4QuantizeUtil.cuda_nvfp4_dequantize( | ||
| k_prev_nvfp4.view(torch.uint8), | ||
| k_prev_scales, | ||
| cur_k_scale_gpu, | ||
| ) | ||
| v_prev_bf16 = NVFP4QuantizeUtil.cuda_nvfp4_dequantize( | ||
| v_prev_nvfp4.view(torch.uint8), | ||
| v_prev_scales, | ||
| cur_v_scale_gpu, | ||
| ) | ||
| k_prev_fp8 = k_prev_bf16.to(torch.float8_e4m3fn) | ||
| v_prev_fp8 = v_prev_bf16.to(torch.float8_e4m3fn) | ||
|
|
||
| # Direct continuous copy | ||
| k_buffer_dq[ | ||
| cur_token_idx_dq_buffer_cpu : cur_token_idx_dq_buffer_cpu | ||
| + prev_len | ||
| ] = k_prev_fp8 | ||
| v_buffer_dq[ | ||
| cur_token_idx_dq_buffer_cpu : cur_token_idx_dq_buffer_cpu | ||
| + prev_len | ||
| ] = v_prev_fp8 | ||
|
|
||
| # Write of current chunk | ||
| cur_end = cur_batch_start_loc_cpu + extend_len | ||
| k_cur_chunk = k_cur_fp8[cur_batch_start_loc_cpu:cur_end] | ||
| v_cur_chunk = v_cur_fp8[cur_batch_start_loc_cpu:cur_end] | ||
| k_buffer_dq[ | ||
| cur_token_idx_dq_buffer_cpu | ||
| + prev_len : cur_token_idx_dq_buffer_cpu | ||
| + prev_len | ||
| + extend_len | ||
| ] = k_cur_chunk | ||
| v_buffer_dq[ | ||
| cur_token_idx_dq_buffer_cpu | ||
| + prev_len : cur_token_idx_dq_buffer_cpu | ||
| + prev_len | ||
| + extend_len | ||
| ] = v_cur_chunk | ||
|
|
||
| cur_batch_start_loc_cpu = cur_end | ||
| # align to page size | ||
| cur_token_idx_dq_buffer_cpu = ( | ||
| ( | ||
| cur_token_idx_dq_buffer_cpu | ||
| + prev_len | ||
| + extend_len | ||
| + self.page_size | ||
| - 1 | ||
| ) | ||
| // self.page_size | ||
| * self.page_size | ||
| ) | ||
|
|
There was a problem hiding this comment.
The NVFP4 dequantization logic within this forward_extend method appears to be a direct copy of the _dequant_nvfp4_kv_for_extend_base function from flashinfer_backend.py. This significant code duplication should be refactored. Consider moving this logic to a shared utility or a common base class to avoid redundancy and make future maintenance easier.
| if sum(paged_seq_lens_cpu) > 0: | ||
| # [prefix_len, 256] -> [padded_prefix_len, 256] -> sum_tokens -> token_indices[page_size, ..., padde_prefix_len + 256 + page_size] | ||
| paged_seq_lens_cpu.append(256) | ||
| import numpy as np |
| if save_kv_cache: | ||
| forward_batch.token_to_kv_pool.set_kv_buffer( | ||
| layer, cache_loc, k, v, layer.k_scale, layer.v_scale | ||
| ) | ||
| if self.is_nvfp4_kvcache: | ||
| cur_k_scale_gpu = self.k_scales_gpu[ | ||
| layer.layer_id : layer.layer_id + 1 | ||
| ] | ||
| cur_v_scale_gpu = self.v_scales_gpu[ | ||
| layer.layer_id : layer.layer_id + 1 | ||
| ] | ||
| forward_batch.token_to_kv_pool.set_kv_buffer( | ||
| layer, | ||
| cache_loc, | ||
| k, | ||
| v, | ||
| cur_k_scale_gpu, | ||
| cur_v_scale_gpu, | ||
| ) | ||
| else: | ||
| forward_batch.token_to_kv_pool.set_kv_buffer( | ||
| layer, cache_loc, k, v, layer.k_scale, layer.v_scale | ||
| ) |
|
Rebased and moved to #21601 |
Rebased and moved to #21601
Motivation
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci