diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 0b76b3556d9d..37c8c3a96855 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -16,6 +16,7 @@ import torch +from sglang.srt.configs.model_config import ModelConfig from sglang.srt.dllm.config import DllmConfig from sglang.srt.environ import envs from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -109,6 +110,62 @@ class PrefillMetadata: global_override_indptr_cpu = None +def _preload_kv_scales( + config: ModelConfig, model_runner: ModelRunner, is_sm100_supported: bool +): + num_layers = config.num_hidden_layers + k_scales_cpu = torch.ones(num_layers, dtype=torch.float32, device="cpu") + v_scales_cpu = torch.ones(num_layers, dtype=torch.float32, device="cpu") + + from sglang.srt.model_executor.model_runner import resolve_language_model + + attention_layers = [] + language_model = resolve_language_model(model_runner.model) + for layer in language_model.layers: + if hasattr(layer, "self_attn"): + if hasattr(layer.self_attn, "attn"): + attention_layers.append(layer.self_attn.attn) + elif hasattr(layer.self_attn, "attn_mqa"): + attention_layers.append(layer.self_attn.attn_mqa) + elif hasattr(layer, "attn"): + attention_layers.append(layer.attn) + elif hasattr(layer, "attention"): + if hasattr(layer.attention, "attn"): + attention_layers.append(layer.attention.attn) + + for layer in attention_layers: + layer_id = layer.layer_id + if layer_id >= len(v_scales_cpu): + continue + + # prepare k/v global scale + if not hasattr(layer, "k_scale") or layer.k_scale is None: + k_scale = 1.0 + else: + k_scale = layer.k_scale + if not hasattr(layer, "v_scale") or layer.v_scale is None: + v_scale = 1.0 + else: + v_scale = layer.v_scale + + if is_sm100_supported: + k_scale = k_scale * 6.0 + v_scale = v_scale * 6.0 + + k_scales_cpu[layer_id] = k_scale + v_scales_cpu[layer_id] = v_scale + + k_scales_gpu = torch.ones( + num_layers, dtype=torch.float32, device=model_runner.device + ) + v_scales_gpu = torch.ones( + num_layers, dtype=torch.float32, device=model_runner.device + ) + k_scales_gpu.copy_(k_scales_cpu, non_blocking=True) + v_scales_gpu.copy_(v_scales_cpu, non_blocking=True) + return k_scales_gpu, v_scales_gpu + + class FlashInferAttnBackend(AttentionBackend): """Flashinfer attention kernels.""" @@ -131,9 +188,18 @@ def __init__( self.dllm_config = DllmConfig.from_server_args(model_runner.server_args) self.is_dllm_model = self.dllm_config is not None + self.is_nvfp4_kvcache = model_runner.kv_cache_dtype == torch.float4_e2m1fn_x2 + # For nvfp4 kv cache, we don't use flashinfer-decode, and we use fp8 kv cache for prefill, so set kv_cache_dtype_alias + # to pass the checks + self.kv_cache_dtype_alias = ( + model_runner.kv_cache_dtype + if not self.is_nvfp4_kvcache + else torch.float8_e4m3fn + ) + # Parse constants self.decode_use_tensor_cores = should_use_tensor_core( - kv_cache_dtype=model_runner.kv_cache_dtype, + kv_cache_dtype=self.kv_cache_dtype_alias, num_attention_heads=model_runner.model_config.num_attention_heads // get_attention_tp_size(), num_kv_heads=model_runner.model_config.get_num_kv_heads( @@ -295,6 +361,14 @@ def __init__( self.prefill_cuda_graph_metadata = {} # For verify self.draft_extend_cuda_graph_metadata = {} # For draft extend + # NVFP4 KV Cache + if self.is_nvfp4_kvcache: + self.is_sm100_supported = is_sm100_supported() + self.k_scales_gpu, self.v_scales_gpu = _preload_kv_scales( + model_runner.model_config, model_runner, self.is_sm100_supported + ) + self.page_size = model_runner.page_size + def _process_multi_item_scoring( self, forward_batch: ForwardBatch ) -> MultiItemScoringParams: @@ -417,6 +491,43 @@ def _process_multi_item_scoring( ), ) + def _prepare_nvfp4_metadata_for_extend_base( + self, forward_batch: ForwardBatch, use_ragged: bool = False + ): + """This must be called after init_forward_metadata/capture_cuda_graph/replay_cuda_graph""" + # construct nvfp4 kv cache dequant page table for extend stage + self.dq_page_table = None + self.cpu_req_pool_indices = None + if ( + self.is_nvfp4_kvcache + and forward_batch.forward_mode.is_extend_without_speculative() + ): + if use_ragged: + paged_seq_lens_cpu = forward_batch.extend_prefix_lens_cpu + else: + paged_seq_lens_cpu = forward_batch.seq_lens_cpu + if sum(paged_seq_lens_cpu) > 0: + # [prefix_len, 256] -> [padded_prefix_len, 256] -> sum_tokens -> token_indices[page_size, ..., padde_prefix_len + 256 + page_size] + paged_seq_lens_cpu.append(256) + import numpy as np + + paged_seq_lens_cpu = np.array(paged_seq_lens_cpu) + paged_seq_lens_cpu_padded = ( + (paged_seq_lens_cpu + self.page_size - 1) + // self.page_size + * self.page_size + ) + total_paged_tokens = sum(paged_seq_lens_cpu_padded) + self.dq_page_table = torch.arange( + self.page_size, + total_paged_tokens + self.page_size, + device=forward_batch.req_pool_indices.device, + dtype=torch.int32, + ) + self.cpu_req_pool_indices = forward_batch.req_pool_indices.to( + "cpu", non_blocking=True + ) + def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_decode_or_idle(): self.indices_updater_decode.update( @@ -484,6 +595,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): # Use new backend-specific implementation multi_item_params = self._process_multi_item_scoring(forward_batch) + self._prepare_nvfp4_metadata_for_extend_base(forward_batch, use_ragged) + self.indices_updater_prefill.update( forward_batch.req_pool_indices, forward_batch.seq_lens, @@ -496,6 +609,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): spec_info=None, fixed_split_size=self.prefill_split_tile_size, multi_item_params=multi_item_params, + custom_kv_indices=self.dq_page_table, ) self.forward_metadata = PrefillMetadata( self.prefill_wrappers_paged, @@ -503,6 +617,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): extend_no_prefix, multi_item_params, ) + # For none-ragged case, we transfer current chunk into dq kv table + self.transfer_cur_chunk_kv = not self.forward_metadata.use_ragged def init_cuda_graph_state( self, @@ -739,6 +855,108 @@ def init_forward_metadata_replay_cuda_graph( def get_cuda_graph_seq_len_fill_value(self): return 1 + def _dequant_nvfp4_kv_for_extend_base( + self, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + transfer_cur_chunk_kv: bool = True, + ): + + cur_k_scale_gpu = self.k_scales_gpu[layer.layer_id : layer.layer_id + 1] + cur_v_scale_gpu = self.v_scales_gpu[layer.layer_id : layer.layer_id + 1] + + from sglang.srt.layers.quantization.kvfp4_tensor import NVFP4QuantizeUtil + + batch_size = forward_batch.batch_size + + k_buffer_nvfp4, k_scales_buffer = ( + forward_batch.token_to_kv_pool.get_fp4_key_buffer(layer.layer_id) + ) + v_buffer_nvfp4, v_scales_buffer = ( + forward_batch.token_to_kv_pool.get_fp4_value_buffer(layer.layer_id) + ) + k_buffer_dq, v_buffer_dq = forward_batch.token_to_kv_pool.get_dq_kv_buffer() + + # Convert current k/v to fp8 once + if transfer_cur_chunk_kv: + k_cur_fp8 = k.to(torch.float8_e4m3fn) + v_cur_fp8 = v.to(torch.float8_e4m3fn) + + # Process each request in batch + cur_batch_start_loc_cpu = 0 + # skip first page for dummy output + cur_token_idx_dq_buffer_cpu = self.page_size + for batch_idx in range(batch_size): + req_pool_idx = self.cpu_req_pool_indices[batch_idx] + prev_len = forward_batch.extend_prefix_lens_cpu[batch_idx] + extend_len = forward_batch.extend_seq_lens_cpu[batch_idx] + + # Dequantize and copy previous KV + if prev_len > 0: + prev_token_indices = forward_batch.req_to_token_pool.req_to_token[ + req_pool_idx, :prev_len + ] + k_prev_nvfp4 = k_buffer_nvfp4[prev_token_indices] + k_prev_scales = k_scales_buffer[prev_token_indices] + v_prev_nvfp4 = v_buffer_nvfp4[prev_token_indices] + v_prev_scales = v_scales_buffer[prev_token_indices] + + # Dequantize: [prev_len, num_heads, head_dim] + k_prev_bf16 = NVFP4QuantizeUtil.cuda_nvfp4_dequantize( + k_prev_nvfp4.view(torch.uint8), + k_prev_scales, + cur_k_scale_gpu, + ) + v_prev_bf16 = NVFP4QuantizeUtil.cuda_nvfp4_dequantize( + v_prev_nvfp4.view(torch.uint8), + v_prev_scales, + cur_v_scale_gpu, + ) + k_prev_fp8 = k_prev_bf16.to(torch.float8_e4m3fn) + v_prev_fp8 = v_prev_bf16.to(torch.float8_e4m3fn) + + # Direct continuous copy + k_buffer_dq[ + cur_token_idx_dq_buffer_cpu : cur_token_idx_dq_buffer_cpu + prev_len + ] = k_prev_fp8 + v_buffer_dq[ + cur_token_idx_dq_buffer_cpu : cur_token_idx_dq_buffer_cpu + prev_len + ] = v_prev_fp8 + + # Write of current chunk + if transfer_cur_chunk_kv: + cur_end = cur_batch_start_loc_cpu + extend_len + k_cur_chunk = k_cur_fp8[cur_batch_start_loc_cpu:cur_end] + v_cur_chunk = v_cur_fp8[cur_batch_start_loc_cpu:cur_end] + k_buffer_dq[ + cur_token_idx_dq_buffer_cpu + + prev_len : cur_token_idx_dq_buffer_cpu + + prev_len + + extend_len + ] = k_cur_chunk + v_buffer_dq[ + cur_token_idx_dq_buffer_cpu + + prev_len : cur_token_idx_dq_buffer_cpu + + prev_len + + extend_len + ] = v_cur_chunk + cur_batch_start_loc_cpu = cur_end + + # align to page size + cur_token_idx_dq_buffer_cpu = ( + ( + cur_token_idx_dq_buffer_cpu + + prev_len + + extend_len + + self.page_size + - 1 + ) + // self.page_size + * self.page_size + ) + def forward_extend( self, q: torch.Tensor, @@ -760,17 +978,51 @@ def forward_extend( logits_soft_cap = layer.logit_cap q = q.contiguous() + + assert not ( + self.is_nvfp4_kvcache and layer.is_cross_attention + ), "NVFP4 dequant KV cache is not supported for cross-attention" + + # We perform dequant for chunk prefill/cache reuse. + if self.is_nvfp4_kvcache: + self._dequant_nvfp4_kv_for_extend_base( + k, v, layer, forward_batch, self.transfer_cur_chunk_kv + ) + k_buffer_dq, v_buffer_dq = forward_batch.token_to_kv_pool.get_dq_kv_buffer() + k_paged = k_buffer_dq.view(-1, layer.tp_k_head_num, layer.head_dim) + v_paged = v_buffer_dq.view(-1, layer.tp_v_head_num, layer.head_dim) + kv_cache = (k_paged, v_paged) + else: + kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + + # use paged attention if not self.forward_metadata.use_ragged: - if k is not None: + if k is not None and save_kv_cache: assert v is not None - if save_kv_cache: + if self.is_nvfp4_kvcache: + cur_k_scale_gpu = self.k_scales_gpu[ + layer.layer_id : layer.layer_id + 1 + ] + cur_v_scale_gpu = self.v_scales_gpu[ + layer.layer_id : layer.layer_id + 1 + ] + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + cache_loc, + k, + v, + cur_k_scale_gpu, + cur_v_scale_gpu, + ) + else: forward_batch.token_to_kv_pool.set_kv_buffer( layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) + # We need to process the paged part for nvfp4 kv cache o = prefill_wrapper_paged.forward( q.view(-1, layer.tp_q_head_num, layer.head_dim), - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + kv_cache, causal=not layer.is_cross_attention, sm_scale=layer.scaling, # Disable sliding window attention for multi-item scoring: @@ -798,6 +1050,9 @@ def forward_extend( # previously cached context without re-materializing KV tensors (e.g., the # IQuestLoopCoder path uses token_to_kv_pool as the KV source). if k is None and v is None: + assert ( + not self.is_nvfp4_kvcache + ), "KV cache must be provided for ragged attention when using NVFP4 kv cache" k = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)[0] v = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)[1] causal = True @@ -838,7 +1093,7 @@ def forward_extend( ) o2, s2 = prefill_wrapper_paged.forward_return_lse( q.view(-1, layer.tp_q_head_num, layer.head_dim), - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + kv_cache, causal=False, sm_scale=layer.scaling, logits_soft_cap=logits_soft_cap, @@ -847,9 +1102,25 @@ def forward_extend( o, _ = merge_state(o1, s1, o2, s2) if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale - ) + if self.is_nvfp4_kvcache: + cur_k_scale_gpu = self.k_scales_gpu[ + layer.layer_id : layer.layer_id + 1 + ] + cur_v_scale_gpu = self.v_scales_gpu[ + layer.layer_id : layer.layer_id + 1 + ] + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + cache_loc, + k, + v, + cur_k_scale_gpu, + cur_v_scale_gpu, + ) + else: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) @@ -913,7 +1184,7 @@ def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBacken get_attention_tp_size() ) self.head_dim = model_runner.model_config.head_dim - self.data_type = model_runner.kv_cache_dtype + self.data_type = attn_backend.kv_cache_dtype_alias self.q_data_type = model_runner.dtype self.sliding_window_size = model_runner.sliding_window_size self.attn_backend = attn_backend @@ -1173,7 +1444,7 @@ def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBacken get_attention_tp_size() ) self.head_dim = model_runner.model_config.head_dim - self.data_type = model_runner.kv_cache_dtype + self.data_type = attn_backend.kv_cache_dtype_alias self.q_data_type = model_runner.dtype self.sliding_window_size = model_runner.sliding_window_size self.attn_backend = attn_backend @@ -1224,6 +1495,7 @@ def update_single_wrapper( spec_info: Optional[SpecInput], fixed_split_size: Optional[int] = None, multi_item_params: Optional[MultiItemScoringParams] = None, + custom_kv_indices: Optional[torch.Tensor] = None, ): if use_ragged: # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu @@ -1249,6 +1521,7 @@ def update_single_wrapper( spec_info, fixed_split_size=fixed_split_size, multi_item_params=multi_item_params, + custom_kv_indices=custom_kv_indices, ) def update_sliding_window( @@ -1359,6 +1632,7 @@ def call_begin_forward( use_sliding_window_kv_pool: bool = False, fixed_split_size: Optional[int] = None, multi_item_params: Optional[MultiItemScoringParams] = None, + custom_kv_indices: Optional[torch.Tensor] = None, ): bs = len(seq_lens) if spec_info is None: @@ -1366,20 +1640,24 @@ def call_begin_forward( # Normal extend kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( - paged_kernel_lens_sum + 256, - dtype=torch.int32, - device=req_pool_indices.device, - ) - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - kv_start_idx, - kv_indices, - self.req_to_token.shape[1], - ) + + if custom_kv_indices is not None: + kv_indices = custom_kv_indices + else: + kv_indices = torch.empty( + paged_kernel_lens_sum + 256, + dtype=torch.int32, + device=req_pool_indices.device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + self.req_to_token.shape[1], + ) qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr = qo_indptr[: bs + 1] custom_mask = None @@ -1406,6 +1684,7 @@ def call_begin_forward( ) if use_sliding_window_kv_pool: + assert not custom_kv_indices, "custom_kv_indices incompatible with SWA" kv_last_index = kv_indptr[-1] kv_indices[:kv_last_index] = ( self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa( diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index aa418b7af669..6271fc5f8052 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -21,6 +21,11 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils.common import ( + is_sm90_supported, + is_sm100_supported, + is_sm120_supported, +) logger = logging.getLogger(__name__) @@ -55,6 +60,10 @@ class TRTLLMMHAMetadata: cu_seqlens_k: torch.Tensor = None # Page table, the index of KV Cache Tables/Blocks page_table: torch.Tensor = None + # Prefix length, for NVFP4 KV Cache + prefix_lengths_kv_cpu: torch.Tensor = None + # Extend lengths, for NVFP4 KV Cache + extend_lengths_kv_cpu: torch.Tensor = None class TRTLLMHAAttnBackend(FlashInferAttnBackend): @@ -122,6 +131,84 @@ def __init__( # Forward metadata self.forward_metadata: Optional[TRTLLMMHAMetadata] = None + # Init backend (XQA or TRTLLM-GEN) + # We need to specify q_type and out_type for different backend + # XQA: (q_type must be bf16) + # KV bf16: q_type = bf16, out_type=model_runner.dtype + # KV fp8: q_type = bf16, out_type=model_runner.dtype + # TRTLLM-GEN: + # KV bf16: q_type = bf16, out_type=model_runner.dtype + # KV fp8: q_type = fp8, out_type=model_runner.dtype + self.is_xqa_impl = is_sm90_supported() or is_sm120_supported() + + self.is_sm100_gpu = is_sm100_supported() + self.is_nvfp4_kvcache = self.data_type == torch.float4_e2m1fn_x2 + + # k/v scales on GPU tensor, used for NVFP4 KV Cache + self.k_scales_gpu, self.v_scales_gpu = self.preload_kv_scales( + config, model_runner + ) + + def preload_kv_scales(self, config, model_runner: ModelRunner): + if not self.is_nvfp4_kvcache: + return None, None + num_layers = config.num_hidden_layers + k_scales_cpu = torch.ones(num_layers, dtype=torch.float32, device="cpu") + v_scales_cpu = torch.ones(num_layers, dtype=torch.float32, device="cpu") + + from sglang.srt.model_executor.model_runner import resolve_language_model + + attention_layers = [] + language_model = resolve_language_model(model_runner.model) + for layer in language_model.layers: + if hasattr(layer, "self_attn"): + if hasattr(layer.self_attn, "attn"): + attention_layers.append(layer.self_attn.attn) + elif hasattr(layer.self_attn, "attn_mqa"): + attention_layers.append(layer.self_attn.attn_mqa) + elif hasattr(layer, "attn"): + attention_layers.append(layer.attn) + elif hasattr(layer, "attention"): + if hasattr(layer.attention, "attn"): + attention_layers.append(layer.attention.attn) + + # logger.info(f"Preloading k/v scales for {len(attention_layers)} layers to GPU") + for layer in attention_layers: + layer_id = layer.layer_id + if layer_id >= len(v_scales_cpu): + continue + + # prepare k/v global scale + if not hasattr(layer, "k_scale") or layer.k_scale is None: + k_scale = 1.0 + else: + k_scale = layer.k_scale + if not hasattr(layer, "v_scale") or layer.v_scale is None: + v_scale = 1.0 + else: + v_scale = layer.v_scale + + if self.is_sm100_gpu: + k_scale = k_scale * 6.0 + v_scale = v_scale * 6.0 + + k_scales_cpu[layer_id] = k_scale + v_scales_cpu[layer_id] = v_scale + + # 一次性拷贝到 GPU + k_scales_gpu = torch.ones( + num_layers, dtype=torch.float32, device=model_runner.device + ) + v_scales_gpu = torch.ones( + num_layers, dtype=torch.float32, device=model_runner.device + ) + k_scales_gpu.copy_(k_scales_cpu, non_blocking=True) + v_scales_gpu.copy_(v_scales_cpu, non_blocking=True) + # logger.info(f"{k_scales_gpu=}, {v_scales_gpu=}") + # import sys + # sys.stdout.flush() + return k_scales_gpu, v_scales_gpu + def init_cuda_graph_state( self, max_bs: int, @@ -318,7 +405,14 @@ def init_forward_metadata_capture_cuda_graph( metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :] self.draft_extend_metadata[bs] = metadata + + # if self.is_nvfp4_kvcache: + # self.extend_lengths_kv_cpu = torch.full( + # (bs,), num_tokens_per_bs, dtype=torch.int32, device="cpu") + # self.prefix_lengths_kv_cpu = seq_lens.to("cpu", non_blocking=True) - self.extend_lengths_kv_cpu + self.forward_metadata = metadata + # self._prepare_nvfp4_metadata_for_extend(forward_mode, req_pool_indices) def init_forward_metadata_replay_cuda_graph( self, @@ -421,7 +515,11 @@ def init_forward_metadata_replay_cuda_graph( self.draft_extend_metadata["strided_indices"][:max_seq_pages], ] metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size) + # if self.is_nvfp4_kvcache: + # self.extend_lengths_kv_cpu = accept_length.to("cpu", non_blocking=True) + # self.prefix_lengths_kv_cpu = seq_lens_cpu - self.extend_lengths_kv_cpu self.forward_metadata = metadata + # self._prepare_nvfp4_metadata_for_extend(forward_mode, req_pool_indices) def get_cuda_graph_seq_len_fill_value(self) -> int: """Get the fill value for sequence lengths in CUDA graph.""" @@ -552,6 +650,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): metadata.max_seq_len_q = metadata.max_seq_len_k metadata.cu_seqlens_q = metadata.cu_seqlens_k + if self.is_nvfp4_kvcache: + self.extend_lengths_kv_cpu = forward_batch.extend_seq_lens_cpu + self.prefix_lengths_kv_cpu = forward_batch.extend_prefix_lens_cpu + # Convert the page table to a strided format if self.page_size > 1: self.strided_indices = torch.arange( @@ -563,6 +665,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.forward_metadata = metadata + self._prepare_nvfp4_metadata_for_extend(forward_batch) + def forward_decode( self, q: torch.Tensor, @@ -573,57 +677,131 @@ def forward_decode( save_kv_cache: bool = True, **kwargs, ) -> torch.Tensor: - """Run forward for decode using TRTLLM MHA kernel.""" + """ + Run forward for decode using TRTLLM MHA kernel. + DECODE + """ cache_loc = forward_batch.out_cache_loc use_fused_fp8_path = self._should_use_fused_fp8_path(save_kv_cache, k) - if use_fused_fp8_path: - # Use fused FP8 quantization + KV cache write path - self._fused_fp8_set_kv_buffer( - q=q, - k=k, - v=v, - layer=layer, - forward_batch=forward_batch, - ) - k = None - v = None + # prepare k/v global scale + if not hasattr(layer, "k_scale") or layer.k_scale is None: + k_scale = 1.0 else: - # Use original set_kv_buffer path - if save_kv_cache and k is not None: + k_scale = layer.k_scale + if not hasattr(layer, "v_scale") or layer.v_scale is None: + v_scale = 1.0 + else: + v_scale = layer.v_scale + + if self.is_nvfp4_kvcache: + if self.is_sm100_gpu: + # re-scale for requirements of trtllm nvfp4 kv cache kernel + # we only apply this rescale for quant kernel, but not fp8 mha kernel + k_scale = k_scale * 6.0 + v_scale = v_scale * 6.0 + cur_k_scale_gpu = self.k_scales_gpu[layer.layer_id : layer.layer_id + 1] + cur_v_scale_gpu = self.v_scales_gpu[layer.layer_id : layer.layer_id + 1] + + # save k/v cache to kv pool + if save_kv_cache and k is not None: + if use_fused_fp8_path: + # fused fp8 quant + write kv cache + self._fused_fp8_set_kv_buffer( + q=q, + k=k, + v=v, + layer=layer, + forward_batch=forward_batch, + ) + k = None + v = None + elif self.is_nvfp4_kvcache: forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale + layer, + cache_loc, + k, + v, + cur_k_scale_gpu, + cur_v_scale_gpu, + ) + else: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, k_scale, v_scale ) - if self.data_type == torch.float8_e4m3fn: + # prepare query + # For XQA, q_dtype should be bf16 + if (self.data_type == torch.float8_e4m3fn or self.is_nvfp4_kvcache) and ( + not self.is_xqa_impl + ): q = q.to(torch.float8_e4m3fn) q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) - k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - # shape conversion: - # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim] - k_cache = k_cache.view( - -1, self.page_size, layer.tp_k_head_num, layer.head_dim - ).permute(0, 2, 1, 3) - v_cache = v_cache.view( - -1, self.page_size, layer.tp_v_head_num, layer.head_dim - ).permute(0, 2, 1, 3) - kv_cache = (k_cache, v_cache) + + # prepare kv cache + if self.is_nvfp4_kvcache: + + k_cache, k_cache_scales = forward_batch.token_to_kv_pool.get_fp4_key_buffer( + layer.layer_id + ) + v_cache, v_cache_scales = ( + forward_batch.token_to_kv_pool.get_fp4_value_buffer(layer.layer_id) + ) + k_cache = k_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim // 2 + ).permute(0, 2, 1, 3) + v_cache = v_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim // 2 + ).permute(0, 2, 1, 3) + + k_cache_scales = k_cache_scales.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim // 16 + ).permute(0, 2, 1, 3) + v_cache_scales = v_cache_scales.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim // 16 + ).permute(0, 2, 1, 3) + + # wrap for trtllm-gen mha + # k_scale = k_scale * 6.0 + # v_scale = v_scale * 6.0 + # k_cache_scales = (k_cache_scales.float() / 6.0).to(torch.float8_e4m3fn) + # v_cache_scales = (v_cache_scales.float() / 6.0).to(torch.float8_e4m3fn) + + kv_cache = (k_cache, v_cache) + kv_cache_block_scales = (k_cache_scales, v_cache_scales) + + else: + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + # shape conversion: + # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim] + k_cache = k_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ).permute(0, 2, 1, 3) + v_cache = v_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim + ).permute(0, 2, 1, 3) + kv_cache = (k_cache, v_cache) + kv_cache_block_scales = None # TODO: add support for quantization q_scale = 1.0 - k_scale = ( - layer.k_scale_float - if getattr(layer, "k_scale_float", None) is not None - else 1.0 - ) + # k_scale = ( + # layer.k_scale_float + # if getattr(layer, "k_scale_float", None) is not None + # else 1.0 + # ) + bmm1_scale = q_scale * k_scale * layer.scaling - bmm2_scale = 1.0 + bmm2_scale = v_scale # sink: additional value per head in the denominator of the softmax. attention_sink = kwargs.get("sinks", None) - # Call TRT-LLM kernel - # raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype + # FIXME(Sam): Here XQA only support fp16 input for now, we will add bf16 support when flashinfer release. + if self.is_xqa_impl and self.is_nvfp4_kvcache: + q = q.to(torch.float16) o = flashinfer.decode.trtllm_batch_decode_with_kv_cache( query=q, kv_cache=kv_cache, @@ -634,11 +812,12 @@ def forward_decode( bmm1_scale=bmm1_scale, bmm2_scale=bmm2_scale, window_left=layer.sliding_window_size, - # TODO: add attention_sink operation or nvfp4 scale factor if needed + # TODO: add attention_sink operation if needed sinks=attention_sink, - out_dtype=self.q_data_type, # model_runner.dtype + out_dtype=q.dtype, # model_runner.dtype + kv_cache_sf=kv_cache_block_scales, ) - + o = o.to(self.q_data_type) return o.view(-1, layer.tp_q_head_num * layer.head_dim) def forward_extend( @@ -651,59 +830,238 @@ def forward_extend( save_kv_cache=True, **kwargs, ): + """ + Target-Prefill[EXTEND](w/o cudagraph), context_api + Draft-Prefill [EXTEND](w/o cudagraph), context_api + Draft-Extend [DRAFT_EXTEND](cudagraph), context_api + Target-Verify [TARGET_VERIFY](cudagraph), decode_api + """ cache_loc = forward_batch.out_cache_loc use_fused_fp8_path = self._should_use_fused_fp8_path(save_kv_cache, k) - if use_fused_fp8_path: - # Use fused FP8 quantization + KV cache write path - self._fused_fp8_set_kv_buffer( - q=q, - k=k, - v=v, - layer=layer, - forward_batch=forward_batch, - ) - k = None - v = None - else: - # Use original set_kv_buffer path - if save_kv_cache and k is not None: + # process k/v global scale + v_scale = layer.v_scale + k_scale = layer.k_scale + if not hasattr(layer, "k_scale") or layer.k_scale is None: + k_scale = 1.0 + if not hasattr(layer, "v_scale") or layer.v_scale is None: + v_scale = 1.0 + + if self.is_nvfp4_kvcache: + cur_k_scale_gpu = self.k_scales_gpu[layer.layer_id : layer.layer_id + 1] + cur_v_scale_gpu = self.v_scales_gpu[layer.layer_id : layer.layer_id + 1] + + # save k/v cache to kv pool + if save_kv_cache and k is not None: + if use_fused_fp8_path: + # Use fused FP8 quantization + KV cache write path + self._fused_fp8_set_kv_buffer( + q=q, + k=k, + v=v, + layer=layer, + forward_batch=forward_batch, + ) + k = None + v = None + elif self.is_nvfp4_kvcache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer=layer, + loc=cache_loc, + cache_k=k, + cache_v=v, + k_scale=cur_k_scale_gpu, + v_scale=cur_v_scale_gpu, + ) + else: forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale + layer, cache_loc, k, v, k_scale, v_scale ) - if self.data_type == torch.float8_e4m3fn: - q = q.to(torch.float8_e4m3fn) + if ( + self.is_nvfp4_kvcache + and forward_batch.forward_mode.is_extend_without_speculative() + ): + # path: + # 1. nvfp4, target model prefill/chunkedprefill, w/o cudagraph, is_extend_without_speculative(), context mha kernel + # 2. nvfp4, draft model prefill/chunkedprefill, w/o cudagraph, is_extend_without_speculative(), context mha kernel + from sglang.srt.layers.quantization.fp4_utils import NVFP4QuantizeUtil + + batch_size = forward_batch.batch_size + + k_buffer_nvfp4, k_scales_buffer = ( + forward_batch.token_to_kv_pool.get_fp4_key_buffer(layer.layer_id) + ) + v_buffer_nvfp4, v_scales_buffer = ( + forward_batch.token_to_kv_pool.get_fp4_value_buffer(layer.layer_id) + ) + k_buffer_dq, v_buffer_dq = forward_batch.token_to_kv_pool.get_dq_kv_buffer() + + # Convert current k/v to fp8 once + k_cur_fp8 = k.to(torch.float8_e4m3fn) + v_cur_fp8 = v.to(torch.float8_e4m3fn) + + # Process each request in batch + cur_batch_start_loc_cpu = 0 + # skip first page for dummy output + cur_token_idx_dq_buffer_cpu = self.page_size + for batch_idx in range(batch_size): + req_pool_idx = self.cpu_req_pool_indices[batch_idx] + prev_len = forward_batch.extend_prefix_lens_cpu[batch_idx] + extend_len = forward_batch.extend_seq_lens_cpu[batch_idx] + # prev_len = self.prefix_lengths_kv_cpu[batch_idx] + # extend_len = self.extend_lengths_kv_cpu[batch_idx] + + # Dequantize and copy previous KV + if prev_len > 0: + prev_token_indices = forward_batch.req_to_token_pool.req_to_token[ + req_pool_idx, :prev_len + ] + k_prev_nvfp4 = k_buffer_nvfp4[prev_token_indices] + k_prev_scales = k_scales_buffer[prev_token_indices] + v_prev_nvfp4 = v_buffer_nvfp4[prev_token_indices] + v_prev_scales = v_scales_buffer[prev_token_indices] + + # Dequantize: [prev_len, num_heads, head_dim] + k_prev_bf16 = NVFP4QuantizeUtil.cuda_nvfp4_dequantize( + k_prev_nvfp4.view(torch.uint8), + k_prev_scales, + cur_k_scale_gpu, + ) + v_prev_bf16 = NVFP4QuantizeUtil.cuda_nvfp4_dequantize( + v_prev_nvfp4.view(torch.uint8), + v_prev_scales, + cur_v_scale_gpu, + ) + k_prev_fp8 = k_prev_bf16.to(torch.float8_e4m3fn) + v_prev_fp8 = v_prev_bf16.to(torch.float8_e4m3fn) + + # Direct continuous copy + k_buffer_dq[ + cur_token_idx_dq_buffer_cpu : cur_token_idx_dq_buffer_cpu + + prev_len + ] = k_prev_fp8 + v_buffer_dq[ + cur_token_idx_dq_buffer_cpu : cur_token_idx_dq_buffer_cpu + + prev_len + ] = v_prev_fp8 + + # Write of current chunk + cur_end = cur_batch_start_loc_cpu + extend_len + k_cur_chunk = k_cur_fp8[cur_batch_start_loc_cpu:cur_end] + v_cur_chunk = v_cur_fp8[cur_batch_start_loc_cpu:cur_end] + k_buffer_dq[ + cur_token_idx_dq_buffer_cpu + + prev_len : cur_token_idx_dq_buffer_cpu + + prev_len + + extend_len + ] = k_cur_chunk + v_buffer_dq[ + cur_token_idx_dq_buffer_cpu + + prev_len : cur_token_idx_dq_buffer_cpu + + prev_len + + extend_len + ] = v_cur_chunk + + cur_batch_start_loc_cpu = cur_end + # align to page size + cur_token_idx_dq_buffer_cpu = ( + ( + cur_token_idx_dq_buffer_cpu + + prev_len + + extend_len + + self.page_size + - 1 + ) + // self.page_size + * self.page_size + ) + + k_paged = k_buffer_dq.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ).permute(0, 2, 1, 3) + v_paged = v_buffer_dq.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim + ).permute(0, 2, 1, 3) + + # Use paged KV cache for attention + kv_cache = (k_paged, v_paged) + kv_block_scales = None + # self.dq_page_table is only used for nvfp4 target model prefill path + cur_step_page_table = self.dq_page_table + elif self.is_nvfp4_kvcache and ( + forward_batch.forward_mode.is_target_verify() + or forward_batch.forward_mode.is_draft_extend(include_v2=True) + ): + assert ( + False + ), "NVFP4 kv cache is not supported for MTP draft extend for now." + k_cache, k_cache_scales = forward_batch.token_to_kv_pool.get_fp4_key_buffer( + layer.layer_id + ) + v_cache, v_cache_scales = ( + forward_batch.token_to_kv_pool.get_fp4_value_buffer(layer.layer_id) + ) + k_cache = k_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim // 2 + ).permute(0, 2, 1, 3) + v_cache = v_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim // 2 + ).permute(0, 2, 1, 3) + + k_cache_scales = k_cache_scales.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim // 16 + ).permute(0, 2, 1, 3) + v_cache_scales = v_cache_scales.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim // 16 + ).permute(0, 2, 1, 3) + + kv_cache = (k_cache, v_cache) + kv_block_scales = (k_cache_scales, v_cache_scales) + cur_step_page_table = self.forward_metadata.page_table + else: + # bf16/fp8, all paths + # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim] + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + k_cache = k_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ).permute(0, 2, 1, 3) + v_cache = v_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim + ).permute(0, 2, 1, 3) + kv_cache = (k_cache, v_cache) + kv_block_scales = None + cur_step_page_table = self.forward_metadata.page_table + q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) - # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim] - k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - k_cache = k_cache.view( - -1, self.page_size, layer.tp_k_head_num, layer.head_dim - ).permute(0, 2, 1, 3) - v_cache = v_cache.view( - -1, self.page_size, layer.tp_v_head_num, layer.head_dim - ).permute(0, 2, 1, 3) - kv_cache = (k_cache, v_cache) + if self.data_type == torch.float8_e4m3fn or self.is_nvfp4_kvcache: + q = q.to(torch.float8_e4m3fn) # sink: additional value per head in the denominator of the softmax. attention_sink = kwargs.get("sinks", None) # TODO: add support for quantization q_scale = 1.0 - k_scale = ( - layer.k_scale_float - if getattr(layer, "k_scale_float", None) is not None - else 1.0 - ) + # k_scale = ( + # # where to load k global and v global scales? + # layer.k_scale_float + # if getattr(layer, "k_scale_float", None) is not None + # else 1.0 + # ) bmm1_scale = q_scale * k_scale * layer.scaling - bmm2_scale = 1.0 + bmm2_scale = v_scale if forward_batch.forward_mode.is_target_verify(): + # Paths: + # 4. bf16/fp8/nvfp4, target model verify, w/ cudagraph, is_target_verify(), decode mha kernel o = flashinfer.decode.trtllm_batch_decode_with_kv_cache( query=q, kv_cache=kv_cache, + kv_block_scales=kv_block_scales, workspace_buffer=self.workspace_buffer, - block_tables=self.forward_metadata.page_table, + block_tables=cur_step_page_table, seq_lens=self.forward_metadata.cache_seqlens_int32, max_seq_len=self.max_context_len, bmm1_scale=bmm1_scale, @@ -711,16 +1069,16 @@ def forward_extend( window_left=layer.sliding_window_size, # TODO: add attention_sink operation or nvfp4 scale factor if needed sinks=attention_sink, - out_dtype=self.q_data_type, # model_runner.dtype + out_dtype=q.dtype, # fp4 kv kernel doesn't support bf16 output q_len_per_req=self.forward_metadata.max_seq_len_q, ) else: - + # TODO(Sam): NVFP4 kv cache is not supported or MTP. Because draft extend will invoke this api, it needs nvfp4 kv cache support. o = flashinfer.prefill.trtllm_batch_context_with_kv_cache( query=q, kv_cache=kv_cache, workspace_buffer=self.workspace_buffer, - block_tables=self.forward_metadata.page_table, + block_tables=cur_step_page_table, seq_lens=self.forward_metadata.cache_seqlens_int32, max_q_len=self.forward_metadata.max_seq_len_q, max_kv_len=self.max_context_len, @@ -730,10 +1088,13 @@ def forward_extend( cum_seq_lens_q=self.forward_metadata.cu_seqlens_q, cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k, window_left=layer.sliding_window_size, - # TODO: add attention_sink operation or nvfp4 scale factor if needed + # TODO: add attention_sink operation scale factor if needed sinks=attention_sink, - out_dtype=self.q_data_type, # model_runner.dtype + out_dtype=self.q_data_type, ) + # import sys + # sys.stdout.flush() + o = o.to(self.q_data_type) return o.view(-1, layer.tp_q_head_num * layer.head_dim) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 61219f6b04a7..e383d74b860a 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -245,6 +245,8 @@ def _dispatch_auto_backend() -> Callable: # 4. AITER (if AMD GPU with AITER enabled) # 5. Triton (fallback) + return triton_w8a8_block_fp8_linear + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: return deepgemm_w8a8_block_fp8_linear_with_fallback elif is_blackwell_supported() and is_flashinfer_available(): diff --git a/python/sglang/srt/layers/quantization/kvfp4_tensor.py b/python/sglang/srt/layers/quantization/kvfp4_tensor.py index 545199ff0a35..ccecf7a0d4f9 100644 --- a/python/sglang/srt/layers/quantization/kvfp4_tensor.py +++ b/python/sglang/srt/layers/quantization/kvfp4_tensor.py @@ -12,13 +12,46 @@ # limitations under the License. # ============================================================================== +# Define a enum class for FP4 formats, including MXFP4, NVFP4 and future formats +from enum import Enum + import torch + +class FP4KVCacheRecipe(Enum): + MXFP4 = 1 # KVFP4: block-wise scaling + NVFP4 = 2 # two-level scaling: global FP32 + block FP8 E4M3 + + E2M1_MAX = 6.0 +MAX_BLOCK_SCALE_FP8 = 448.0 # Maximum FP8 E4M3 value # Put constants directly on CUDA if available _device = "cuda" if torch.cuda.is_available() else "cpu" +# E2M1 format: 1 sign bit + 2 exponent bits + 1 mantissa bit = 4 bits +# 16 possible values: 0x0-0xF +# Negative values: 0x8-0xF (sign bit = 1) +# Positive values: 0x0-0x7 (sign bit = 0) E2M1_VALUES = torch.tensor( - [0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.float32, device=_device + [ + 0, + 0.5, + 1, + 1.5, + 2, + 3, + 4, + 6, # 0x0-0x7: positive values + -0, + -0.5, + -1, + -1.5, + -2, + -3, + -4, + -6, + ], # 0x8-0xF: negative values + dtype=torch.float32, + device=_device, ) E2M1_BOUNDS = torch.tensor( [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5], dtype=torch.float32, device=_device @@ -32,6 +65,7 @@ class KVFP4QuantizeUtil: @torch.compile def batched_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ + Quantize tensor to KVFP4 format Args: tensor: Input tensor of shape [B, M, N] @@ -110,3 +144,968 @@ def batched_dequantize( scaled = reshaped * torch.exp2(scale_exp.unsqueeze(-1)) return scaled.view(b, m, n).to(dtype) + + +class NVFP4QuantizeUtil: + """Utility class for NVFP4 quantization and dequantization with two-level scaling (global FP32 + block FP8).""" + + # Cached kernel modules + _nvfp4_dequant_module = None + _nvfp4_quant_sm100_module = None + + @staticmethod + def fi_nvfp4_quantize(tensor: torch.Tensor, global_scale: torch.Tensor): + # input and output shape [B, M, N] + # return uint8 cache and fp8 block scales + try: + from flashinfer import fp4_quantize + except ImportError: + raise ImportError( + "flashinfer is installed correctly. Please install flashinfer to use NVFP4 KV cache." + ) + global_scale_inv = 1.0 / global_scale + if isinstance(global_scale_inv, float): + global_scale_inv = torch.tensor( + global_scale_inv, dtype=torch.float32, device=tensor.device + ) + assert ( + global_scale_inv.device == tensor.device + ), "global_scale and tensor must be on the same device" + b, m, n = tensor.shape + tensor_reshaped = tensor.reshape(b * m, n) + tensor_fp4, tensor_fp4_sf = fp4_quantize( + tensor_reshaped, + global_scale_inv, + sf_vec_size=16, + sf_use_ue8m0=False, + is_sf_swizzled_layout=False, + is_sf_8x4_layout=False, + enable_pdl=None, + ) + tensor_fp4 = tensor_fp4.view(b, m, tensor_fp4.shape[-1]) + tensor_fp4_sf = tensor_fp4_sf.view(b, m, tensor_fp4_sf.shape[-1]).view( + torch.float8_e4m3fn + ) + return tensor_fp4, tensor_fp4_sf, global_scale + + @staticmethod + def cuda_nvfp4_dequantize( + quant_tensor: torch.Tensor, + block_scales: torch.Tensor, + global_scale: torch.Tensor, + dtype: torch.dtype = torch.bfloat16, + ) -> torch.Tensor: + """ + Dequantize NVFP4 tensor using optimized CUDA kernel with vectorization and shared memory. + + This is a fast kernel-based implementation that provides significant performance improvements + over the pure PyTorch implementation. + + Args: + quant_tensor: Quantized E2M1 tensor of shape [B, M, N/2] (packed uint8) + block_scales: Block scale factors of shape [B, M, N/16] (FP8 E4M3) + global_scale: Global scale factor (float32 scalar or tensor) + dtype: Target dtype for output (torch.bfloat16 or torch.float16) + + Returns: + Dequantized tensor of shape [B, M, N] + """ + b, m, n_half = quant_tensor.shape + n = n_half * 2 + + # Ensure inputs are on CUDA + assert quant_tensor.is_cuda, "quant_tensor must be on CUDA" + assert block_scales.is_cuda, "block_scales must be on CUDA" + assert quant_tensor.dtype == torch.uint8, "quant_tensor must be uint8" + + # Handle global_scale conversion - ensure it's a CUDA tensor + if isinstance(global_scale, (int, float)): + global_scale = torch.tensor( + [global_scale], dtype=torch.float32, device=quant_tensor.device + ) + else: + # Ensure global_scale is on CUDA and is a 1D tensor with 1 element + if not global_scale.is_cuda: + global_scale = global_scale.to(quant_tensor.device, non_blocking=True) + if global_scale.dim() == 0: + global_scale = global_scale.unsqueeze(0) + global_scale = global_scale.contiguous() + + # Get the kernel module + module = NVFP4QuantizeUtil._get_dequant_module() + + # Reshape to 2D for kernel: [B*M, N/2] and [B*M, N/16] + quant_2d = quant_tensor.reshape(b * m, n_half) + + # Convert FP8 E4M3 block scales to uint8 view for kernel + if block_scales.dtype == torch.float8_e4m3fn: + scales_2d = block_scales.view(torch.uint8).reshape(b * m, -1) + else: + scales_2d = block_scales.reshape(b * m, -1) + + # Call appropriate kernel based on dtype + if dtype == torch.bfloat16: + output_2d = module.nvfp4_dequant_to_bf16(quant_2d, scales_2d, global_scale) + elif dtype == torch.float16: + output_2d = module.nvfp4_dequant_to_fp16(quant_2d, scales_2d, global_scale) + else: + raise ValueError( + f"Unsupported dtype: {dtype}. Only torch.bfloat16 and torch.float16 are supported." + ) + + # Reshape back to 3D: [B, M, N] + output = output_2d.reshape(b, m, n) + + return output + + @staticmethod + def cuda_nvfp4_quantize_blackwell( + tensor: torch.Tensor, + global_scale: torch.Tensor, + dtype: torch.dtype = torch.bfloat16, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Quantize tensor to NVFP4 format using optimized SM100 CUDA kernel. + + This is a fast kernel-based implementation optimized for SM100 architecture + that uses native E2M1 conversion instructions. + + Args: + tensor: Input tensor of shape [B, M, N] (bfloat16 or float16) + global_scale: Global scale factor (float32 scalar or tensor) + dtype: Input dtype (must match tensor dtype) + + Returns: + quant_tensor: Quantized E2M1 tensor of shape [B, M, N/2] (packed uint8) + block_scales: Block scale factors of shape [B, M, N/16] (FP8 E4M3) + global_scale: Global scale factor (float32 scalar tensor) + """ + b, m, n = tensor.shape + + # Ensure inputs are on CUDA + assert tensor.is_cuda, "tensor must be on CUDA" + assert dtype in [ + torch.bfloat16, + torch.float16, + ], "Only bfloat16 and float16 are supported" + assert ( + tensor.dtype == dtype + ), f"tensor dtype {tensor.dtype} must match specified dtype {dtype}" + assert n % 16 == 0, "N dimension must be divisible by 16" + + # Handle global_scale conversion - ensure it's a CUDA tensor + if isinstance(global_scale, (int, float)): + global_scale = torch.tensor( + [global_scale], dtype=torch.float32, device=tensor.device + ) + else: + # Ensure global_scale is on CUDA and is a 1D tensor with 1 element + if not global_scale.is_cuda: + global_scale = global_scale.to(tensor.device, non_blocking=True) + if global_scale.dim() == 0: + global_scale = global_scale.unsqueeze(0) + global_scale = global_scale.contiguous() + + # Get the kernel module + module = NVFP4QuantizeUtil._get_quant_sm100_module() + + # Reshape to 2D for kernel: [B*M, N] + tensor_2d = tensor.reshape(b * m, n).contiguous() + + # Call appropriate kernel based on dtype + if dtype == torch.bfloat16: + quant_2d, scales_2d = module.nvfp4_quant_from_bf16(tensor_2d, global_scale) + elif dtype == torch.float16: + quant_2d, scales_2d = module.nvfp4_quant_from_fp16(tensor_2d, global_scale) + else: + raise ValueError( + f"Unsupported dtype: {dtype}. Only torch.bfloat16 and torch.float16 are supported." + ) + + # Reshape back to 3D: [B, M, N/2] and [B, M, N/16] + quant_tensor = quant_2d.reshape(b, m, n // 2) + + # Convert uint8 block scales back to FP8 E4M3 + block_scales = scales_2d.view(torch.float8_e4m3fn).reshape(b, m, n // 16) + + return quant_tensor, block_scales, global_scale + + @staticmethod + # @torch.compile + def batched_quantize( + tensor: torch.Tensor, global_scale: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Quantize tensor to NVFP4 format with two-level scaling (global FP32 + block FP8 E4M3) + + Formula: x_fp4 * block_scale * global_scale = x_bf16 + + Args: + tensor: Input tensor of shape [B, M, N] + global_scale: Optional global scale factor (float32 scalar). + If None, will auto-compute per-tensor global scale. + If provided, will use the given global scale. + + Returns: + quant_tensor: Quantized E2M1 tensor of shape [B, M, N/2] (packed uint8) + block_scales: Block scale factors of shape [B, M*N/16] (FP8 E4M3) + global_scale: Global scale factor (float32 scalar) + """ + b, m, n = tensor.shape + device = tensor.device + + # Step 1: Calculate global_scale + if global_scale is None: + global_max = tensor.abs().amax() + global_scale = torch.tensor( + global_max.item() / (E2M1_MAX * MAX_BLOCK_SCALE_FP8), + dtype=torch.float32, + device=device, + ) + else: + # Use provided global scale + if not isinstance(global_scale, torch.Tensor): + global_scale = torch.tensor( + global_scale, dtype=torch.float32, device=device + ) + else: + global_scale = global_scale.to(device=device, dtype=torch.float32) + + if global_scale < 1e-6: + global_scale = torch.tensor(1e-6, dtype=torch.float32, device=device) + + # Step 2: Scale x_bf16 to FP4 range [-6, 6] + # First, reshape to blocks [B, M*N/16, 16] + reshaped = tensor.float().view(b, m, n // 16, 16) + block_max = reshaped.abs().amax(dim=-1, keepdim=True) + block_scales = block_max.squeeze(-1) / (E2M1_MAX * global_scale) + block_scales = torch.clamp(block_scales, 0.0, MAX_BLOCK_SCALE_FP8) + block_scales_fp8 = block_scales.to(torch.float8_e4m3fn) + + # Scale each block to FP4 range: x_scaled = x / block_max * E2M1_MAX + # This ensures values are in [-6, 6] range + block_scales_fixed = block_scales.unsqueeze(-1) + x_scaled = reshaped / (block_scales_fixed * global_scale) + + # Step 3: Convert scaled values (x_scaled) to packed FP4 + # x_scaled is already in FP4 range [-6, 6] in bf16 representation + # Now quantize to E2M1 format + + # E2M1 format: bit 3 = sign, bits 2-0 = magnitude (exponent + mantissa) + sign_bits = (x_scaled < 0).to(torch.uint8) << 3 # bit 3: sign bit + abs_vals = x_scaled.abs() + # Find nearest E2M1 magnitude (0-7) using boundaries + magnitude_bits = torch.sum(abs_vals.unsqueeze(-1) >= E2M1_BOUNDS, dim=-1).to( + torch.uint8 + ) + # Combine sign and magnitude: 4-bit value = sign_bit | magnitude + fp4_vals = sign_bits | magnitude_bits + # Pack two FP4 values into one uint8 + fp4_reshaped = fp4_vals.view(b, m, n) + packed = (fp4_reshaped[..., 1::2] << 4) + fp4_reshaped[..., 0::2] + + return packed, block_scales_fp8, global_scale + + @staticmethod + # @torch.compile + def batched_dequantize( + quant_tensor: torch.Tensor, + block_scales: torch.Tensor, + global_scale: torch.Tensor, + dtype: torch.dtype = torch.bfloat16, + ) -> torch.Tensor: + """ + Dequantize NVFP4 tensor with two-level scaling (global FP32 + block FP8 E4M3) + + Args: + quant_tensor: Quantized E2M1 tensor of shape [B, M, N/2] (packed uint8) + block_scales: Block scale factors of shape [B, M*N/16] (FP8 E4M3) + global_scale: Global scale factor (float32 scalar) + dtype: Target dtype for output + + Returns: + Dequantized tensor of shape [B, M, N] + """ + b, m, n_half = quant_tensor.shape + n = n_half * 2 + + # More efficient unpacking using bit operations + fp4_vals = torch.empty(b, m, n, dtype=torch.uint8, device=quant_tensor.device) + fp4_vals[..., 0::2] = quant_tensor & 0x0F + fp4_vals[..., 1::2] = (quant_tensor >> 4) & 0x0F + + # Directly map 4-bit E2M1 values (0x0-0xF) to float + # E2M1_VALUES[0-7] = positive, E2M1_VALUES[8-15] = negative + float_vals = E2M1_VALUES[fp4_vals.long()] + + # Reshape for block-wise scaling + reshaped = float_vals.view(b, m * n // 16, 16) + + # Apply block scale factors (inverse scaling: divide by FP8 block scales) + # Convert FP8 back to float32 for computation + block_scales_float = block_scales.float().unsqueeze(-1) # [B, M*N/16, 1] + scaled = reshaped * block_scales_float + + # Apply inverse global scaling + dequantized = scaled.view(b, m, n) * global_scale + + return dequantized.to(dtype) + + @staticmethod + def _get_dequant_module(): + """Load and cache the NVFP4 dequantization kernel module.""" + if NVFP4QuantizeUtil._nvfp4_dequant_module is None: + from torch.utils.cpp_extension import load_inline + + CUDA_SOURCE = r""" +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +typedef __nv_fp8_e4m3 fp8_e4m3; +typedef __nv_fp8x2_e4m3 fp8x2_e4m3; +#define HAS_FP8_SUPPORT 1 +#else +typedef uint8_t fp8_e4m3; +typedef uint16_t fp8x2_e4m3; +#define HAS_FP8_SUPPORT 0 +#endif + +// E2M1 lookup table +__device__ __constant__ float E2M1_LUT[16] = { + 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, + -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f +}; + +// Dequantize 4 FP4 values (packed in uint16_t) to float2x2 +__device__ __forceinline__ void dequant_fp4x4_to_float2x2( + uint16_t packed_fp4, + float scale_0, + float scale_1, + float global_scale, + float2& out0, + float2& out1 +) { + uint8_t fp4_0 = (packed_fp4 >> 0) & 0xF; + uint8_t fp4_1 = (packed_fp4 >> 4) & 0xF; + uint8_t fp4_2 = (packed_fp4 >> 8) & 0xF; + uint8_t fp4_3 = (packed_fp4 >> 12) & 0xF; + + out0.x = E2M1_LUT[fp4_0] * scale_0 * global_scale; + out0.y = E2M1_LUT[fp4_1] * scale_0 * global_scale; + out1.x = E2M1_LUT[fp4_2] * scale_1 * global_scale; + out1.y = E2M1_LUT[fp4_3] * scale_1 * global_scale; +} + +template +__global__ void nvfp4_dequant_vectorized_kernel( + const uint8_t* __restrict__ fp4_data, + const uint8_t* __restrict__ block_scales, + const float* __restrict__ global_scale_ptr, + OutType* __restrict__ output, + const int M, + const int K +) { + const int row = blockIdx.x; + const int tid = threadIdx.x; + + if (row >= M) return; + + // Load global_scale from device memory once per block + __shared__ float global_scale; + __shared__ uint8_t smem_scales[512]; + + if (tid == 0) { + global_scale = *global_scale_ptr; + } + + const int K_scales = K / 16; + const int K_packed = K / 2; + + for (int i = tid; i < K_scales; i += BLOCK_SIZE) { + smem_scales[i] = block_scales[row * K_scales + i]; + } + __syncthreads(); + + constexpr int PACKED_PER_THREAD = ELTS_PER_THREAD / 2; + const int elts_per_block = BLOCK_SIZE * ELTS_PER_THREAD; + + const uint8_t* row_fp4 = fp4_data + row * K_packed; + OutType* row_output = output + row * K; + + for (int base_col = 0; base_col < K; base_col += elts_per_block) { + const int col_start = base_col + tid * ELTS_PER_THREAD; + + if (col_start >= K) break; + + #pragma unroll + for (int i = 0; i < PACKED_PER_THREAD / 2; ++i) { + const int col = col_start + i * 4; + if (col + 3 >= K) break; + + const int packed_idx = col / 2; + uint16_t packed_fp4 = *reinterpret_cast(&row_fp4[packed_idx]); + + const int scale_idx_0 = col / 16; + const int scale_idx_1 = (col + 2) / 16; + + const uint8_t scale_fp8_0 = smem_scales[scale_idx_0]; + const uint8_t scale_fp8_1 = smem_scales[scale_idx_1]; + +#if HAS_FP8_SUPPORT + const float scale_0 = static_cast(*reinterpret_cast(&scale_fp8_0)); + const float scale_1 = static_cast(*reinterpret_cast(&scale_fp8_1)); +#else + const float scale_0 = 1.0f; + const float scale_1 = 1.0f; +#endif + + float2 out0, out1; + dequant_fp4x4_to_float2x2(packed_fp4, scale_0, scale_1, global_scale, out0, out1); + + if constexpr (std::is_same_v) { + __nv_bfloat162 bf16_0 = __float22bfloat162_rn(out0); + __nv_bfloat162 bf16_1 = __float22bfloat162_rn(out1); + + *reinterpret_cast<__nv_bfloat162*>(&row_output[col]) = bf16_0; + *reinterpret_cast<__nv_bfloat162*>(&row_output[col + 2]) = bf16_1; + } else if constexpr (std::is_same_v) { + half2 h2_0 = __float22half2_rn(out0); + half2 h2_1 = __float22half2_rn(out1); + + *reinterpret_cast(&row_output[col]) = h2_0; + *reinterpret_cast(&row_output[col + 2]) = h2_1; + } + } + } +} + +torch::Tensor nvfp4_dequant_to_bf16_cuda_v2( + torch::Tensor fp4_data, + torch::Tensor block_scales, + torch::Tensor global_scale +) { + TORCH_CHECK(fp4_data.is_cuda(), "fp4_data must be CUDA tensor"); + TORCH_CHECK(fp4_data.dtype() == torch::kUInt8, "fp4_data must be uint8"); + TORCH_CHECK(global_scale.is_cuda(), "global_scale must be CUDA tensor"); + TORCH_CHECK(global_scale.dtype() == torch::kFloat32, "global_scale must be float32"); + + const int M = fp4_data.size(0); + const int K = fp4_data.size(1) * 2; + + auto output = torch::empty({M, K}, torch::TensorOptions() + .dtype(torch::kBFloat16).device(fp4_data.device())); + + constexpr int BLOCK_SIZE = 128; + dim3 grid(M); + dim3 block(BLOCK_SIZE); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + nvfp4_dequant_vectorized_kernel<__nv_bfloat16, BLOCK_SIZE, 16><<>>( + fp4_data.data_ptr(), + block_scales.data_ptr(), + global_scale.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(output.data_ptr()), + M, K + ); + + return output; +} + +torch::Tensor nvfp4_dequant_to_fp16_cuda_v2( + torch::Tensor fp4_data, + torch::Tensor block_scales, + torch::Tensor global_scale +) { + TORCH_CHECK(fp4_data.is_cuda(), "fp4_data must be CUDA tensor"); + TORCH_CHECK(fp4_data.dtype() == torch::kUInt8, "fp4_data must be uint8"); + TORCH_CHECK(global_scale.is_cuda(), "global_scale must be CUDA tensor"); + TORCH_CHECK(global_scale.dtype() == torch::kFloat32, "global_scale must be float32"); + + const int M = fp4_data.size(0); + const int K = fp4_data.size(1) * 2; + + auto output = torch::empty({M, K}, torch::TensorOptions() + .dtype(torch::kFloat16).device(fp4_data.device())); + + constexpr int BLOCK_SIZE = 128; + dim3 grid(M); + dim3 block(BLOCK_SIZE); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + nvfp4_dequant_vectorized_kernel<<>>( + fp4_data.data_ptr(), + block_scales.data_ptr(), + global_scale.data_ptr(), + reinterpret_cast(output.data_ptr()), + M, K + ); + + return output; +} +""" + + CPP_SOURCE = r""" +#include + +torch::Tensor nvfp4_dequant_to_bf16_cuda_v2( + torch::Tensor fp4_data, + torch::Tensor block_scales, + torch::Tensor global_scale +); + +torch::Tensor nvfp4_dequant_to_fp16_cuda_v2( + torch::Tensor fp4_data, + torch::Tensor block_scales, + torch::Tensor global_scale +); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("nvfp4_dequant_to_bf16", &nvfp4_dequant_to_bf16_cuda_v2, "NVFP4 dequantize to BF16 V2"); + m.def("nvfp4_dequant_to_fp16", &nvfp4_dequant_to_fp16_cuda_v2, "NVFP4 dequantize to FP16 V2"); +} +""" + + NVFP4QuantizeUtil._nvfp4_dequant_module = load_inline( + name="nvfp4_dequant_v2", + cpp_sources=[CPP_SOURCE], + cuda_sources=[CUDA_SOURCE], + extra_cuda_cflags=[ + "-O3", + "--use_fast_math", + "-std=c++17", + "-DENABLE_BF16", + ], + verbose=False, + with_cuda=True, + ) + + return NVFP4QuantizeUtil._nvfp4_dequant_module + + @staticmethod + def _get_quant_sm100_module(): + """Load and cache the NVFP4 quantization kernel module for SM100 architecture.""" + if NVFP4QuantizeUtil._nvfp4_quant_sm100_module is None: + from torch.utils.cpp_extension import load_inline + + CUDA_SOURCE = r""" +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +typedef __nv_fp8_e4m3 fp8_e4m3; +#define HAS_FP8_SUPPORT 1 +#else +typedef uint8_t fp8_e4m3; +#define HAS_FP8_SUPPORT 0 +#endif + +// Helper functions +__device__ __forceinline__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +__device__ __forceinline__ __nv_bfloat162 cuda_abs(__nv_bfloat162 a) { + __nv_bfloat162 result; + float fx = fabsf(__bfloat162float(a.x)); + float fy = fabsf(__bfloat162float(a.y)); + result.x = __float2bfloat16(fx); + result.y = __float2bfloat16(fy); + return result; +} + +__device__ __forceinline__ half2 cuda_abs(half2 a) { + return __habs2(a); +} + +__device__ __forceinline__ __nv_bfloat162 cuda_max(__nv_bfloat162 a, __nv_bfloat162 b) { + __nv_bfloat162 result; + result.x = __bfloat162float(a.x) > __bfloat162float(b.x) ? a.x : b.x; + result.y = __bfloat162float(a.y) > __bfloat162float(b.y) ? a.y : b.y; + return result; +} + +__device__ __forceinline__ half2 cuda_max(half2 a, half2 b) { + return __hmax2(a, b); +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) +{ +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), "f"(array[2].x), "f"(array[2].y), + "f"(array[3].x), "f"(array[3].y)); + return val; +#else + return 0; +#endif +} + +// Quantize 8 FP16/BF16 values to E2M1 with FP8 E4M3 block scaling +template +__device__ uint32_t quantize_fp16_to_e2m1_with_scaling( + InType (&vec)[4], // 4 x Vec2 = 8 values + float global_scale, + uint8_t* block_scale_out +) { + constexpr int SF_VEC_SIZE = 16; + constexpr int CVT_ELTS_PER_THREAD = 8; + constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_ELTS_PER_THREAD; + + // Get absolute maximum values among the local 8 values + auto localMax = cuda_abs(vec[0]); + + #pragma unroll + for (int i = 1; i < CVT_ELTS_PER_THREAD / 2; i++) { + localMax = cuda_max(localMax, cuda_abs(vec[i])); + } + + // Get the absolute maximum among all 16 values (two threads for 16) + localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + if constexpr (CVT_NUM_THREADS_PER_SF == 4) { + localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax); + } + + // Get the final absolute maximum value + float vecMax; + if constexpr (std::is_same_v) { + auto max_single = __bfloat162float(localMax.x) > __bfloat162float(localMax.y) ? localMax.x : localMax.y; + vecMax = __bfloat162float(max_single); + } else { + vecMax = fmaxf(__half2float(localMax.x), __half2float(localMax.y)); + } + + // Calculate block scale factor (FP8 E4M3) + uint8_t fp8_scale_val = 0; + float output_scale = 0.0f; + + // Get the SF (max value of the vector / max value of e2m1) + // maximum value of e2m1 = 6.0 + auto sf_value = reciprocal_approximate_ftz(global_scale) * (vecMax * reciprocal_approximate_ftz(6.0f)); + +#if HAS_FP8_SUPPORT + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(sf_value); + fp8_scale_val = tmp.__x; + sf_value = static_cast(tmp); +#else + // Fallback: clamp to uint8 range + fp8_scale_val = static_cast(fminf(fmaxf(sf_value, 0.0f), 255.0f)); + sf_value = static_cast(fp8_scale_val); +#endif + + // Get the output scale + output_scale = vecMax != 0 ? reciprocal_approximate_ftz(sf_value * global_scale) : 0.0f; + + // Write block scale + if (block_scale_out) { + *block_scale_out = fp8_scale_val; + } + + // Convert the input to float and apply scaling + float2 fp2_vals[CVT_ELTS_PER_THREAD / 2]; + + #pragma unroll + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2_vals[i] = __bfloat1622float2(vec[i]); + } else { + fp2_vals[i] = __half22float2(vec[i]); + } + fp2_vals[i].x *= output_scale; + fp2_vals[i].y *= output_scale; + } + + // Convert to e2m1 values (FP4) + uint32_t e2m1_vec = fp32_vec_to_e2m1(fp2_vals); + + return e2m1_vec; +} + +// Quantization kernel for BF16 to NVFP4 +template +__global__ void nvfp4_quant_from_bf16_kernel( + const __nv_bfloat16* __restrict__ input, + const float* __restrict__ global_scale_ptr, + uint8_t* __restrict__ fp4_output, + uint8_t* __restrict__ block_scales, + const int M, + const int K +) { + const int row = blockIdx.x; + const int tid = threadIdx.x; + + if (row >= M) return; + + // Load global_scale from device memory once per block + __shared__ float global_scale; + + if (tid == 0) { + global_scale = *global_scale_ptr; + } + __syncthreads(); + + constexpr int CVT_ELTS_PER_THREAD = 8; + constexpr int PACKED_PER_THREAD = CVT_ELTS_PER_THREAD / 2; + const int elts_per_block = BLOCK_SIZE * CVT_ELTS_PER_THREAD; + + const __nv_bfloat16* row_input = input + row * K; + uint8_t* row_fp4 = fp4_output + row * (K / 2); + uint8_t* row_scales = block_scales + row * (K / 16); + + for (int base_col = 0; base_col < K; base_col += elts_per_block) { + const int col_start = base_col + tid * CVT_ELTS_PER_THREAD; + + if (col_start >= K) break; + + // Load 8 BF16 values as 4 x BF16x2 + __nv_bfloat162 vec[4]; + + #pragma unroll + for (int i = 0; i < PACKED_PER_THREAD; ++i) { + const int col = col_start + i * 2; + if (col + 1 < K) { + vec[i] = *reinterpret_cast(&row_input[col]); + } else if (col < K) { + vec[i].x = row_input[col]; + vec[i].y = __float2bfloat16(0.0f); + } else { + vec[i] = __float2bfloat162_rn(0.0f); + } + } + + // Quantize to E2M1 with block scaling + const int block_idx = col_start / 16; + uint8_t* scale_out = (tid % 2 == 0) ? &row_scales[block_idx] : nullptr; + + uint32_t e2m1_vals = quantize_fp16_to_e2m1_with_scaling(vec, global_scale, scale_out); + + // Pack into output (4 bytes = 8 FP4 values) + const int packed_idx = col_start / 2; + if (packed_idx + 3 < K / 2) { + *reinterpret_cast(&row_fp4[packed_idx]) = e2m1_vals; + } else { + // Handle boundary case + uint8_t* bytes = reinterpret_cast(&e2m1_vals); + for (int i = 0; i < 4 && packed_idx + i < K / 2; ++i) { + row_fp4[packed_idx + i] = bytes[i]; + } + } + } +} + +// Quantization kernel for FP16 to NVFP4 +template +__global__ void nvfp4_quant_from_fp16_kernel( + const half* __restrict__ input, + const float* __restrict__ global_scale_ptr, + uint8_t* __restrict__ fp4_output, + uint8_t* __restrict__ block_scales, + const int M, + const int K +) { + const int row = blockIdx.x; + const int tid = threadIdx.x; + + if (row >= M) return; + + // Load global_scale from device memory once per block + __shared__ float global_scale; + + if (tid == 0) { + global_scale = *global_scale_ptr; + } + __syncthreads(); + + constexpr int CVT_ELTS_PER_THREAD = 8; + constexpr int PACKED_PER_THREAD = CVT_ELTS_PER_THREAD / 2; + const int elts_per_block = BLOCK_SIZE * CVT_ELTS_PER_THREAD; + + const half* row_input = input + row * K; + uint8_t* row_fp4 = fp4_output + row * (K / 2); + uint8_t* row_scales = block_scales + row * (K / 16); + + for (int base_col = 0; base_col < K; base_col += elts_per_block) { + const int col_start = base_col + tid * CVT_ELTS_PER_THREAD; + + if (col_start >= K) break; + + // Load 8 FP16 values as 4 x FP16x2 + half2 vec[4]; + + #pragma unroll + for (int i = 0; i < PACKED_PER_THREAD; ++i) { + const int col = col_start + i * 2; + if (col + 1 < K) { + vec[i] = *reinterpret_cast(&row_input[col]); + } else if (col < K) { + vec[i].x = row_input[col]; + vec[i].y = __float2half(0.0f); + } else { + vec[i] = __float2half2_rn(0.0f); + } + } + + // Quantize to E2M1 with block scaling + const int block_idx = col_start / 16; + uint8_t* scale_out = (tid % 2 == 0) ? &row_scales[block_idx] : nullptr; + + uint32_t e2m1_vals = quantize_fp16_to_e2m1_with_scaling(vec, global_scale, scale_out); + + // Pack into output (4 bytes = 8 FP4 values) + const int packed_idx = col_start / 2; + if (packed_idx + 3 < K / 2) { + *reinterpret_cast(&row_fp4[packed_idx]) = e2m1_vals; + } else { + // Handle boundary case + uint8_t* bytes = reinterpret_cast(&e2m1_vals); + for (int i = 0; i < 4 && packed_idx + i < K / 2; ++i) { + row_fp4[packed_idx + i] = bytes[i]; + } + } + } +} + +std::vector nvfp4_quant_from_bf16_cuda( + torch::Tensor input, + torch::Tensor global_scale +) { + TORCH_CHECK(input.is_cuda(), "input must be CUDA tensor"); + TORCH_CHECK(input.dtype() == torch::kBFloat16, "input must be bfloat16"); + TORCH_CHECK(global_scale.is_cuda(), "global_scale must be CUDA tensor"); + TORCH_CHECK(global_scale.dtype() == torch::kFloat32, "global_scale must be float32"); + + const int M = input.size(0); + const int K = input.size(1); + + TORCH_CHECK(K % 16 == 0, "K dimension must be divisible by 16"); + + auto fp4_output = torch::empty({M, K / 2}, torch::TensorOptions() + .dtype(torch::kUInt8).device(input.device())); + + auto block_scales = torch::empty({M, K / 16}, torch::TensorOptions() + .dtype(torch::kUInt8).device(input.device())); + + constexpr int BLOCK_SIZE = 128; + dim3 grid(M); + dim3 block(BLOCK_SIZE); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + nvfp4_quant_from_bf16_kernel<<>>( + reinterpret_cast(input.data_ptr()), + global_scale.data_ptr(), + fp4_output.data_ptr(), + block_scales.data_ptr(), + M, K + ); + + return {fp4_output, block_scales}; +} + +std::vector nvfp4_quant_from_fp16_cuda( + torch::Tensor input, + torch::Tensor global_scale +) { + TORCH_CHECK(input.is_cuda(), "input must be CUDA tensor"); + TORCH_CHECK(input.dtype() == torch::kFloat16, "input must be float16"); + TORCH_CHECK(global_scale.is_cuda(), "global_scale must be CUDA tensor"); + TORCH_CHECK(global_scale.dtype() == torch::kFloat32, "global_scale must be float32"); + + const int M = input.size(0); + const int K = input.size(1); + + TORCH_CHECK(K % 16 == 0, "K dimension must be divisible by 16"); + + auto fp4_output = torch::empty({M, K / 2}, torch::TensorOptions() + .dtype(torch::kUInt8).device(input.device())); + + auto block_scales = torch::empty({M, K / 16}, torch::TensorOptions() + .dtype(torch::kUInt8).device(input.device())); + + constexpr int BLOCK_SIZE = 128; + dim3 grid(M); + dim3 block(BLOCK_SIZE); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + nvfp4_quant_from_fp16_kernel<<>>( + reinterpret_cast(input.data_ptr()), + global_scale.data_ptr(), + fp4_output.data_ptr(), + block_scales.data_ptr(), + M, K + ); + + return {fp4_output, block_scales}; +} +""" + + CPP_SOURCE = r""" +#include + +std::vector nvfp4_quant_from_bf16_cuda( + torch::Tensor input, + torch::Tensor global_scale +); + +std::vector nvfp4_quant_from_fp16_cuda( + torch::Tensor input, + torch::Tensor global_scale +); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("nvfp4_quant_from_bf16", &nvfp4_quant_from_bf16_cuda, "NVFP4 quantize from BF16 (SM100)"); + m.def("nvfp4_quant_from_fp16", &nvfp4_quant_from_fp16_cuda, "NVFP4 quantize from FP16 (SM100)"); +} +""" + + NVFP4QuantizeUtil._nvfp4_quant_sm100_module = load_inline( + name="nvfp4_quant_sm100", + cpp_sources=[CPP_SOURCE], + cuda_sources=[CUDA_SOURCE], + extra_cuda_cflags=[ + "-O3", + "--use_fast_math", + "-std=c++17", + "-DENABLE_BF16", + "--expt-relaxed-constexpr", + "-gencode=arch=compute_100a,code=sm_100a", + "-gencode=arch=compute_120a,code=sm_120a", + ], + verbose=False, + with_cuda=True, + ) + + return NVFP4QuantizeUtil._nvfp4_quant_sm100_module diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 0afbb15fd7e8..202930fa484d 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -45,6 +45,7 @@ quantize_k_cache, quantize_k_cache_separate, ) +from sglang.srt.layers.quantization.kvfp4_tensor import NVFP4QuantizeUtil from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.utils import ( get_mla_kv_buffer_triton, @@ -52,7 +53,14 @@ set_mla_kv_buffer_triton, set_mla_kv_scale_buffer_triton, ) -from sglang.srt.utils import is_cuda, is_hip, is_npu, next_power_of_2 +from sglang.srt.utils import ( + is_cuda, + is_float4_e2m1fn_x2, + is_hip, + is_npu, + next_power_of_2, +) +from sglang.srt.utils.common import is_sm100_supported, is_sm120_supported from sglang.srt.utils.custom_op import register_custom_op from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter @@ -1020,6 +1028,188 @@ def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor): ) +class MHATokenToKVPoolNVFP4(MHATokenToKVPool): + + def _create_buffers(self): + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.enable_custom_mem_pool + else nullcontext() + ): + # [size, head_num, head_dim] for each layer + # The padded slot 0 is used for writing dummy outputs from padded tokens. + m = self.size + self.page_size + n = self.head_num + k = self.head_dim + + scale_block_size = 16 + self.store_dtype = torch.uint8 + self.k_buffer = [ + torch.zeros( + (m, n, k // 2), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + self.v_buffer = [ + torch.zeros( + (m, n, k // 2), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + + self.k_scale_buffer = [ + torch.zeros( + (m, n, k // scale_block_size), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + self.v_scale_buffer = [ + torch.zeros( + (m, n, k // scale_block_size), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + + self.dq_dtype = torch.float8_e4m3fn + self.dq_k_buffer = torch.zeros( + (m, n, k), + dtype=self.dq_dtype, + device=self.device, + ) + self.dq_v_buffer = torch.zeros( + (m, n, k), + dtype=self.dq_dtype, + device=self.device, + ) + + def _clear_buffers(self): + del self.k_buffer + del self.v_buffer + del self.k_scale_buffer + del self.v_scale_buffer + del self.dq_k_buffer + del self.dq_v_buffer + + def _get_key_nvfp4_from_nvfp4_buffer(self, layer_id: int): + return ( + self.k_buffer[layer_id - self.start_layer], + self.k_scale_buffer[layer_id - self.start_layer].view(torch.float8_e4m3fn), + ) + + def _get_value_nvfp4_from_nvfp4_buffer(self, layer_id: int): + return ( + self.v_buffer[layer_id - self.start_layer], + self.v_scale_buffer[layer_id - self.start_layer].view(torch.float8_e4m3fn), + ) + + def get_fp4_value_buffer(self, layer_id: int): + return self._get_value_nvfp4_from_nvfp4_buffer(layer_id) + + def get_fp4_key_buffer(self, layer_id: int): + return self._get_key_nvfp4_from_nvfp4_buffer(layer_id) + + def _get_key_buffer(self, layer_id: int, k_global_scale: float): + # for internal use of referencing + cache_k_nope_fp4 = self.k_buffer[layer_id - self.start_layer].view(torch.uint8) + cache_k_nope_fp4_sf = self.k_scale_buffer[layer_id - self.start_layer].view( + torch.float8_e4m3fn + ) + + cache_k_nope_fp4_dequant = NVFP4QuantizeUtil.cuda_nvfp4_dequantize( + cache_k_nope_fp4, cache_k_nope_fp4_sf, k_global_scale + ) + return cache_k_nope_fp4_dequant + + def _get_value_buffer(self, layer_id: int, v_global_scale: float): + # for internal use of referencing + cache_v_nope_fp4 = self.v_buffer[layer_id - self.start_layer].view(torch.uint8) + cache_v_nope_fp4_sf = self.v_scale_buffer[layer_id - self.start_layer].view( + torch.float8_e4m3fn + ) + + cache_v_nope_fp4_dequant = NVFP4QuantizeUtil.cuda_nvfp4_dequantize( + cache_v_nope_fp4, cache_v_nope_fp4_sf, v_global_scale + ) + return cache_v_nope_fp4_dequant + + def get_kv_buffer(self, layer_id: int, scale_k: float, scale_v: float): + return self.get_key_buffer(layer_id, scale_k), self.get_value_buffer( + layer_id, scale_v + ) + + # cache_k and cache_v are in bf16 format + def set_kv_buffer( + self, + layer: RadixAttention, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + layer_id_override: Optional[int] = None, + ): + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + + if layer_id_override is not None: + layer_id = layer_id_override + else: + layer_id = layer.layer_id + + if is_sm100_supported() or is_sm120_supported(): + cache_k, cache_k_fp4_sf, _ = ( + NVFP4QuantizeUtil.cuda_nvfp4_quantize_blackwell(cache_k, k_scale) + ) + cache_v, cache_v_fp4_sf, _ = ( + NVFP4QuantizeUtil.cuda_nvfp4_quantize_blackwell(cache_v, v_scale) + ) + else: + cache_k, cache_k_fp4_sf, _ = NVFP4QuantizeUtil.batched_quantize( + cache_k, k_scale + ) + cache_v, cache_v_fp4_sf, _ = NVFP4QuantizeUtil.batched_quantize( + cache_v, v_scale + ) + + cache_k = cache_k.view(torch.uint8) + cache_v = cache_v.view(torch.uint8) + + cache_k_fp4_sf = cache_k_fp4_sf.view(torch.uint8) + cache_v_fp4_sf = cache_v_fp4_sf.view(torch.uint8) + + if get_is_capture_mode() and self.alt_stream is not None: + # Overlap the copy of K and V cache for small batch size + current_stream = self.device_module.current_stream() + self.alt_stream.wait_stream(current_stream) + self.k_buffer[layer_id - self.start_layer][loc] = cache_k + + self.k_scale_buffer[layer_id - self.start_layer][loc] = cache_k_fp4_sf + with self.device_module.stream(self.alt_stream): + self.v_buffer[layer_id - self.start_layer][loc] = cache_v + + self.v_scale_buffer[layer_id - self.start_layer][loc] = cache_v_fp4_sf + current_stream.wait_stream(self.alt_stream) + else: + self.k_buffer[layer_id - self.start_layer][loc] = cache_k + self.v_buffer[layer_id - self.start_layer][loc] = cache_v + + self.k_scale_buffer[layer_id - self.start_layer][loc] = cache_k_fp4_sf + self.v_scale_buffer[layer_id - self.start_layer][loc] = cache_v_fp4_sf + + def get_dq_kv_buffer( + self, + ): + return (self.dq_k_buffer, self.dq_v_buffer) + + class MHATokenToKVPoolFP4(MHATokenToKVPool): def _create_buffers(self): @@ -1070,6 +1260,9 @@ def _create_buffers(self): ) for _ in range(self.layer_num) ] + from sglang.srt.layers.quantization.kvfp4_tensor import KVFP4QuantizeUtil + + self.fp4_quant_util = KVFP4QuantizeUtil def _clear_buffers(self): del self.k_buffer @@ -1085,9 +1278,7 @@ def _get_key_buffer(self, layer_id: int): ) cache_k_nope_fp4_sf = self.k_scale_buffer[layer_id - self.start_layer] - from sglang.srt.layers.quantization.kvfp4_tensor import KVFP4QuantizeUtil - - cache_k_nope_fp4_dequant = KVFP4QuantizeUtil.batched_dequantize( + cache_k_nope_fp4_dequant = self.fp4_quant_util.cuda_nvfp4_dequantize( cache_k_nope_fp4, cache_k_nope_fp4_sf ) return cache_k_nope_fp4_dequant @@ -1101,9 +1292,7 @@ def _get_value_buffer(self, layer_id: int): ) cache_v_nope_fp4_sf = self.v_scale_buffer[layer_id - self.start_layer] - from sglang.srt.layers.quantization.kvfp4_tensor import KVFP4QuantizeUtil - - cache_v_nope_fp4_dequant = KVFP4QuantizeUtil.batched_dequantize( + cache_v_nope_fp4_dequant = self.fp4_quant_util.cuda_nvfp4_dequantize( cache_v_nope_fp4, cache_v_nope_fp4_sf ) return cache_v_nope_fp4_dequant @@ -1131,10 +1320,8 @@ def set_kv_buffer( if v_scale is not None: cache_v.div_(v_scale) - from sglang.srt.layers.quantization.kvfp4_tensor import KVFP4QuantizeUtil - - cache_k, cache_k_fp4_sf = KVFP4QuantizeUtil.batched_quantize(cache_k) - cache_v, cache_v_fp4_sf = KVFP4QuantizeUtil.batched_quantize(cache_v) + cache_k, cache_k_fp4_sf = self.fp4_quant_util.batched_quantize(cache_k) + cache_v, cache_v_fp4_sf = self.fp4_quant_util.batched_quantize(cache_v) if self.store_dtype != self.dtype: cache_k = cache_k.view(self.store_dtype) @@ -1198,9 +1385,17 @@ def __init__( self.use_mla = use_mla if not use_mla: - TokenToKVPoolClass = MHATokenToKVPool + if is_float4_e2m1fn_x2(dtype): + # TODO(Sam): Add a env flag to choose between NVFP4 and other FP4 + # TokenToKVPoolClass = MHATokenToKVPoolFP4 + TokenToKVPoolClass = MHATokenToKVPoolNVFP4 + else: + TokenToKVPoolClass = MHATokenToKVPool if _is_npu: + assert not is_float4_e2m1fn_x2( + dtype + ), "FP4 is not supported on NPU yet." from sglang.srt.hardware_backend.npu.memory_pool_npu import ( NPUMHATokenToKVPool, ) @@ -1273,18 +1468,37 @@ def _transfer_full_attention_id(self, layer_id: int): ) return self.full_attention_layer_id_mapping[layer_id] - def get_key_buffer(self, layer_id: int): + def get_key_buffer(self, layer_id: int, scale: Optional[float] = None): layer_id = self._transfer_full_attention_id(layer_id) - return self.full_kv_pool.get_key_buffer(layer_id) + return self.full_kv_pool.get_key_buffer(layer_id, scale) - def get_value_buffer(self, layer_id: int): + def get_value_buffer(self, layer_id: int, scale: Optional[float] = None): layer_id = self._transfer_full_attention_id(layer_id) - return self.full_kv_pool.get_value_buffer(layer_id) + return self.full_kv_pool.get_value_buffer(layer_id, scale) - def get_kv_buffer(self, layer_id: int): + def get_kv_buffer( + self, + layer_id: int, + ): layer_id = self._transfer_full_attention_id(layer_id) return self.full_kv_pool.get_kv_buffer(layer_id) + def get_fp4_value_buffer(self, layer_id: int): + layer_id = self._transfer_full_attention_id(layer_id) + return self.full_kv_pool._get_value_nvfp4_from_nvfp4_buffer(layer_id) + + def get_fp4_key_buffer(self, layer_id: int): + layer_id = self._transfer_full_attention_id(layer_id) + return self.full_kv_pool._get_key_nvfp4_from_nvfp4_buffer(layer_id) + + def get_dq_kv_buffer( + self, + ): + assert is_float4_e2m1fn_x2( + self.dtype + ), "get_dq_kv_buffer_and_page_table only available for FP4 KV pool" + return self.full_kv_pool.get_dq_kv_buffer() + @contextmanager def _transfer_id_context(self, layer: RadixAttention): diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index 6cfa91e87c9a..271678282b7b 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -17,7 +17,7 @@ HybridLinearKVPool, HybridReqToTokenPool, MHATokenToKVPool, - MHATokenToKVPoolFP4, + MHATokenToKVPoolNVFP4, MLATokenToKVPool, MLATokenToKVPoolFP4, NSATokenToKVPool, @@ -552,7 +552,9 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int): ) else: if is_float4_e2m1fn_x2(self.kv_cache_dtype): - self.token_to_kv_pool = MHATokenToKVPoolFP4( + # TODO(Sam): Add a env flag to choose between NVFP4 and other FP4 + # self.token_to_kv_pool = MHATokenToKVPoolFP4( + self.token_to_kv_pool = MHATokenToKVPoolNVFP4( self.max_total_num_tokens, page_size=self.page_size, dtype=self.kv_cache_dtype, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0026eea61a0f..4612b26c6ee0 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1802,9 +1802,28 @@ def _handle_attention_backend_compatibility(self): or self.decode_attention_backend == "trtllm_mha" or self.prefill_attention_backend == "trtllm_mha" ): - if not is_sm100_supported(): + # Check prefill backend + prefill_backend = ( + self.prefill_attention_backend + if self.prefill_attention_backend is not None + else self.attention_backend + ) + if prefill_backend == "trtllm_mha" and not is_sm100_supported(): + raise ValueError( + "TRTLLM MHA backend for prefill is only supported on Blackwell GPUs (SM100). Please use a different prefill backend." + ) + + # Check decode backend + decode_backend = ( + self.decode_attention_backend + if self.decode_attention_backend is not None + else self.attention_backend + ) + if decode_backend == "trtllm_mha" and not ( + is_sm90_supported() or is_sm100_supported() or is_sm120_supported() + ): raise ValueError( - "TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend." + "TRTLLM MHA backend for decode is only supported on Hopper (SM90), Blackwell (SM100) and (SM120) GPUs. Please use a different decode backend." ) if self.page_size not in [16, 32, 64]: