From 8e942ab2d4f3f7b23c19c8acb34ebd61818e41fa Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Thu, 18 Dec 2025 06:51:47 -0800 Subject: [PATCH 1/3] clean up trtllm mha backend and use NHD kv layout --- .../layers/attention/trtllm_mha_backend.py | 187 ++++++++---------- 1 file changed, 82 insertions(+), 105 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index ddd3a67eadc9..93524515ec22 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -538,36 +538,6 @@ def get_cuda_graph_seq_len_fill_value(self) -> int: """Get the fill value for sequence lengths in CUDA graph.""" return 1 - def _should_use_fused_fp8_path(self, save_kv_cache: bool, k: torch.Tensor) -> bool: - """Check if we should use the fused FP8 KV cache write path.""" - return save_kv_cache and k is not None and self.data_type == torch.float8_e4m3fn - - def _fused_fp8_set_kv_buffer( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - layer: RadixAttention, - forward_batch: ForwardBatch, - **kwargs, - ): - """Fused FP8 quantization and KV cache write.""" - cache_loc = self._get_layer_cache_loc(layer, forward_batch.out_cache_loc) - - # Get K/V cache buffers from token_to_kv_pool - k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - - fused_fp8_set_kv_buffer( - k=k, - v=v, - k_cache=k_cache, - v_cache=v_cache, - cache_loc=cache_loc, - k_scale=layer.k_scale, # May be None - v_scale=layer.v_scale, # May be None - page_size=self.page_size, - ) - def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize the metadata for a forward pass.""" @@ -692,24 +662,40 @@ def forward_decode( **kwargs, ) -> torch.Tensor: """Run forward for decode using TRTLLM MHA kernel.""" - cache_loc = forward_batch.out_cache_loc - - use_fused_fp8_path = self._should_use_fused_fp8_path(save_kv_cache, k) - - if use_fused_fp8_path: - # Use fused FP8 quantization + KV cache write path - self._fused_fp8_set_kv_buffer( - q=q, - k=k, - v=v, - layer=layer, - forward_batch=forward_batch, - ) - k = None - v = None - else: - # Use original set_kv_buffer path - if save_kv_cache and k is not None: + cache_loc = self._get_layer_cache_loc(layer, forward_batch.out_cache_loc) + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + + q_scale_float = 1.0 + k_scale_float = ( + layer.k_scale_float + if getattr(layer, "k_scale_float", None) is not None + else 1.0 + ) + v_scale_float = ( + layer.v_scale_float + if getattr(layer, "v_scale_float", None) is not None + else 1.0 + ) + + if save_kv_cache and k is not None: + if kwargs.get("cos_sin_cache") is not None: + raise NotImplementedError( + "Fused RoPE + FP8 quantization + KV cache update is not implemented yet" + ) + elif self.data_type == torch.float8_e4m3fn: + # Use fused FP8 quantization + KV cache update path + fused_fp8_set_kv_buffer( + k=k, + v=v, + k_cache=k_cache, + v_cache=v_cache, + cache_loc=cache_loc, + k_scale=layer.k_scale, + v_scale=layer.v_scale, + page_size=self.page_size, + ) + else: + # Use original set_kv_buffer path forward_batch.token_to_kv_pool.set_kv_buffer( layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) @@ -717,32 +703,17 @@ def forward_decode( if self.data_type == torch.float8_e4m3fn: q = q.to(torch.float8_e4m3fn) q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) - k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - # shape conversion: - # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim] - k_cache = k_cache.view( - -1, self.page_size, layer.tp_k_head_num, layer.head_dim - ).permute(0, 2, 1, 3) - v_cache = v_cache.view( - -1, self.page_size, layer.tp_v_head_num, layer.head_dim - ).permute(0, 2, 1, 3) + k_cache = k_cache.view(-1, self.page_size, layer.tp_k_head_num, layer.head_dim) + v_cache = v_cache.view(-1, self.page_size, layer.tp_v_head_num, layer.head_dim) if layer.tp_k_head_num == 1: k_cache = canonicalize_stride(k_cache) if layer.tp_v_head_num == 1: v_cache = canonicalize_stride(v_cache) - kv_cache = (k_cache, v_cache) - # TODO: add support for quantization - q_scale = 1.0 - k_scale = ( - layer.k_scale_float - if getattr(layer, "k_scale_float", None) is not None - else 1.0 - ) - bmm1_scale = q_scale * k_scale * layer.scaling - bmm2_scale = 1.0 + bmm1_scale = q_scale_float * k_scale_float * layer.scaling + bmm2_scale = v_scale_float # sink: additional value per head in the denominator of the softmax. attention_sink = kwargs.get("sinks", None) @@ -762,6 +733,7 @@ def forward_decode( window_left=layer.sliding_window_size, sinks=attention_sink, out_dtype=self.q_data_type, # model_runner.dtype + kv_layout="NHD", ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) @@ -773,27 +745,44 @@ def forward_extend( v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, - save_kv_cache=True, + save_kv_cache: bool = True, **kwargs, - ): - cache_loc = forward_batch.out_cache_loc - - use_fused_fp8_path = self._should_use_fused_fp8_path(save_kv_cache, k) - - if use_fused_fp8_path: - # Use fused FP8 quantization + KV cache write path - self._fused_fp8_set_kv_buffer( - q=q, - k=k, - v=v, - layer=layer, - forward_batch=forward_batch, - ) - k = None - v = None - else: - # Use original set_kv_buffer path - if save_kv_cache and k is not None: + ) -> torch.Tensor: + """Run forward for extend using TRTLLM MHA kernel.""" + cache_loc = self._get_layer_cache_loc(layer, forward_batch.out_cache_loc) + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + + q_scale_float = 1.0 + k_scale_float = ( + layer.k_scale_float + if getattr(layer, "k_scale_float", None) is not None + else 1.0 + ) + v_scale_float = ( + layer.v_scale_float + if getattr(layer, "v_scale_float", None) is not None + else 1.0 + ) + + if save_kv_cache and k is not None: + if kwargs.get("cos_sin_cache") is not None: + raise NotImplementedError( + "Fused RoPE + FP8 quantization + KV cache update is not implemented yet" + ) + elif self.data_type == torch.float8_e4m3fn: + # Use fused FP8 quantization + KV cache update path + fused_fp8_set_kv_buffer( + k=k, + v=v, + k_cache=k_cache, + v_cache=v_cache, + cache_loc=cache_loc, + k_scale=layer.k_scale, + v_scale=layer.v_scale, + page_size=self.page_size, + ) + else: + # Use original set_kv_buffer path forward_batch.token_to_kv_pool.set_kv_buffer( layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) @@ -801,33 +790,19 @@ def forward_extend( if self.data_type == torch.float8_e4m3fn: q = q.to(torch.float8_e4m3fn) q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) - # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim] - k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - k_cache = k_cache.view( - -1, self.page_size, layer.tp_k_head_num, layer.head_dim - ).permute(0, 2, 1, 3) - v_cache = v_cache.view( - -1, self.page_size, layer.tp_v_head_num, layer.head_dim - ).permute(0, 2, 1, 3) + k_cache = k_cache.view(-1, self.page_size, layer.tp_k_head_num, layer.head_dim) + v_cache = v_cache.view(-1, self.page_size, layer.tp_v_head_num, layer.head_dim) if layer.tp_k_head_num == 1: k_cache = canonicalize_stride(k_cache) if layer.tp_v_head_num == 1: v_cache = canonicalize_stride(v_cache) - kv_cache = (k_cache, v_cache) + bmm1_scale = q_scale_float * k_scale_float * layer.scaling + bmm2_scale = v_scale_float # sink: additional value per head in the denominator of the softmax. attention_sink = kwargs.get("sinks", None) - # TODO: add support for quantization - q_scale = 1.0 - k_scale = ( - layer.k_scale_float - if getattr(layer, "k_scale_float", None) is not None - else 1.0 - ) - bmm1_scale = q_scale * k_scale * layer.scaling - bmm2_scale = 1.0 page_table = self._get_layer_page_table(layer, forward_batch) @@ -845,6 +820,7 @@ def forward_extend( sinks=attention_sink, out_dtype=self.q_data_type, # model_runner.dtype q_len_per_req=self.forward_metadata.max_seq_len_q, + kv_layout="NHD", ) else: o = flashinfer.prefill.trtllm_batch_context_with_kv_cache( @@ -863,6 +839,7 @@ def forward_extend( window_left=layer.sliding_window_size, sinks=attention_sink, out_dtype=self.q_data_type, # model_runner.dtype + kv_layout="NHD", ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) From 91e1a1b511bff1ea21288ecdc7d492b9a79ef673 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Fri, 19 Dec 2025 12:15:22 -0800 Subject: [PATCH 2/3] support Flashinfer rope_quantize_fp8 + append_paged_kv_cache for GPT-OSS --- .../srt/layers/attention/base_attn_backend.py | 4 + .../layers/attention/trtllm_mha_backend.py | 238 +++++++++++++++++- python/sglang/srt/models/gpt_oss.py | 13 +- 3 files changed, 246 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index 8d14e32a916b..7e1bd5b55527 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -167,3 +167,7 @@ def get_indexer_metadata( ) -> Optional[BaseIndexerMetadata]: """Get the indexer metadata. None means don't support indexer.""" return None + + def support_rope_fusion(self) -> bool: + """Check if the current backend supports RoPE fusion.""" + return False diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index 93524515ec22..eb31a60f7468 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -19,7 +19,10 @@ from sglang.srt.layers.attention.triton_ops.trtllm_fp8_kv_kernel import ( fused_fp8_set_kv_buffer, ) -from sglang.srt.layers.attention.utils import canonicalize_stride +from sglang.srt.layers.attention.utils import ( + canonicalize_stride, + create_flashinfer_kv_indices_triton, +) from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool, SWATokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available @@ -60,6 +63,18 @@ class TRTLLMMHAMetadata: # Page table for SWA layers (translated from full pool indices to SWA pool indices) swa_page_table: torch.Tensor = None + # The following fields are used for Flashinfer KV update + # kv_indices: page indices for all requests + kv_indices: torch.Tensor = None + # kv_indptr: cumulative page count per request + kv_indptr: torch.Tensor = None + # batch_indices: which request each token belongs to + batch_indices: torch.Tensor = None + # positions: position of each token within its sequence + positions: torch.Tensor = None + # kv_last_page_len: number of valid entries in last page per request + kv_last_page_len: torch.Tensor = None + class TRTLLMHAAttnBackend(FlashInferAttnBackend): """TRTLLM MHA attention kernel from flashinfer.""" @@ -222,6 +237,11 @@ def init_cuda_graph_state( ), } + if self.support_rope_fusion(): + self.decode_cuda_graph_metadata["kv_indices"] = torch.zeros( + max_bs * max_num_pages, dtype=torch.int32, device=self.device + ) + if ( self.speculative_num_draft_tokens is not None and self.speculative_num_draft_tokens > 0 @@ -292,6 +312,14 @@ def init_cuda_graph_state( ), } + if self.support_rope_fusion(): + self.target_verify_metadata["kv_indices"] = torch.zeros( + max_bs * max_num_pages, dtype=torch.int32, device=self.device + ) + self.draft_extend_metadata["kv_indices"] = torch.zeros( + max_bs * max_num_pages, dtype=torch.int32, device=self.device + ) + def init_forward_metadata_capture_cuda_graph( self, bs: int, @@ -427,6 +455,11 @@ def init_forward_metadata_capture_cuda_graph( ) self.draft_extend_metadata[bs] = metadata + + if self.support_rope_fusion(): + metadata.kv_indices = self.decode_cuda_graph_metadata["kv_indices"] + self._init_forward_metadata_for_rope_fusion(metadata, bs, num_tokens) + self.forward_metadata = metadata def init_forward_metadata_replay_cuda_graph( @@ -532,6 +565,24 @@ def init_forward_metadata_replay_cuda_graph( ] metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size) self._copy_swa_page_table(metadata, page_indices, max_seq_pages) + + if self.support_rope_fusion(): + # Compute total number of tokens + if forward_mode.is_decode_or_idle(): + if spec_info is not None: + total_num_tokens = bs * spec_info.num_draft_tokens + else: + total_num_tokens = bs + elif forward_mode.is_target_verify(): + total_num_tokens = bs * self.speculative_num_draft_tokens + else: # draft_extend + total_num_tokens = metadata.cu_seqlens_q[-1].item() + + metadata.kv_indices = self.decode_cuda_graph_metadata["kv_indices"] + self._init_forward_metadata_for_rope_fusion( + metadata, bs, total_num_tokens, update_inplace=True + ) + self.forward_metadata = metadata def get_cuda_graph_seq_len_fill_value(self) -> int: @@ -605,7 +656,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k ] - else: metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() @@ -649,8 +699,112 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): metadata.swa_page_table[:, self.strided_indices] // self.page_size ) + if self.support_rope_fusion(): + # Compute total number of tokens + if forward_batch.forward_mode.is_decode_or_idle(): + if forward_batch.spec_info is not None: + total_num_tokens = ( + batch_size * forward_batch.spec_info.num_draft_tokens + ) + else: + total_num_tokens = batch_size + elif forward_batch.forward_mode.is_target_verify(): + total_num_tokens = batch_size * self.speculative_num_draft_tokens + else: + total_num_tokens = forward_batch.extend_num_tokens + + self._init_forward_metadata_for_rope_fusion( + metadata, batch_size, total_num_tokens + ) + self.forward_metadata = metadata + def _init_forward_metadata_for_rope_fusion( + self, + metadata: TRTLLMMHAMetadata, + batch_size: int, + total_num_tokens: int, + update_inplace: bool = False, + ): + """ + Initialize the following metadata for RoPE + FP8 quantization + KV cache update kernel. + - kv_indptr: cumulative page count per request + - batch_indices: which request each token belongs to + - positions: position of each token within its sequence + - kv_last_page_len: number of valid entries in last page per request + - kv_indices: page indices for all requests + """ + + # Compute number of pages per request + num_pages_per_req = ( + metadata.cache_seqlens_int32 + self.page_size - 1 + ) // self.page_size + + # kv_indptr + if update_inplace: + metadata.kv_indptr[1:].copy_( + torch.cumsum(num_pages_per_req, dim=0, dtype=torch.int32) + ) + else: + metadata.kv_indptr = torch.nn.functional.pad( + torch.cumsum(num_pages_per_req, dim=0, dtype=torch.int32), (1, 0) + ) + + # batch_indices and positions + if update_inplace: + flashinfer.page.get_batch_indices_positions( + metadata.cu_seqlens_q, + metadata.cache_seqlens_int32, + total_num_tokens, + metadata.batch_indices, + metadata.positions, + ) + else: + metadata.batch_indices, metadata.positions = ( + flashinfer.page.get_batch_indices_positions( + metadata.cu_seqlens_q, + metadata.cache_seqlens_int32, + total_num_tokens, + ) + ) + + # kv_last_page_len + if update_inplace: + metadata.kv_last_page_len.copy_( + metadata.cache_seqlens_int32 - (num_pages_per_req - 1) * self.page_size + ) + else: + metadata.kv_last_page_len = ( + metadata.cache_seqlens_int32 - (num_pages_per_req - 1) * self.page_size + ) + + # kv_indices + device = metadata.kv_indptr.device + if metadata.kv_indices is not None: + # decode + assert ( + metadata.kv_indices.size(0) >= metadata.kv_indptr[-1] + ), f"kv_indices.size(0)={metadata.kv_indices.size(0)} < kv_indptr[-1]={metadata.kv_indptr[-1]}" + metadata.kv_indices = metadata.kv_indices[: metadata.kv_indptr[-1]] + else: + # extend + metadata.kv_indices = torch.zeros( + metadata.kv_indptr[-1], dtype=torch.int32, device=device + ) + create_flashinfer_kv_indices_triton[(batch_size,)]( + metadata.page_table, + torch.arange(batch_size, dtype=torch.int32, device=device), + num_pages_per_req, + metadata.kv_indptr, + None, + metadata.kv_indices, + metadata.page_table.stride(0), + ) + + def support_rope_fusion(self) -> bool: + """Check if supports RoPE fusion.""" + return self.data_type == torch.float8_e4m3fn + def forward_decode( self, q: torch.Tensor, @@ -678,9 +832,43 @@ def forward_decode( ) if save_kv_cache and k is not None: - if kwargs.get("cos_sin_cache") is not None: - raise NotImplementedError( - "Fused RoPE + FP8 quantization + KV cache update is not implemented yet" + if self.support_rope_fusion() and kwargs.get("rotary_emb") is not None: + # Use fused RoPE + FP8 quantization + KV cache update path + rotary_emb = kwargs.get("rotary_emb") + + q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + k_cache = k_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ) + v_cache = v_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim + ) + + q, k = flashinfer.rope.rope_quantize_fp8( + q_rope=q, + k_rope=k, + q_nope=None, + k_nope=None, + cos_sin_cache=rotary_emb.cos_sin_cache, + pos_ids=forward_batch.positions, + is_neox=rotary_emb.is_neox_style, + quantize_dtype=torch.float8_e4m3fn, + quant_scale_q=q_scale_float, + quant_scale_kv=k_scale_float, + )[:2] + + v = v.div_(v_scale_float).to(torch.float8_e4m3fn) + + flashinfer.page.append_paged_kv_cache( + append_key=k, + append_value=v, + batch_indices=self.forward_metadata.batch_indices, + positions=self.forward_metadata.positions, + paged_kv_cache=(k_cache, v_cache), + kv_indices=self.forward_metadata.kv_indices, + kv_indptr=self.forward_metadata.kv_indptr, + kv_last_page_len=self.forward_metadata.kv_last_page_len, + kv_layout="NHD", ) elif self.data_type == torch.float8_e4m3fn: # Use fused FP8 quantization + KV cache update path @@ -765,9 +953,43 @@ def forward_extend( ) if save_kv_cache and k is not None: - if kwargs.get("cos_sin_cache") is not None: - raise NotImplementedError( - "Fused RoPE + FP8 quantization + KV cache update is not implemented yet" + if self.support_rope_fusion() and kwargs.get("rotary_emb") is not None: + # Use fused RoPE + FP8 quantization + KV cache update path + rotary_emb = kwargs.get("rotary_emb") + + q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + k_cache = k_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ) + v_cache = v_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim + ) + + q, k = flashinfer.rope.rope_quantize_fp8( + q_rope=q, + k_rope=k, + q_nope=None, + k_nope=None, + cos_sin_cache=rotary_emb.cos_sin_cache, + pos_ids=forward_batch.positions, + is_neox=rotary_emb.is_neox_style, + quantize_dtype=torch.float8_e4m3fn, + quant_scale_q=q_scale_float, + quant_scale_kv=k_scale_float, + )[:2] + + v = v.div_(v_scale_float).to(torch.float8_e4m3fn) + + flashinfer.page.append_paged_kv_cache( + append_key=k, + append_value=v, + batch_indices=self.forward_metadata.batch_indices, + positions=self.forward_metadata.positions, + paged_kv_cache=(k_cache, v_cache), + kv_indices=self.forward_metadata.kv_indices, + kv_indptr=self.forward_metadata.kv_indptr, + kv_last_page_len=self.forward_metadata.kv_last_page_len, + kv_layout="NHD", ) elif self.data_type == torch.float8_e4m3fn: # Use fused FP8 quantization + KV cache update path diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 96caaa65b57c..941d6836b58c 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -332,7 +332,9 @@ def forward_prepare( else None ), } - q, k = self.rotary_emb(positions, q, k, **extra_args) + # Skip RoPE if current attn backend supports RoPE fusion + if not forward_batch.attn_backend.support_rope_fusion(): + q, k = self.rotary_emb(positions, q, k, **extra_args) inner_state = q, k, v, forward_batch return None, forward_batch, inner_state @@ -340,10 +342,19 @@ def forward_core(self, intermediate_state): hidden_states, forward_batch, inner_state = intermediate_state if inner_state is None: return hidden_states + + # Pass RoPE into attn backend if it supports RoPE fusion + extra_args = {} + if forward_batch.attn_backend.support_rope_fusion(): + extra_args = { + "rotary_emb": self.rotary_emb, + } + attn_output = self.attn( *inner_state, sinks=self.sinks, save_kv_cache=not enable_fused_set_kv_buffer(forward_batch), + **extra_args, ) output, _ = self.o_proj(attn_output) return output From e59267fb52cdde2e8b52f296b3481ff7376d41c3 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Mon, 22 Dec 2025 08:49:54 -0800 Subject: [PATCH 3/3] support Flashinfer rope_quantize_fp8_append_paged_kv_cache --- .../layers/attention/trtllm_mha_backend.py | 92 ++++++++----------- 1 file changed, 36 insertions(+), 56 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index eb31a60f7468..8339fe302279 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -72,8 +72,6 @@ class TRTLLMMHAMetadata: batch_indices: torch.Tensor = None # positions: position of each token within its sequence positions: torch.Tensor = None - # kv_last_page_len: number of valid entries in last page per request - kv_last_page_len: torch.Tensor = None class TRTLLMHAAttnBackend(FlashInferAttnBackend): @@ -457,7 +455,12 @@ def init_forward_metadata_capture_cuda_graph( self.draft_extend_metadata[bs] = metadata if self.support_rope_fusion(): - metadata.kv_indices = self.decode_cuda_graph_metadata["kv_indices"] + if forward_mode.is_decode_or_idle(): + metadata.kv_indices = self.decode_cuda_graph_metadata["kv_indices"] + elif forward_mode.is_target_verify(): + metadata.kv_indices = self.target_verify_metadata["kv_indices"] + else: + metadata.kv_indices = self.draft_extend_metadata["kv_indices"] self._init_forward_metadata_for_rope_fusion(metadata, bs, num_tokens) self.forward_metadata = metadata @@ -573,12 +576,14 @@ def init_forward_metadata_replay_cuda_graph( total_num_tokens = bs * spec_info.num_draft_tokens else: total_num_tokens = bs + metadata.kv_indices = self.decode_cuda_graph_metadata["kv_indices"] elif forward_mode.is_target_verify(): total_num_tokens = bs * self.speculative_num_draft_tokens + metadata.kv_indices = self.target_verify_metadata["kv_indices"] else: # draft_extend total_num_tokens = metadata.cu_seqlens_q[-1].item() + metadata.kv_indices = self.draft_extend_metadata["kv_indices"] - metadata.kv_indices = self.decode_cuda_graph_metadata["kv_indices"] self._init_forward_metadata_for_rope_fusion( metadata, bs, total_num_tokens, update_inplace=True ) @@ -731,7 +736,6 @@ def _init_forward_metadata_for_rope_fusion( - kv_indptr: cumulative page count per request - batch_indices: which request each token belongs to - positions: position of each token within its sequence - - kv_last_page_len: number of valid entries in last page per request - kv_indices: page indices for all requests """ @@ -768,28 +772,16 @@ def _init_forward_metadata_for_rope_fusion( ) ) - # kv_last_page_len - if update_inplace: - metadata.kv_last_page_len.copy_( - metadata.cache_seqlens_int32 - (num_pages_per_req - 1) * self.page_size - ) - else: - metadata.kv_last_page_len = ( - metadata.cache_seqlens_int32 - (num_pages_per_req - 1) * self.page_size - ) - # kv_indices device = metadata.kv_indptr.device - if metadata.kv_indices is not None: - # decode - assert ( - metadata.kv_indices.size(0) >= metadata.kv_indptr[-1] - ), f"kv_indices.size(0)={metadata.kv_indices.size(0)} < kv_indptr[-1]={metadata.kv_indptr[-1]}" - metadata.kv_indices = metadata.kv_indices[: metadata.kv_indptr[-1]] - else: - # extend + if metadata.kv_indices is None: + # The max number of pages is kv_indptr[-1] + # Use upper bound to avoid D2H transfer from accessing kv_indptr[-1] + max_pages_per_req = ( + metadata.max_seq_len_k + self.page_size - 1 + ) // self.page_size metadata.kv_indices = torch.zeros( - metadata.kv_indptr[-1], dtype=torch.int32, device=device + batch_size * max_pages_per_req, dtype=torch.int32, device=device ) create_flashinfer_kv_indices_triton[(batch_size,)]( metadata.page_table, @@ -836,7 +828,7 @@ def forward_decode( # Use fused RoPE + FP8 quantization + KV cache update path rotary_emb = kwargs.get("rotary_emb") - q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + q = q.view(-1, layer.tp_q_head_num, layer.head_dim) k_cache = k_cache.view( -1, self.page_size, layer.tp_k_head_num, layer.head_dim ) @@ -844,32 +836,26 @@ def forward_decode( -1, self.page_size, layer.tp_v_head_num, layer.head_dim ) - q, k = flashinfer.rope.rope_quantize_fp8( + q = flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( q_rope=q, k_rope=k, q_nope=None, k_nope=None, + v=v, cos_sin_cache=rotary_emb.cos_sin_cache, pos_ids=forward_batch.positions, + paged_kv_cache=(k_cache, v_cache), + kv_indices=self.forward_metadata.kv_indices, + kv_indptr=self.forward_metadata.kv_indptr, + batch_indices=self.forward_metadata.batch_indices, + positions=self.forward_metadata.positions, is_neox=rotary_emb.is_neox_style, quantize_dtype=torch.float8_e4m3fn, quant_scale_q=q_scale_float, quant_scale_kv=k_scale_float, - )[:2] - - v = v.div_(v_scale_float).to(torch.float8_e4m3fn) - - flashinfer.page.append_paged_kv_cache( - append_key=k, - append_value=v, - batch_indices=self.forward_metadata.batch_indices, - positions=self.forward_metadata.positions, - paged_kv_cache=(k_cache, v_cache), - kv_indices=self.forward_metadata.kv_indices, - kv_indptr=self.forward_metadata.kv_indptr, - kv_last_page_len=self.forward_metadata.kv_last_page_len, + page_size=self.page_size, kv_layout="NHD", - ) + )[0] elif self.data_type == torch.float8_e4m3fn: # Use fused FP8 quantization + KV cache update path fused_fp8_set_kv_buffer( @@ -957,7 +943,7 @@ def forward_extend( # Use fused RoPE + FP8 quantization + KV cache update path rotary_emb = kwargs.get("rotary_emb") - q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + q = q.view(-1, layer.tp_q_head_num, layer.head_dim) k_cache = k_cache.view( -1, self.page_size, layer.tp_k_head_num, layer.head_dim ) @@ -965,32 +951,26 @@ def forward_extend( -1, self.page_size, layer.tp_v_head_num, layer.head_dim ) - q, k = flashinfer.rope.rope_quantize_fp8( + q = flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( q_rope=q, k_rope=k, q_nope=None, k_nope=None, + v=v, cos_sin_cache=rotary_emb.cos_sin_cache, pos_ids=forward_batch.positions, + paged_kv_cache=(k_cache, v_cache), + kv_indices=self.forward_metadata.kv_indices, + kv_indptr=self.forward_metadata.kv_indptr, + batch_indices=self.forward_metadata.batch_indices, + positions=self.forward_metadata.positions, is_neox=rotary_emb.is_neox_style, quantize_dtype=torch.float8_e4m3fn, quant_scale_q=q_scale_float, quant_scale_kv=k_scale_float, - )[:2] - - v = v.div_(v_scale_float).to(torch.float8_e4m3fn) - - flashinfer.page.append_paged_kv_cache( - append_key=k, - append_value=v, - batch_indices=self.forward_metadata.batch_indices, - positions=self.forward_metadata.positions, - paged_kv_cache=(k_cache, v_cache), - kv_indices=self.forward_metadata.kv_indices, - kv_indptr=self.forward_metadata.kv_indptr, - kv_last_page_len=self.forward_metadata.kv_last_page_len, + page_size=self.page_size, kv_layout="NHD", - ) + )[0] elif self.data_type == torch.float8_e4m3fn: # Use fused FP8 quantization + KV cache update path fused_fp8_set_kv_buffer(