Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/sglang/srt/layers/attention/trtllm_mla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,7 @@ def forward_decode(
k_rope: Optional[torch.Tensor] = None,
cos_sin_cache: Optional[torch.Tensor] = None,
is_neox: Optional[bool] = False,
llama_4_scaling: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Run forward for decode using TRTLLM MLA kernel."""
merge_query = q_rope is not None
Expand Down Expand Up @@ -843,6 +844,11 @@ def forward_decode(
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)

# Apply llama 4 scaling if provided
if llama_4_scaling is not None:
query = query.to(self.q_data_type) * llama_4_scaling
query = query.to(self.data_type)

# Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
if query.dim() == 3:
query = query.unsqueeze(1)
Expand Down Expand Up @@ -903,6 +909,7 @@ def forward_extend(
k_rope: Optional[torch.Tensor] = None,
cos_sin_cache: Optional[torch.Tensor] = None,
is_neox: Optional[bool] = False,
llama_4_scaling: Optional[torch.Tensor] = None,
) -> torch.Tensor:

if (
Expand Down Expand Up @@ -955,6 +962,10 @@ def forward_extend(

q = q.view(-1, layer.tp_q_head_num, layer.head_dim)

# Apply llama 4 scaling if provided
if llama_4_scaling is not None:
q *= llama_4_scaling

if (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend(include_v2=True)
Expand Down
51 changes: 44 additions & 7 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,16 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0


def _get_llama_4_scaling(
original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor
) -> torch.Tensor:
scaling = 1 + scaling_beta * torch.log(
1 + torch.floor(positions / original_max_position_embeddings)
)
# Broadcast over num_heads and head_dim
return scaling[..., None, None]


class DeepseekV2AttentionMLA(nn.Module):

def __init__(
Expand Down Expand Up @@ -1212,12 +1222,6 @@ def __init__(
if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn"

self.llama_4_scaling = (
config.to_dict()["llama_4_scaling"]["beta"]
if "llama_4_scaling" in config
else None
)

# For tensor parallel attention
if self.q_lora_rank is not None:
self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
Expand Down Expand Up @@ -1470,12 +1474,14 @@ def forward(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
llama_4_scaling: Optional[torch.Tensor],
):
s = self.forward_prepare(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
llama_4_scaling=llama_4_scaling,
)
return self.forward_core(s)

Expand All @@ -1485,6 +1491,7 @@ def forward_prepare(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
llama_4_scaling: Optional[torch.Tensor],
):
if self.attn_mha.kv_b_proj is None:
self.attn_mha.kv_b_proj = self.kv_b_proj
Expand Down Expand Up @@ -1525,7 +1532,11 @@ def forward_prepare(
elif attn_forward_method == AttnForwardMethod.MLA:
if not self.is_mla_preprocess_enabled:
inner_state = self.forward_absorb_prepare(
positions, hidden_states, forward_batch, zero_allocator
positions,
hidden_states,
forward_batch,
zero_allocator,
llama_4_scaling,
)
else:
# TODO(iforgetmyname): to be separated as a standalone func
Expand Down Expand Up @@ -1756,6 +1767,7 @@ def forward_absorb_prepare(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
llama_4_scaling: Optional[torch.Tensor],
):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode

Expand Down Expand Up @@ -1933,6 +1945,7 @@ def forward_absorb_prepare(
zero_allocator,
positions,
topk_indices,
llama_4_scaling,
)

def forward_absorb_core(
Expand All @@ -1945,13 +1958,15 @@ def forward_absorb_core(
zero_allocator,
positions,
topk_indices,
llama_4_scaling,
):
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
extra_args = {}
if self._fuse_rope_for_trtllm_mla(forward_batch):
extra_args = {
"cos_sin_cache": self.rotary_emb.cos_sin_cache,
"is_neox": self.rotary_emb.is_neox_style,
"llama_4_scaling": llama_4_scaling,
}

attn_output = self.attn_mqa(
Expand Down Expand Up @@ -1982,6 +1997,10 @@ def forward_absorb_core(
q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1)

# Apply llama 4 scaling if provided
if llama_4_scaling is not None:
q *= llama_4_scaling

attn_output = self.attn_mqa(
q,
k,
Expand Down Expand Up @@ -2972,6 +2991,7 @@ def forward(
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
gemm_output_zero_allocator: BumpAllocator = None,
llama_4_scaling: Optional[torch.Tensor] = None,
) -> torch.Tensor:
quant_format = (
"mxfp4"
Expand Down Expand Up @@ -3012,6 +3032,7 @@ def forward(
hidden_states=hidden_states,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
llama_4_scaling=llama_4_scaling,
)

hidden_states, residual = self.layer_communicator.prepare_mlp(
Expand Down Expand Up @@ -3230,6 +3251,9 @@ def __init__(
)
self.layers_to_capture = []

# llama_4_scaling: for supporting Mistral-Large-3 model
self.llama_4_scaling_config = getattr(config, "llama_4_scaling", None)

def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens

Expand Down Expand Up @@ -3278,6 +3302,18 @@ def forward(
if enable_prefill_cp(forward_batch, self.nsa_enable_prefill_cp):
hidden_states = cp_split_and_rebuild_data(forward_batch, hidden_states)

# llama_4_scaling: for supporting Mistral-Large-3 model
# Compute llama 4 scaling once per forward pass if enabled
llama_4_scaling: Optional[torch.Tensor] = None
if self.llama_4_scaling_config is not None:
llama_4_scaling = _get_llama_4_scaling(
original_max_position_embeddings=self.llama_4_scaling_config[
"original_max_position_embeddings"
],
scaling_beta=self.llama_4_scaling_config["beta"],
positions=positions,
)

normal_start_layer = self.start_layer
normal_end_layer = self.end_layer
if forward_batch.can_run_tbo:
Expand All @@ -3301,6 +3337,7 @@ def forward(
residual,
zero_allocator,
gemm_output_zero_allocator,
llama_4_scaling,
)

if normal_end_layer != self.end_layer:
Expand Down
Loading