diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 0c8f832ed26b..7afa0a09ed7c 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -802,6 +802,7 @@ def forward_decode( k_rope: Optional[torch.Tensor] = None, cos_sin_cache: Optional[torch.Tensor] = None, is_neox: Optional[bool] = False, + llama_4_scaling: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Run forward for decode using TRTLLM MLA kernel.""" merge_query = q_rope is not None @@ -843,6 +844,11 @@ def forward_decode( # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function query = q.view(-1, layer.tp_q_head_num, layer.head_dim) + # Apply llama 4 scaling if provided + if llama_4_scaling is not None: + query = query.to(self.q_data_type) * llama_4_scaling + query = query.to(self.data_type) + # Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1 if query.dim() == 3: query = query.unsqueeze(1) @@ -903,6 +909,7 @@ def forward_extend( k_rope: Optional[torch.Tensor] = None, cos_sin_cache: Optional[torch.Tensor] = None, is_neox: Optional[bool] = False, + llama_4_scaling: Optional[torch.Tensor] = None, ) -> torch.Tensor: if ( @@ -955,6 +962,10 @@ def forward_extend( q = q.view(-1, layer.tp_q_head_num, layer.head_dim) + # Apply llama 4 scaling if provided + if llama_4_scaling is not None: + q *= llama_4_scaling + if ( forward_batch.forward_mode.is_target_verify() or forward_batch.forward_mode.is_draft_extend(include_v2=True) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 9dc9fed76524..1331fe916724 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1158,6 +1158,16 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: return 0.1 * mscale * math.log(scale) + 1.0 +def _get_llama_4_scaling( + original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor +) -> torch.Tensor: + scaling = 1 + scaling_beta * torch.log( + 1 + torch.floor(positions / original_max_position_embeddings) + ) + # Broadcast over num_heads and head_dim + return scaling[..., None, None] + + class DeepseekV2AttentionMLA(nn.Module): def __init__( @@ -1212,12 +1222,6 @@ def __init__( if rope_scaling: rope_scaling["rope_type"] = "deepseek_yarn" - self.llama_4_scaling = ( - config.to_dict()["llama_4_scaling"]["beta"] - if "llama_4_scaling" in config - else None - ) - # For tensor parallel attention if self.q_lora_rank is not None: self.fused_qkv_a_proj_with_mqa = ReplicatedLinear( @@ -1470,12 +1474,14 @@ def forward( hidden_states: torch.Tensor, forward_batch: ForwardBatch, zero_allocator: BumpAllocator, + llama_4_scaling: Optional[torch.Tensor], ): s = self.forward_prepare( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, zero_allocator=zero_allocator, + llama_4_scaling=llama_4_scaling, ) return self.forward_core(s) @@ -1485,6 +1491,7 @@ def forward_prepare( hidden_states: torch.Tensor, forward_batch: ForwardBatch, zero_allocator: BumpAllocator, + llama_4_scaling: Optional[torch.Tensor], ): if self.attn_mha.kv_b_proj is None: self.attn_mha.kv_b_proj = self.kv_b_proj @@ -1525,7 +1532,11 @@ def forward_prepare( elif attn_forward_method == AttnForwardMethod.MLA: if not self.is_mla_preprocess_enabled: inner_state = self.forward_absorb_prepare( - positions, hidden_states, forward_batch, zero_allocator + positions, + hidden_states, + forward_batch, + zero_allocator, + llama_4_scaling, ) else: # TODO(iforgetmyname): to be separated as a standalone func @@ -1756,6 +1767,7 @@ def forward_absorb_prepare( hidden_states: torch.Tensor, forward_batch: ForwardBatch, zero_allocator: BumpAllocator, + llama_4_scaling: Optional[torch.Tensor], ): from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode @@ -1933,6 +1945,7 @@ def forward_absorb_prepare( zero_allocator, positions, topk_indices, + llama_4_scaling, ) def forward_absorb_core( @@ -1945,6 +1958,7 @@ def forward_absorb_core( zero_allocator, positions, topk_indices, + llama_4_scaling, ): if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS: extra_args = {} @@ -1952,6 +1966,7 @@ def forward_absorb_core( extra_args = { "cos_sin_cache": self.rotary_emb.cos_sin_cache, "is_neox": self.rotary_emb.is_neox_style, + "llama_4_scaling": llama_4_scaling, } attn_output = self.attn_mqa( @@ -1982,6 +1997,10 @@ def forward_absorb_core( q = torch.cat([q_nope_out, q_pe], dim=-1) k = torch.cat([k_nope, k_pe], dim=-1) + # Apply llama 4 scaling if provided + if llama_4_scaling is not None: + q *= llama_4_scaling + attn_output = self.attn_mqa( q, k, @@ -2972,6 +2991,7 @@ def forward( residual: Optional[torch.Tensor], zero_allocator: BumpAllocator, gemm_output_zero_allocator: BumpAllocator = None, + llama_4_scaling: Optional[torch.Tensor] = None, ) -> torch.Tensor: quant_format = ( "mxfp4" @@ -3012,6 +3032,7 @@ def forward( hidden_states=hidden_states, forward_batch=forward_batch, zero_allocator=zero_allocator, + llama_4_scaling=llama_4_scaling, ) hidden_states, residual = self.layer_communicator.prepare_mlp( @@ -3230,6 +3251,9 @@ def __init__( ) self.layers_to_capture = [] + # llama_4_scaling: for supporting Mistral-Large-3 model + self.llama_4_scaling_config = getattr(config, "llama_4_scaling", None) + def get_input_embeddings(self) -> torch.Tensor: return self.embed_tokens @@ -3278,6 +3302,18 @@ def forward( if enable_prefill_cp(forward_batch, self.nsa_enable_prefill_cp): hidden_states = cp_split_and_rebuild_data(forward_batch, hidden_states) + # llama_4_scaling: for supporting Mistral-Large-3 model + # Compute llama 4 scaling once per forward pass if enabled + llama_4_scaling: Optional[torch.Tensor] = None + if self.llama_4_scaling_config is not None: + llama_4_scaling = _get_llama_4_scaling( + original_max_position_embeddings=self.llama_4_scaling_config[ + "original_max_position_embeddings" + ], + scaling_beta=self.llama_4_scaling_config["beta"], + positions=positions, + ) + normal_start_layer = self.start_layer normal_end_layer = self.end_layer if forward_batch.can_run_tbo: @@ -3301,6 +3337,7 @@ def forward( residual, zero_allocator, gemm_output_zero_allocator, + llama_4_scaling, ) if normal_end_layer != self.end_layer: