diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index c3f533b27fb..4e8543213a1 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -625,6 +625,7 @@ def forward_extend( save_kv_cache=True, # For multi-head latent attention q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, ): if k is not None: assert v is not None @@ -639,11 +640,11 @@ def forward_extend( layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) else: - forward_batch.token_to_kv_pool.set_kv_buffer( + forward_batch.token_to_kv_pool.set_mla_kv_buffer( layer, cache_loc, k, - v, + k_rope, ) # Use precomputed metadata across all layers @@ -887,6 +888,7 @@ def forward_decode( save_kv_cache=True, # For multi-head latent attention q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, ) -> torch.Tensor: if k is not None: assert v is not None @@ -901,11 +903,11 @@ def forward_decode( layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) else: - forward_batch.token_to_kv_pool.set_kv_buffer( + forward_batch.token_to_kv_pool.set_mla_kv_buffer( layer, cache_loc, k, - v, + k_rope, ) # Use precomputed metadata across all layers diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 741937a721d..c57ae2aae86 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -92,8 +92,11 @@ def forward( if k is not None: # For cross-layer sharing, kv can be None assert v is not None - k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) - v = v.view(-1, self.tp_v_head_num, self.v_head_dim) + if "k_rope" not in kwargs: + k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) + v = v.view(-1, self.tp_v_head_num, self.v_head_dim) + else: + k = k.view(-1, self.tp_k_head_num, self.v_head_dim) return forward_batch.attn_backend.forward( q, diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 615ec3a2bc1..b28ad55f98d 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -34,6 +34,8 @@ import numpy as np import psutil import torch +import triton +import triton.language as tl from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.utils import debug_timing, get_compiler_backend @@ -405,6 +407,72 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): dst_2[loc] = src_2.to(dtype).view(store_dtype) +@triton.jit +def set_mla_kv_buffer_kernel( + kv_buffer_ptr, + cache_k_nope_ptr, + cache_k_rope_ptr, + loc_ptr, + buffer_stride: tl.constexpr, + nope_stride: tl.constexpr, + rope_stride: tl.constexpr, + nope_dim: tl.constexpr, + rope_dim: tl.constexpr, + BLOCK: tl.constexpr, +): + pid_loc = tl.program_id(0) + pid_blk = tl.program_id(1) + + base = pid_blk * BLOCK + offs = base + tl.arange(0, BLOCK) + total_dim = nope_dim + rope_dim + mask = offs < total_dim + + loc = tl.load(loc_ptr + pid_loc) + dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs + + if base + BLOCK <= nope_dim: + src = tl.load( + cache_k_nope_ptr + pid_loc * nope_stride + offs, + mask=mask, + ) + else: + offs_rope = offs - nope_dim + src = tl.load( + cache_k_rope_ptr + pid_loc * rope_stride + offs_rope, + mask=mask, + ) + + tl.store(dst_ptr, src, mask=mask) + + +def set_mla_kv_buffer_triton( + kv_buffer: torch.Tensor, + loc: torch.Tensor, + cache_k_nope: torch.Tensor, + cache_k_rope: torch.Tensor, +): + nope_dim = cache_k_nope.shape[-1] + rope_dim = cache_k_rope.shape[-1] + total_dim = nope_dim + rope_dim + BLOCK = 128 + n_loc = loc.numel() + grid = (n_loc, triton.cdiv(total_dim, BLOCK)) + + set_mla_kv_buffer_kernel[grid]( + kv_buffer, + cache_k_nope, + cache_k_rope, + loc, + kv_buffer.stride(0), + cache_k_nope.stride(0), + cache_k_rope.stride(0), + nope_dim, + rope_dim, + BLOCK=BLOCK, + ) + + class MLATokenToKVPool(KVCache): def __init__( self, @@ -504,6 +572,25 @@ def set_kv_buffer( else: self.kv_buffer[layer_id][loc] = cache_k + def set_mla_kv_buffer( + self, + layer: RadixAttention, + loc: torch.Tensor, + cache_k_nope: torch.Tensor, + cache_k_rope: torch.Tensor, + ): + layer_id = layer.layer_id + if cache_k_nope.dtype != self.dtype: + cache_k_nope = cache_k_nope.to(self.dtype) + cache_k_rope = cache_k_rope.to(self.dtype) + if self.store_dtype != self.dtype: + cache_k_nope = cache_k_nope.view(self.store_dtype) + cache_k_rope = cache_k_rope.view(self.store_dtype) + + set_mla_kv_buffer_triton( + self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope + ) + def get_flat_data(self, indices): # prepare a large chunk of contiguous data for efficient transfer return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)]) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2c709439b2c..c24c098c029 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -757,14 +757,13 @@ def forward_absorb( q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - k = torch.cat([k_nope, k_pe], dim=-1) - if self.attention_backend == "fa3": attn_output = self.attn_mqa( - q_nope_out, k, k_nope, forward_batch, q_rope=q_pe + q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe ) else: q = torch.cat([q_nope_out, q_pe], dim=-1) + k = torch.cat([k_nope, k_pe], dim=-1) attn_output = self.attn_mqa(q, k, k_nope, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)