diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index 2e7c67758691..e8aaa00114a7 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -219,6 +219,7 @@ def forward_decode( layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, + **kwargs, ) -> torch.Tensor: """Run forward for decode using TRTLLM MHA kernel.""" cache_loc = forward_batch.out_cache_loc @@ -230,7 +231,7 @@ def forward_decode( q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) # shape conversion: - # [bs, page_size, num_kv_heads, head_dim] -> [bs, num_kv_heads, page_size, head_dim] + # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim] k_cache = k_cache.view( -1, self.page_size, layer.tp_k_head_num, layer.head_dim ).permute(0, 2, 1, 3) @@ -246,6 +247,8 @@ def forward_decode( if getattr(layer, "k_scale_float", None) is not None else 1.0 ) + # sink: additional value per head in the denominator of the softmax. + attention_sink = kwargs.get("sk", None) bmm1_scale = q_scale * k_scale * layer.scaling bmm2_scale = 1.0 @@ -262,6 +265,7 @@ def forward_decode( bmm2_scale=bmm2_scale, window_left=self.sliding_window_size, # TODO: add attention_sink operation or nvfp4 scale factor if needed + sinks=attention_sink, ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) @@ -274,6 +278,7 @@ def forward_extend( layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, + **kwargs, ): cache_loc = forward_batch.out_cache_loc if save_kv_cache and k is not None: @@ -281,6 +286,7 @@ def forward_extend( layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim] k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) k_cache = k_cache.view( -1, self.page_size, layer.tp_k_head_num, layer.head_dim @@ -292,6 +298,8 @@ def forward_extend( # TODO: bmm1_scale and bmm2_scale might require modification # TODO: Change once quantization is supported + # sink: additional value per head in the denominator of the softmax. + attention_sink = kwargs.get("sk", None) q_scale = 1.0 k_scale = ( layer.k_scale_float @@ -316,6 +324,7 @@ def forward_extend( cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k, window_left=self.sliding_window_size, # TODO: add attention_sink operation or nvfp4 scale factor if needed + sinks=attention_sink, ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index cf40c652bed8..d442949eb8c1 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -241,7 +241,7 @@ def __init__( ) self.sinks = nn.Parameter( - torch.empty(self.num_heads, dtype=params_dtype), requires_grad=False + torch.empty(self.num_heads, dtype=torch.float32), requires_grad=False ) self.o_proj = RowParallelLinear( @@ -804,7 +804,9 @@ def load_weights( if "sinks" in name: start = tp_rank * param.numel() param.data.copy_( - loaded_weight[start : start + param.numel()] + loaded_weight[start : start + param.numel()].to( + torch.float32 + ) ) else: weight_loader = getattr(