From 90098a5388c7ee783667b147161ac5e1b95f2398 Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Mon, 17 Nov 2025 09:19:44 +0000 Subject: [PATCH 01/30] Add Mistral Large 3 support. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- .../kernels/fused_moe_triton/common_utils.py | 3 +- python/sglang/srt/configs/model_config.py | 6 + .../srt/layers/moe/fused_moe_triton/layer.py | 12 +- python/sglang/srt/layers/quantization/fp8.py | 182 ++++-- .../srt/layers/quantization/modelopt_quant.py | 30 +- python/sglang/srt/models/deepseek_v2.py | 46 +- python/sglang/srt/models/mistral_large_3.py | 76 +++ .../srt/models/mistral_large_3_eagle.py | 104 ++++ python/sglang/srt/models/pixtral.py | 568 +++++++++++++++++- .../srt/multimodal/processors/pixtral.py | 12 +- python/sglang/srt/server_args.py | 14 +- .../sglang/srt/utils/hf_transformers_utils.py | 52 +- python/sglang/srt/utils/mistral_utils.py | 303 ++++++++++ sgl-router/src/routers/router_manager.rs | 30 +- 14 files changed, 1347 insertions(+), 91 deletions(-) create mode 100644 python/sglang/srt/models/mistral_large_3.py create mode 100644 python/sglang/srt/models/mistral_large_3_eagle.py create mode 100644 python/sglang/srt/utils/mistral_utils.py diff --git a/benchmark/kernels/fused_moe_triton/common_utils.py b/benchmark/kernels/fused_moe_triton/common_utils.py index 40f8697f1e4e..dd8b09f8a79b 100644 --- a/benchmark/kernels/fused_moe_triton/common_utils.py +++ b/benchmark/kernels/fused_moe_triton/common_utils.py @@ -73,11 +73,12 @@ def get_model_config( "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", "Glm4MoeForCausalLM", + "MistralLarge3ForCausalLM", ]: E = (config.n_routed_experts // ep_size) + ( 0 if disable_shared_experts_fusion - or architecture not in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"] + or architecture not in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM", "MistralLarge3ForCausalLM"] else 1 ) topk = config.num_experts_per_tok + ( diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 7f1f3e472c6f..bc88a39a274d 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -58,6 +58,8 @@ def is_deepseek_nsa(config: PretrainedConfig) -> bool: "DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM", "DeepseekV3ForCausalLMNextN", + "MistralLarge3ForCausalLM", + "PixtralForConditionalGeneration", ] and getattr(config, "index_topk", None) is not None ) @@ -334,6 +336,9 @@ def _derive_model_shapes(self): or "LongcatFlashForCausalLM" in self.hf_config.architectures or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures or "DotsVLMForCausalLM" in self.hf_config.architectures + or "MistralLarge3ForCausalLM" in self.hf_config.architectures + or "PixtralForConditionalGeneration" in self.hf_config.architectures + or "MistralLarge3ForCausalLMEagle" in self.hf_config.architectures ): self.head_dim = 256 self.attention_arch = AttentionArch.MLA @@ -939,6 +944,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal "MultiModalityCausalLM", "MllamaForConditionalGeneration", "NemotronH_Nano_VL_V2", + "PixtralForConditionalGeneration", "Qwen2AudioForConditionalGeneration", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration", diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 6e0487a1ddc8..5f8b27f2b385 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1086,9 +1086,7 @@ def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor): hs_fp4_bytes, hs_sf_bytes = fp4_quantize( hidden_states, self.w13_input_scale_quant, - 16, # sf_vec_size - False, # use_ue8m0 - False, # is_sf_swizzled_layout + is_sf_swizzled_layout=False, ) hs_fp4 = hs_fp4_bytes.reshape( @@ -1129,6 +1127,12 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): else topk_config.correction_bias.to(hidden_states.dtype) ) + router_logits = ( + router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ) + with use_symmetric_memory( get_tp_group(), disabled=not is_allocation_symmetric() ): @@ -1170,7 +1174,7 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): intermediate_size=self.intermediate_size_per_partition, local_expert_offset=self.moe_ep_rank * self.num_local_experts, local_num_experts=self.num_local_experts, - routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, + routed_scaling_factor=None, tile_tokens_dim=None, routing_method_type=routing_method_type, do_finalize=True, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 878f70619728..5b6c6f3788f1 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -563,6 +563,7 @@ def create_weights( if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.uint32 if _use_hip_int4 else torch.float8_e4m3fn + tp_size = get_tensor_model_parallel_world_size() if self.block_quant: block_n, block_k = ( @@ -779,7 +780,6 @@ def process_weights_after_loading(self, layer: Module) -> None: _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) else: # For fp8 moe run with deepgemm, the expert weights and scales need be requantized to ue8m0 - from sglang.srt.layers.moe import get_moe_runner_backend from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE from sglang.srt.model_loader.utils import ( should_deepgemm_weight_requant_ue8m0, @@ -913,10 +913,83 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_weight_scale = torch.nn.Parameter( max_w13_scales, requires_grad=False ) - if _is_hip: self.process_weights_hip_scale_padding(layer) + + # Align FP8 weights to FlashInfer per-tensor kernel layout if enabled + if get_moe_runner_backend().is_flashinfer_trtllm(): + from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a + + # Note: No need to swap W13 halves, they are already in the correct order: [Gate, Up] + num_experts, two_n, hidden = layer.w13_weight.shape + + # 2) Reorder rows for fused gated activation (W13) + w13_interleaved = [ + reorder_rows_for_gated_act_gemm(layer.w13_weight[i]) + for i in range(num_experts) + ] + w13_interleaved = torch.stack(w13_interleaved).reshape( + num_experts, two_n, hidden + ) + + # 3) Shuffle weights for transposed MMA output (both W13, W2) + epilogue_tile_m = 128 + w13_shuffled = [ + shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m) + for i in range(num_experts) + ] + w2_shuffled = [ + shuffle_matrix_a(layer.w2_weight[i].view(torch.uint8), epilogue_tile_m) + for i in range(num_experts) + ] + + layer.w13_weight = Parameter( + torch.stack(w13_shuffled).view(torch.float8_e4m3fn), + requires_grad=False, + ) + layer.w2_weight = Parameter( + torch.stack(w2_shuffled).view(torch.float8_e4m3fn), + requires_grad=False, + ) + + # Precompute and register per-expert output scaling factors for FI MoE + if get_moe_runner_backend().is_flashinfer_trtllm(): + # Note: w13_input_scale and w2_input_scale are scalar Parameters post-reduction + assert ( + hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None + ) + assert hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None + assert ( + hasattr(layer, "w13_weight_scale") + and layer.w13_weight_scale is not None + ) + assert ( + hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None + ) + + input_scale = layer.w13_input_scale.to(torch.float32) + activation_scale = layer.w2_input_scale.to(torch.float32) + w13_weight_scale = layer.w13_weight_scale.to(torch.float32) + w2_weight_scale = layer.w2_weight_scale.to(torch.float32) + + output1_scales_scalar = ( + w13_weight_scale * input_scale * (1.0 / activation_scale) + ) + output1_scales_gate_scalar = w13_weight_scale * input_scale + output2_scales_scalar = activation_scale * w2_weight_scale + + layer.output1_scales_scalar = Parameter( + output1_scales_scalar, requires_grad=False + ) + layer.output1_scales_gate_scalar = Parameter( + output1_scales_gate_scalar, requires_grad=False + ) + layer.output2_scales_scalar = Parameter( + output2_scales_scalar, requires_grad=False + ) return + + return def process_weights_hip_int4(self, layer: Module): # TODO: _use_aiter: add after triton kernel added @@ -1001,7 +1074,6 @@ def create_moe_runner( from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe.utils import ( get_moe_a2a_backend, - get_moe_runner_backend, ) self.moe_runner_config = moe_runner_config @@ -1222,7 +1294,7 @@ def apply_with_router_logits( activation = self.moe_runner_config.activation routed_scaling_factor = self.moe_runner_config.routed_scaling_factor - from flashinfer.fused_moe import trtllm_fp8_block_scale_moe + from flashinfer.fused_moe import trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe from sglang.srt.layers.moe.topk import TopKOutputChecker from sglang.srt.layers.moe.utils import RoutingMethodType @@ -1233,9 +1305,13 @@ def apply_with_router_logits( assert ( activation == "silu" ), "Only silu is supported for flashinfer blockscale fp8 moe" - a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1]) - # NOTE: scales of hidden states have to be transposed! - a_sf_t = a_sf.t().contiguous() + + if self.block_quant: + a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1]) + # NOTE: scales of hidden states have to be transposed! + a_sf_t = a_sf.t().contiguous() + else: + a_q, _ = scaled_fp8_quant(x, layer.w13_input_scale) correction_bias = ( None @@ -1243,43 +1319,71 @@ def apply_with_router_logits( else topk_config.correction_bias.to(x.dtype) ) - routing_method_type = getattr(layer, "routing_method_type") + routing_method_type = getattr(layer, "routing_method_type", RoutingMethodType.DeepSeekV3) with use_symmetric_memory( get_tp_group(), disabled=not is_allocation_symmetric() ): - # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. - # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 - # so we put the whole function under the ``use_symmetric_memory`` context manager. - # If the bug is fixed, we can only put the output tensor allocation under the context manager. - return trtllm_fp8_block_scale_moe( - routing_logits=( - router_logits.to(torch.float32) - if routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits - ), - routing_bias=correction_bias, - hidden_states=a_q, - hidden_states_scale=a_sf_t, - gemm1_weights=layer.w13_weight, - gemm1_weights_scale=layer.w13_weight_scale_inv, - gemm2_weights=layer.w2_weight, - gemm2_weights_scale=layer.w2_weight_scale_inv, - num_experts=layer.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=layer.w2_weight.shape[2], - local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, - local_num_experts=layer.num_local_experts, - routed_scaling_factor=( - routed_scaling_factor if routed_scaling_factor is not None else 1.0 - ), - tile_tokens_dim=None, - routing_method_type=routing_method_type, - use_shuffled_weight=False, - ) + if self.block_quant: + # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. + # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 + # so we put the whole function under the ``use_symmetric_memory`` context manager. + # If the bug is fixed, we can only put the output tensor allocation under the context manager. + return trtllm_fp8_block_scale_moe( + routing_logits=( + router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ), + routing_bias=correction_bias, + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale_inv, + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale_inv, + num_experts=layer.num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=layer.w2_weight.shape[2], + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + routed_scaling_factor=( + routed_scaling_factor if routed_scaling_factor is not None else 1.0 + ), + tile_tokens_dim=None, + routing_method_type=routing_method_type, + use_shuffled_weight=False, + ) + else: + routing_bias_cast = ( + None if correction_bias is None else correction_bias.to(torch.bfloat16) + ) + + return trtllm_fp8_per_tensor_scale_moe( + routing_logits=router_logits.to(torch.bfloat16), + routing_bias=routing_bias_cast, + hidden_states=a_q, + gemm1_weights=layer.w13_weight, + output1_scales_scalar=layer.output1_scales_scalar, + output1_scales_gate_scalar=layer.output1_scales_gate_scalar, + gemm2_weights=layer.w2_weight, + output2_scales_scalar=layer.output2_scales_scalar, + num_experts=layer.num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=layer.w2_weight.shape[2], + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + routed_scaling_factor=( + routed_scaling_factor if routed_scaling_factor is not None else 1.0 + ), + use_routing_scales_on_input=False, + routing_method_type=routing_method_type, + ) def maybe_apply_hip_fused_experts( self, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 8d44584e7f26..ef91692166a1 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1459,10 +1459,27 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) + def _slice_scale(w): + assert w.shape == (layer.num_experts,) + assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts + return w[ + layer.moe_ep_rank + * layer.num_local_experts : (layer.moe_ep_rank + 1) + * layer.num_local_experts + ] + # Calculate input scales based on strategy if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe: - w13_input_scale = layer.w13_input_scale.max().to(torch.float32) - w2_input_scale = layer.w2_input_scale.max().to(torch.float32) + w13_input_scale = ( + layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts) + ) + w2_input_scale = ( + layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts) + ) + + if layer.moe_ep_size > 1: + w13_input_scale = _slice_scale(w13_input_scale) + w2_input_scale = _slice_scale(w2_input_scale) elif self.enable_flashinfer_cutedsl_moe: # All-expert-one-input-scale is mathematically different from default per-expert-input-scale # Thus we allow users to switch the flag to do thorough testing @@ -1479,15 +1496,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w2_input_scale = layer.w2_input_scale - def _slice_scale(w): - assert w.shape == (layer.num_experts,) - assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts - return w[ - layer.moe_ep_rank - * layer.num_local_experts : (layer.moe_ep_rank + 1) - * layer.num_local_experts - ] - w13_input_scale = _slice_scale(w13_input_scale) w2_input_scale = _slice_scale(w2_input_scale) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 163d128457ba..6d84f9c38f6d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -649,7 +649,11 @@ def __init__( layer_id=self.layer_id, quant_config=quant_config, routed_scaling_factor=self.routed_scaling_factor, - routing_method_type=RoutingMethodType.DeepSeekV3, + routing_method_type=( + RoutingMethodType.Renormalize + if config.norm_topk_prob + else RoutingMethodType.DeepSeekV3 + ), prefix=add_prefix("experts", prefix), ) @@ -1204,6 +1208,12 @@ def __init__( if rope_scaling: rope_scaling["rope_type"] = "deepseek_yarn" + self.llama_4_scaling = ( + config.to_dict()["llama_4_scaling"]["beta"] + if "llama_4_scaling" in config + else None + ) + # For tensor parallel attention if self.q_lora_rank is not None: self.fused_qkv_a_proj_with_mqa = ReplicatedLinear( @@ -1968,6 +1978,9 @@ def forward_absorb_core( q = torch.cat([q_nope_out, q_pe], dim=-1) k = torch.cat([k_nope, k_pe], dim=-1) + if self.llama_4_scaling: + q *= self.llama_4_scaling + attn_output = self.attn_mqa( q, k, @@ -3333,6 +3346,7 @@ def forward( class DeepseekV2ForCausalLM(nn.Module): # for quark model load packed_modules_mapping = {} + model_cls = DeepseekV2Model def __init__( self, @@ -3359,7 +3373,7 @@ def __init__( self.quant_config = quant_config self.determine_num_fused_shared_experts() self.use_nsa = is_deepseek_nsa(config) - self.model = DeepseekV2Model( + self.model = self.model_cls( config, quant_config, prefix=add_prefix("model", prefix) ) if self.pp_group.is_last_rank: @@ -3930,16 +3944,24 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal ): q_a_proj_weight = cached_a_proj[q_a_proj_name] kv_a_proj_weight = cached_a_proj[kv_a_proj_name] - cat_dim = 0 - if self.quant_config is not None and ( - self.quant_config.get_name() == "awq" - or self.quant_config.get_name() == "awq_marlin" - or self.quant_config.get_name() == "moe_wna16" - ): - cat_dim = 1 - fused_weight = torch.cat( - [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim - ) + + if q_a_proj_weight.shape == torch.Size( + [] + ) and kv_a_proj_weight.shape == torch.Size([]): + fused_weight = q_a_proj_weight + else: + cat_dim = 0 + if self.quant_config is not None and ( + self.quant_config.get_name() == "awq" + or self.quant_config.get_name() == "awq_marlin" + or self.quant_config.get_name() == "moe_wna16" + ): + cat_dim = 1 + + fused_weight = torch.cat( + [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim + ) + param_name = ( name.replace( "q_a_proj", "fused_qkv_a_proj_with_mqa" diff --git a/python/sglang/srt/models/mistral_large_3.py b/python/sglang/srt/models/mistral_large_3.py new file mode 100644 index 000000000000..586e6a926209 --- /dev/null +++ b/python/sglang/srt/models/mistral_large_3.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable + +import regex as re +import torch + +from sglang.srt.models.deepseek_v2 import DeepseekV3ForCausalLM + + +class MistralLarge3ForCausalLM(DeepseekV3ForCausalLM): + # fmt: off + remapping = { + r"layers\.(\d+)\.attention_norm\.weight": r"model.layers.\1.input_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.attention\.wq\.(\w+)": r"model.layers.\1.self_attn.q_proj.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.wq_a\.(\w+)": r"model.layers.\1.self_attn.q_a_proj.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.q_a_norm\.weight": r"model.layers.\1.self_attn.q_a_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.attention\.wq_b\.(\w+)": r"model.layers.\1.self_attn.q_b_proj.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.wkv_a_with_mqa\.(\w+)": r"model.layers.\1.self_attn.kv_a_proj_with_mqa.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.kv_a_norm\.weight": r"model.layers.\1.self_attn.kv_a_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.attention\.wkv_b\.(\w+)": r"model.layers.\1.self_attn.kv_b_proj.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.wo\.(\w+)": r"model.layers.\1.self_attn.o_proj.\2", # noqa: E501 + # FP8 scales + r"layers\.(\d+)\.attention\.k_fake_quantizer\.qscale_act": r"model.layers.\1.self_attn.mla_attn.mla_attn.k_scale", # noqa: E501 + r"layers\.(\d+)\.attention\.q_fake_quantizer\.qscale_act": r"model.layers.\1.self_attn.mla_attn.mla_attn.q_scale", # noqa: E501 + r"layers\.(\d+)\.attention\.v_fake_quantizer\.qscale_act": r"model.layers.\1.self_attn.mla_attn.mla_attn.v_scale", # noqa: E501 + r"layers\.(\d+)\.ffn_norm\.weight": r"model.layers.\1.post_attention_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.feed_forward\.w1\.(\w+)": r"model.layers.\1.mlp.gate_proj.\2", # noqa: E501 + r"layers\.(\d+)\.feed_forward\.w2\.(\w+)": r"model.layers.\1.mlp.down_proj.\2", # noqa: E501 + r"layers\.(\d+)\.feed_forward\.w3\.(\w+)": r"model.layers.\1.mlp.up_proj.\2", # noqa: E501 + r"layers\.(\d+)\.gate\.weight": r"model.layers.\1.mlp.gate.weight", # noqa: E501 + r"layers\.(\d+)\.shared_experts\.w1\.(\w+)": r"model.layers.\1.mlp.shared_experts.gate_proj.\2", # noqa: E501 + r"layers\.(\d+)\.shared_experts\.w2\.(\w+)": r"model.layers.\1.mlp.shared_experts.down_proj.\2", # noqa: E501 + r"layers\.(\d+)\.shared_experts\.w3\.(\w+)": r"model.layers.\1.mlp.shared_experts.up_proj.\2", # noqa: E501 + r"layers\.(\d+)\.experts\.(\d+)\.w1\.(\w+)": r"model.layers.\1.mlp.experts.\2.gate_proj.\3", # noqa: E501 + r"layers\.(\d+)\.experts\.(\d+)\.w2\.(\w+)": r"model.layers.\1.mlp.experts.\2.down_proj.\3", # noqa: E501 + r"layers\.(\d+)\.experts\.(\d+)\.w3\.(\w+)": r"model.layers.\1.mlp.experts.\2.up_proj.\3", # noqa: E501 + r"layers\.(\d+)\.router_biases": r"model.layers.\1.mlp.gate.e_score_correction_bias", # noqa: E501 + r"norm\.weight": "model.norm.weight", # noqa: E501 + r"tok_embeddings\.weight": "model.embed_tokens.weight", # noqa: E501 + r"output\.weight": "lm_head.weight", # noqa: E501 + } + # fmt: on + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + return super().load_weights(self._iterable_remap_mistral_to_ds(weights)) + + def _iterable_remap_mistral_to_ds( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> Iterable[tuple[str, torch.Tensor]]: + """Remap Mistral parameters to DeepseekV2 parameters.""" + for name, loaded_weight in weights: + for k, v in self.remapping.items(): + match = re.fullmatch(k, name) + if match: + name = re.sub(k, v, name) + break + else: + + continue + + # Note(Andy): Unlike Llama, this implementation uses + # is_neox_style=False for RoPE, which matches Mistral's implementation. + # Thus we don't need to permute the q/k weights (unlike Llama) + + # Remapping scale names. We could do this in the regex above but it + # would triple the number of lines for most layers. + if name.endswith(".qscale_act"): + name = re.sub(r"\.qscale_act$", ".input_scale", name) + elif name.endswith(".qscale_weight"): + name = re.sub(r"\.qscale_weight$", ".weight_scale", name) + + yield name, loaded_weight + + +EntryClass = MistralLarge3ForCausalLM diff --git a/python/sglang/srt/models/mistral_large_3_eagle.py b/python/sglang/srt/models/mistral_large_3_eagle.py new file mode 100644 index 000000000000..e85a629e76f0 --- /dev/null +++ b/python/sglang/srt/models/mistral_large_3_eagle.py @@ -0,0 +1,104 @@ +from typing import Optional + +import torch +from torch import nn +from transformers import PretrainedConfig + +from python.sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import RowParallelLinear +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV2Model +from sglang.srt.models.mistral_large_3 import MistralLarge3ForCausalLM +from sglang.srt.utils import add_prefix + + +class MistralLarge3Model(DeepseekV2Model): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + nn.Module.__init__(self) + + self.config = config + self.vocab_size = config.vocab_size + assert get_pp_group().world_size == 1 + self.pp_group = get_pp_group() + self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), + ) + + self.layers = nn.ModuleList( + [ + DeepseekV2DecoderLayer( + config=config, + prefix=add_prefix(prefix, f"layers.{i}"), + quant_config=quant_config, + layer_id=i, + ) + for i in range(self.config.num_hidden_layers) + ] + ) + self.start_layer = 0 + self.end_layer = self.config.num_hidden_layers + + self.fc = RowParallelLinear( + self.config.hidden_size * 2, + self.config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix(prefix, "fc"), + input_is_parallel=False, + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layers_to_capture = [] + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> torch.Tensor: + if input_embeds is None: + input_embeds = self.embed_tokens(input_ids) + input_embeds, _ = self.fc( + torch.cat((input_embeds, forward_batch.spec_info.hidden_states), dim=-1) + ) + output = super().forward( + input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors + ) + assert isinstance(output, torch.Tensor) + return output + + +class MistralLarge3ForCausalLMEagle(MistralLarge3ForCausalLM): + remapping = MistralLarge3ForCausalLM.remapping | { + r"eagle_linear\.weight": r"model.fc.weight", + r"eagle_linear\.qscale_act": r"model.fc.input_scale", + r"eagle_linear\.qscale_weight": r"model.fc.weight_scale", + } + + def __init__( + self, + *, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + config.quant_config = quant_config + self.model_cls = MistralLarge3Model + super().__init__(config=config, quant_config=quant_config, prefix=prefix) + + +EntryClass = [MistralLarge3ForCausalLMEagle] diff --git a/python/sglang/srt/models/pixtral.py b/python/sglang/srt/models/pixtral.py index 249a5ce81bba..1f9c236bb2ef 100644 --- a/python/sglang/srt/models/pixtral.py +++ b/python/sglang/srt/models/pixtral.py @@ -16,10 +16,12 @@ Using mistral-community/pixtral-12b as reference. """ +from dataclasses import dataclass, fields from typing import Iterable, List, Optional, Set, Tuple, Union import torch import torch.nn as nn +import torch.nn.functional as F from transformers import PixtralVisionConfig, PretrainedConfig from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding from transformers.models.pixtral.modeling_pixtral import ( @@ -32,9 +34,398 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import MergedColumnParallelLinear, RowParallelLinear from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens -from sglang.srt.managers.schedule_batch import MultimodalInputs +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.mistral_large_3 import MistralLarge3ForCausalLM + +USE_XFORMERS_OPS = False +PATCH_MERGE = "patch_merge" + + +# Vision encoder +@dataclass +class VisionEncoderArgs: + hidden_size: int + num_channels: int + image_size: int + patch_size: int + intermediate_size: int + num_hidden_layers: int + num_attention_heads: int + rope_theta: float # for rope-2D + image_token_id: int + adapter_bias: bool = True + spatial_merge_size: int = 1 + add_pre_mm_projector_layer_norm: bool = False + mm_projector_id: str = "" + + +class PixtralForConditionalGeneration(nn.Module): + merge_by_field_config = True + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return None + + raise ValueError("Only image modality is supported") + + def __init__(self, *, config, prefix: str = "", **kwargs): + super().__init__() + self.config = config + dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} + vision_args = { + key: value + for key, value in self.config.vision_config.to_dict().items() + if key in dataclass_fields + } + + self.vision_args = VisionEncoderArgs(**vision_args) + + self.language_model = MistralLarge3ForCausalLM( + config=self.config.text_config, + quant_config=kwargs.get("quant_config"), + ) + + self.vision_encoder = VisionTransformer(self.vision_args) + + if self.vision_args.add_pre_mm_projector_layer_norm: + self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size, eps=1e-5) + + if self.vision_args.mm_projector_id == PATCH_MERGE: + self.patch_merger = PatchMerger( + vision_encoder_dim=self.vision_args.hidden_size, + spatial_merge_size=self.vision_args.spatial_merge_size, + use_mlp_bias=False, + ) + + self.vision_language_adapter = VisionLanguageAdapter( + self.vision_args, dim=self.config.text_config.hidden_size + ) + + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + return pattern.pad_input_tokens(input_ids, mm_inputs) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): + return weight[0].startswith("vision_encoder") + + def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]): + return weight[0].startswith("vision_language_adapter") + + def is_patch_merger(weight: tuple[str, torch.Tensor]): + return weight[0].startswith("patch_merger") + + def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]): + return weight[0].startswith("pre_mm_projector_norm") + + # Get references to parameters for direct loading + vision_encoder_dict = dict(self.vision_encoder.named_parameters()) + patch_merger_dict = ( + dict(self.patch_merger.named_parameters()) + if self.vision_args.mm_projector_id == PATCH_MERGE + else dict() + ) + pre_mm_projector_norm_dict = ( + dict(self.pre_mm_projector_norm.named_parameters()) + if self.vision_args.add_pre_mm_projector_layer_norm + else dict() + ) + vision_lang_adapter_dict = dict(self.vision_language_adapter.named_parameters()) + + def llm_weights_generator(): + # Single pass over weights + for name, w in weights: + if is_vision_encoder_weights((name, w)): + # Load vision encoder weights directly + trimmed_name = ".".join(name.split(".")[1:]) + # NOTE: The current nvfp4 model has extra weights that we need to ignore, called + # vision_encoder.transformer.layers.*.attention.{k,v}_fake_quantizer.qscale_act + # TODO: Remove this if condition once the model is fixed + if "fake_quantizer.qscale_act" in trimmed_name: + continue + param = vision_encoder_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + elif is_patch_merger((name, w)): + # Load vision patch merger weights directly + trimmed_name = ".".join(name.split(".")[1:]) + param = patch_merger_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + elif is_pre_mm_projector_norm((name, w)): + # Load vision pre_mm_projector_norm weights directly + trimmed_name = ".".join(name.split(".")[1:]) + param = pre_mm_projector_norm_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + elif is_vision_lang_adapter_weights((name, w)): + # Load vision-language adapter weights directly + trimmed_name = ".".join(name.split(".")[1:]) + param = vision_lang_adapter_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + else: + # LLM weights: yield them to be loaded + # by language_model.load_weights + yield (name, w) + + # Now we call the language model load with the generator + self.language_model.load_weights(llm_weights_generator()) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + images = [item.feature for item in items] + # Process images through vision encoder + image_features = self.vision_encoder(images) + if self.vision_args.add_pre_mm_projector_layer_norm: + image_features = image_features.view(-1, image_features.shape[-1]) + image_features = self.pre_mm_projector_norm(image_features) + if self.vision_args.mm_projector_id == PATCH_MERGE: + patch_size = self.vision_args.patch_size + img_patch_dims = [ + (img.shape[-2] // patch_size, img.shape[-1] // patch_size) + for img in images + for _ in range(img.shape[0]) + ] + image_features = self.patch_merger( + image_features, image_sizes=img_patch_dims + ) + image_embeds = self.vision_language_adapter(image_features) + return image_embeds + + def forward(self, input_ids, positions, forward_batch): + return general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.language_model, + multimodal_model=self, + positions=positions, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def get_embed_and_head(self): + return self.language_model.get_embed_and_head() + + +class PatchMerger(nn.Module): + """ + Learned merging of spatial_merge_size ** 2 patches + """ + + def __init__( + self, + vision_encoder_dim: int, + spatial_merge_size: int, + use_mlp_bias: bool = False, + ) -> None: + super().__init__() + + mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2) + + self.spatial_merge_size = spatial_merge_size + self.mlp_input_dim = mlp_input_dim + + self.merging_layer = nn.Linear( + mlp_input_dim, + vision_encoder_dim, + bias=use_mlp_bias, + ) + + def forward( + self, x: torch.Tensor, image_sizes: list[tuple[int, int]] + ) -> torch.Tensor: + # image_sizes specified in tokens + assert sum([h * w for h, w in image_sizes]) == x.shape[-2] + + # x is (N, vision_encoder_dim) + x = self.permute(x, image_sizes) + + # x is (N / spatial_merge_size ** 2, + # vision_encoder_dim * spatial_merge_size ** 2) + x = self.merging_layer(x) + + # x is (N / spatial_merge_size ** 2, vision_encoder_dim) + return x + + def permute( + self, + x: torch.Tensor, + image_sizes: list[tuple[int, int]], + ) -> torch.Tensor: + """ + Args: + x: (N, D) where N is flattened and concatenated patch tokens + for all images + image_sizes: list of tuple of (height, width) in tokens for + each image + Returns: + image_features: reorders patch tokens so each grid of + (spatial_merge_size, spatial_merge_size) is contiguous. + now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2) + """ + + sub_grids = get_sub_grids( + x=x, image_sizes=image_sizes, spatial_merge_size=self.spatial_merge_size + ) # list of [d x sub_grid_size x sub_grid_size x n_patches] + permuted_tensor: list[torch.Tensor] = [] + for grid in sub_grids: + n_patches = grid.shape[-1] + permuted_tensor.append( + grid.view(-1, n_patches).t() + ) # n_patches x d * sub_grid_size * sub_grid_size + return torch.cat( + permuted_tensor, dim=0 + ) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2) + + +def get_sub_grids( + x: torch.Tensor, + image_sizes: list[tuple[int, int]], + spatial_merge_size: int, +) -> list[torch.Tensor]: + # image_sizes specified in tokens + tokens_per_image = [h * w for h, w in image_sizes] + d = x.shape[-1] + all_img_sub_grids: list[torch.Tensor] = [] + sub_grid_size = spatial_merge_size + + for image_index, image_tokens in enumerate(x.split(tokens_per_image)): + # Reshape image_tokens into a 2D grid + h, w = image_sizes[image_index] + image_grid = image_tokens.view(h, w, d).permute(2, 0, 1)[ + None, :, :, : + ] # 1 x d x h x w + sub_grids = torch.nn.functional.unfold( + image_grid, kernel_size=sub_grid_size, stride=sub_grid_size + ) + sub_grids = sub_grids.view( + 1, d, sub_grid_size, sub_grid_size, -1 + ) # 1 x d x sub_grid_size x sub_grid_size x n_patches + + all_img_sub_grids.append(sub_grids[0]) + + return all_img_sub_grids + + +class VisionTransformer(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.args = args + self.patch_conv = nn.Conv2d( + in_channels=args.num_channels, + out_channels=args.hidden_size, + kernel_size=args.patch_size, + stride=args.patch_size, + bias=False, + ) + self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) + self.transformer = Transformer(args) + + head_dim = self.args.hidden_size // self.args.num_attention_heads + assert head_dim % 2 == 0, "ROPE requires even head_dim" + self._freqs_cis: torch.Tensor | None = None + + @property + def max_patches_per_side(self) -> int: + return self.args.image_size // self.args.patch_size + + @property + def device(self) -> torch.types.Device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @property + def freqs_cis(self) -> torch.Tensor: + if self._freqs_cis is None: + self._freqs_cis = precompute_freqs_cis_2d( + dim=self.args.hidden_size // self.args.num_attention_heads, + height=self.max_patches_per_side, + width=self.max_patches_per_side, + theta=self.args.rope_theta, + ) + + if self._freqs_cis.device != self.device: + self._freqs_cis = self._freqs_cis.to(device=self.device) + + return self._freqs_cis + + def forward( + self, + images: list[torch.Tensor], + ) -> torch.Tensor: + """ + Args: + images: list of N_img images of variable sizes, + each of shape (B, C, H, W) + Returns: + image_features: tensor of token features for + all tokens of all images of shape (N_toks, D) + """ + patch_embeds_list = [self.patch_conv(img.to(self.dtype)) for img in images] + + patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list] + + patch_embeds = torch.cat(patch_embeds, dim=1) + patch_embeds_shape = patch_embeds.shape + patch_embeds = patch_embeds.view(-1, patch_embeds_shape[-1]) + patch_embeds = self.ln_pre(patch_embeds) + patch_embeds = patch_embeds.view(patch_embeds_shape) + + # positional embeddings + positions = position_meshgrid(patch_embeds_list).to(self.device) + freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] + + # pass through Transformer with a block diagonal mask delimiting images + if USE_XFORMERS_OPS: + from xformers import ops as xops + + mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], + ) + else: + from transformers.models.pixtral.modeling_pixtral import ( + generate_block_attention_mask, + ) + + mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds + ) + return self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) + + +def position_meshgrid( + patch_embeds_list: list[torch.Tensor], +) -> torch.Tensor: + positions = torch.cat( + [ + torch.stack( + torch.meshgrid( + torch.arange(p.shape[-2]), + torch.arange(p.shape[-1]), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) + for p in patch_embeds_list + ] + ) + return positions class PixtralHFMLP(nn.Module): @@ -81,6 +472,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out +class VisionLanguageAdapter(nn.Module): + def __init__(self, args: VisionEncoderArgs, dim: int): + super().__init__() + assert isinstance(args, VisionEncoderArgs) + self.w_in = nn.Linear( + args.hidden_size, + dim, + bias=args.adapter_bias, + ) + self.gelu = nn.GELU() + self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w_out(self.gelu(self.w_in(x))) + + class PixtralHFTransformerBlock(nn.Module): """Transformer block for PixtralHFVisionModel using SGLang components.""" @@ -156,6 +563,161 @@ def forward( return output +def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + freqs_cis: complex - (seq_len, head_dim / 2) + x: complex - (bsz, seq_len, head_dim / 2) + """ + ndim = x.ndim + assert ndim > 1 + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( + freqs_cis.shape, + (x.shape[1], x.shape[-1]), + ) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def precompute_freqs_cis_2d( + dim: int, + height: int, + width: int, + theta: float, +) -> torch.Tensor: + """ + freqs_cis: 2D complex tensor of shape (height, width, dim // 2) + to be indexed by (height, width) position tuples + """ + # (dim / 2) frequency bases + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + + h = torch.arange(height, device=freqs.device) + w = torch.arange(width, device=freqs.device) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + freqs_2d = torch.cat( + [ + freqs_h[:, None, :].repeat(1, width, 1), + freqs_w[None, :, :].repeat(height, 1, 1), + ], + dim=-1, + ) + return torch.polar(torch.ones_like(freqs_2d), freqs_2d) + + +def apply_rotary_emb_vit( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + assert freqs_cis.dtype == torch.complex64 + freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class FeedForward(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + assert args.intermediate_size is not None + self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) + self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class Attention(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.args = args + assert not args.hidden_size % args.num_attention_heads + self.n_heads = args.num_attention_heads + self.head_dim = args.hidden_size // args.num_attention_heads + + self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + batch, patches, _ = x.shape + + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = q.reshape(batch, patches, self.n_heads, self.head_dim) + k = k.reshape(batch, patches, self.n_heads, self.head_dim) + v = v.reshape(batch, patches, self.n_heads, self.head_dim) + + q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis) + + if USE_XFORMERS_OPS: + from xformers import ops as xops + + out = xops.memory_efficient_attention(q, k, v, attn_bias=mask) + else: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + out = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) + out = out.transpose(1, 2) + + out = out.reshape(batch, patches, self.n_heads * self.head_dim) + return self.wo(out) + + +class TransformerBlock(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.attention = Attention(args) + self.feed_forward = FeedForward(args) + self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5) + self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + attention_norm_x = self.attention_norm(x.view(-1, x.shape[-1])) + attention_norm_x = attention_norm_x.view(x.shape) + r = self.attention.forward(attention_norm_x, mask=mask, freqs_cis=freqs_cis) + h = x + r + ffn_norm_h = self.ffn_norm(h.view(-1, h.shape[-1])) + ffn_norm_h = ffn_norm_h.view(h.shape) + r = self.feed_forward.forward(ffn_norm_h) + out = h + r + return out + + +class Transformer(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.layers = torch.nn.ModuleList() + for _ in range(args.num_hidden_layers): + self.layers.append(TransformerBlock(args)) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + freqs_cis: torch.Tensor | None, + ) -> torch.Tensor: + for layer in self.layers: + x = layer(x, mask=mask, freqs_cis=freqs_cis) + return x + + class PixtralHFTransformer(nn.Module): """Transformer for PixtralHFVisionModel using SGLang components.""" @@ -456,4 +1018,4 @@ class PixtralVisionModel(PixtralHFVisionModel): # Register the model classes for external access -EntryClass = [PixtralVisionModel] +EntryClass = [PixtralForConditionalGeneration, PixtralVisionModel] diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index af5cedec9fa6..6b6ab34ad1a0 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -6,7 +6,10 @@ _num_image_tokens as _get_pixtral_hf_num_image_tokens, ) -from sglang.srt.models.pixtral import PixtralVisionModel +from sglang.srt.models.pixtral import ( + PixtralForConditionalGeneration, + PixtralVisionModel, +) from sglang.srt.multimodal.processors.base_processor import ( BaseMultimodalProcessor, MultimodalSpecialTokens, @@ -14,7 +17,7 @@ class PixtralProcessor(BaseMultimodalProcessor): - models = [PixtralVisionModel] + models = [PixtralVisionModel, PixtralForConditionalGeneration] PAD_TOKEN = "" IMG_BREAK_TOKEN_ID = 12 @@ -30,7 +33,6 @@ def get_patch_grid_size( patch_width = patch_height = self.patch_size ratio = max(image_width / max_width, image_height / max_height) - if ratio > 1: image_width = int(math.floor(image_width / ratio)) image_height = int(math.floor(image_height / ratio)) @@ -52,6 +54,10 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): self.vision_config = hf_config.vision_config self.image_size = self.vision_config.image_size self.patch_size = self.vision_config.patch_size + + self._processor.patch_size = self.patch_size + self._processor.spatial_merge_size = self.vision_config.spatial_merge_size + self.mm_tokens = MultimodalSpecialTokens( image_token=_processor.image_token, image_token_id=self.IM_TOKEN_ID, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 18c073a897cd..e465a1543b61 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -936,7 +936,11 @@ def _handle_model_specific_adjustments(self): hf_config = self.get_hf_config() model_arch = hf_config.architectures[0] - if model_arch in ["DeepseekV3ForCausalLM"]: + + if model_arch in ["MistralLarge3ForCausalLM", "PixtralForConditionalGeneration"]: + self.dtype = "bfloat16" + + if model_arch in ["DeepseekV3ForCausalLM", "MistralLarge3ForCausalLM", "PixtralForConditionalGeneration"]: if is_deepseek_nsa(hf_config): if ( self.attention_backend is None @@ -1042,7 +1046,7 @@ def _handle_model_specific_adjustments(self): # Default DeepSeek V3/R1 native FP8 when not explicitly set, # Because we need this condition for an assertion in # flashinfer_trtllm MoE runner backend. - if quant_method is None: + if quant_method is None and model_arch == "DeepseekV3ForCausalLM": self.quantization = "fp8" logger.info( "Quantization not specified, default to fp8 for DeepSeek on sm100" @@ -1691,6 +1695,8 @@ def _handle_speculative_decoding(self): "Glm4MoeForCausalLM", "BailingMoeForCausalLM", "BailingMoeV2ForCausalLM", + "MistralLarge3ForCausalLM", + "PixtralForConditionalGeneration", ]: if self.speculative_draft_model_path is None: self.speculative_draft_model_path = self.model_path @@ -1931,6 +1937,8 @@ def _handle_deterministic_inference(self): "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM", + "MistralLarge3ForCausalLM", + "PixtralForConditionalGeneration", ] except Exception: pass @@ -4524,6 +4532,8 @@ def auto_choose_speculative_params(self: ServerArgs): "GptOssForCausalLM", "BailingMoeForCausalLM", "BailingMoeV2ForCausalLM", + "MistralLarge3ForCausalLM", + "PixtralForConditionalGeneration", ]: # The default value for deepseek and gpt-oss return (3, 1, 4) diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 0e71dfb31383..00d2fc173c6f 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -60,7 +60,7 @@ from sglang.srt.configs.internvl import InternVLChatConfig from sglang.srt.connector import create_remote_connector from sglang.srt.multimodal.customized_mm_processor_utils import _CUSTOMIZED_MM_PROCESSOR -from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset +from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset, mistral_utils _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [ ChatGLMConfig, @@ -184,6 +184,41 @@ def _load_deepseek_v32_model( ) +# Temporary hack for Mistral Large +def _load_mistral_large_3_for_causal_LM( + model_path: str, + trust_remote_code: bool = False, + revision: Optional[str] = None, + **kwargs, +): + # first get the local path + local_path = download_from_hf(model_path) + # then load the config file in json + parser = mistral_utils.MistralConfigParser() + config_dict, _ = parser.parse(local_path) + + tmp_path = os.path.join(tempfile.gettempdir(), "_tmp_config_folder") + os.makedirs(tmp_path, exist_ok=True) + + unique_path = os.path.join(tmp_path, f"mistral_large_3_{os.getpid()}") + with open(unique_path, "w") as f: + json.dump(config_dict, f) + + loaded_config = AutoConfig.from_pretrained( + unique_path, trust_remote_code=trust_remote_code, revision=revision, **kwargs + ) + text_config = getattr(loaded_config, "text_config", None) + if text_config is not None and isinstance(text_config, dict): + text_config = AutoConfig.for_model(**text_config) + setattr(loaded_config, "text_config", text_config) + vision_config = getattr(loaded_config, "vision_config", None) + if vision_config is not None and isinstance(vision_config, dict): + vision_config = AutoConfig.for_model(**vision_config) + setattr(loaded_config, "vision_config", vision_config) + + return loaded_config + + def _is_deepseek_ocr_model(config: PretrainedConfig) -> bool: # TODO: Remove this workaround related when AutoConfig correctly identifies deepseek-ocr. # Hugging Face's AutoConfig currently misidentifies it as deepseekvl2. @@ -215,17 +250,18 @@ def get_config( client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) model = client.get_local_dir() - try: - config = AutoConfig.from_pretrained( + if "mistral-large-3" in model: + config = _load_mistral_large_3_for_causal_LM( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) - - except ValueError as e: - if not "deepseek_v32" in str(e): - raise e + elif "deepseek_v32" in model: config = _load_deepseek_v32_model( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) + else: + config = AutoConfig.from_pretrained( + model, trust_remote_code=trust_remote_code, revision=revision, **kwargs + ) if ( config.architectures is not None @@ -465,14 +501,12 @@ def get_processor( ): # pop 'revision' from kwargs if present. revision = kwargs.pop("revision", tokenizer_revision) - config = AutoConfig.from_pretrained( tokenizer_name, trust_remote_code=trust_remote_code, revision=revision, **kwargs, ) - if _is_deepseek_ocr_model(config): # Temporary hack for load deepseek-ocr config.model_type = "deepseek-ocr" diff --git a/python/sglang/srt/utils/mistral_utils.py b/python/sglang/srt/utils/mistral_utils.py new file mode 100644 index 000000000000..d4ce44aded5a --- /dev/null +++ b/python/sglang/srt/utils/mistral_utils.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +from pathlib import Path +from typing import Any + +from transformers import PretrainedConfig, WhisperConfig + +from sglang.srt.utils import logger + + +def adapt_config_dict( + config_dict: dict[str, Any], model: str, **kwargs +) -> PretrainedConfig: + config_dict.update(kwargs) + config_dict = _remap_general_mistral_args(config_dict) + + if bool(config_dict.get("quantization")): + config_dict = _remap_mistral_quantization_args(config_dict) + + is_moe = bool(config_dict.get("moe")) + is_mistral_large_3 = ( + is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0 + ) + if is_moe: + if is_mistral_large_3: + config_dict = _remap_moe_args(config_dict) + config_dict["model_type"] = "deepseek_v3" + config_dict["architectures"] = ["MistralLarge3ForCausalLM"] + + assert ( + "llama_4_scaling" in config_dict + ), "MistralLarge3 expect llama4 scaling config." + llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"] + assert all( + [ + key in config_dict["llama_4_scaling"] + for key in llama_4_scaling_config_keys + ] + ), ( + "llama_4_scaling config should define the keys: " + f"{','.join(llama_4_scaling_config_keys)}" + ) + else: + config_dict["architectures"] = ["MixtralForCausalLM"] + else: + config_dict["architectures"] = ["MistralForCausalLM"] + + if bool(config_dict.get("yarn")): + config_dict = _remap_mistral_yarn_args(config_dict) + + if bool(config_dict.get("llama_4_scaling")): + llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"] + assert all( + [ + key in config_dict["llama_4_scaling"] + for key in llama_4_scaling_config_keys + ] + ), ( + "llama_4_scaling config should define the keys: " + f"{','.join(llama_4_scaling_config_keys)}" + ) + + is_vision = bool( + (config_dict.get("multimodal") or {}).get("vision_encoder_args") + or config_dict.get("vision_encoder") + ) + is_audio = bool( + ((config_dict.get("multimodal") or {}).get("whisper_model_args") or {}).get( + "encoder_args" + ) + ) + is_eagle = "eagle" in model + + assert not (is_vision and is_audio), "Vision and audio are mutually exclusive" + + if is_vision: + config_dict = _remap_mistral_vision_args(config_dict) + if is_audio: + config_dict = _remap_mistral_audio_args(config_dict) + if is_eagle: + config_dict = _remap_mistral_eagle_args(config_dict) + + config = PretrainedConfig.from_dict(config_dict) + + logger.debug("Initialized config %s", config) + + return config_dict, config + + +def _remap_mistral_vision_args(config: dict) -> dict: + if config.get("multimodal"): + vision_config = config.pop("multimodal") + else: + vision_config = config.pop("vision_encoder") + + quant_config = config.get("quantization_config") + + config = { + "model_type": "pixtral", + "architectures": ["PixtralForConditionalGeneration"], + "text_config": config, + "vision_config": {"model_type": "pixtral", **vision_config}, + } + if quant_config: + config["quantization_config"] = quant_config + return config + + +def _remap_mistral_yarn_args(config: dict) -> dict: + yarn_config_map = { + "factor": "factor", + "original_max_position_embeddings": "original_max_position_embeddings", + "beta": "beta_fast", + "alpha": "beta_slow", + "apply_scale": None, + } + yarn_config = config.get("yarn") or {} + config["rope_scaling"] = { + "rope_type": "yarn", + "mscale_all_dim": 1, + } + for old_name, new_name in yarn_config_map.items(): + if old_name in yarn_config: + value = yarn_config.pop(old_name) + if new_name is not None: + config["rope_scaling"][new_name] = value + + assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}" + + return config + + +def _remap_general_mistral_args(config: dict) -> dict: + # Mistral key -> HF key + config_mapping = { + "dim": "hidden_size", + "norm_eps": "rms_norm_eps", + "n_kv_heads": "num_key_value_heads", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "hidden_dim": "intermediate_size", + } + # HF key -> (Mistral key, default value) + top_level_mapping_with_default = { + "model_type": ("model_type", "transformer"), + "hidden_act": ("activation", "silu"), + "tie_word_embeddings": ("tied_embeddings", False), + "max_seq_len": ("max_seq_len", 128_000), + "max_position_embeddings": ("max_position_embeddings", 128_000), + } + + for key, new_key in config_mapping.items(): + if key in config: + config[new_key] = config.pop(key) + + for new_key, (key, default_value) in top_level_mapping_with_default.items(): + config[new_key] = config.pop(key, default_value) + + return config + + +def _remap_mistral_quantization_args(config: dict) -> dict: + if config.get("quantization"): + quantization = config.pop("quantization", {}) + if quantization.get("qformat_weight") == "fp8_e4m3": + qscheme_act = quantization.get("qscheme_act") + assert qscheme_act in ( + "NO_SCALES", + "TENSOR", + None, + ), "Only NO_SCALES and TENSOR (default) are supported for qscheme_act" + is_dynamic = qscheme_act == "NO_SCALES" + config["quantization_config"] = { + "quant_method": "fp8", + "activation_scheme": "dynamic" if is_dynamic else "static", + } + else: + raise ValueError(f"Found unknown quantization='{quantization}' in config") + + return config + + +def _remap_mistral_audio_args(config: dict) -> dict: + whisper_args = config["multimodal"].pop("whisper_model_args") + encoder_args = whisper_args["encoder_args"] + downsample_args = whisper_args["downsample_args"] + + quant_config = config.get("quantization_config") + config = { + "model_type": "whixtral", + "architectures": ["VoxtralForConditionalGeneration"], + "text_config": PretrainedConfig.from_dict(config), + "audio_config": WhisperConfig( + num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"], + window_size=encoder_args["audio_encoding_args"]["window_size"], + sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"], + hop_length=encoder_args["audio_encoding_args"]["hop_length"], + downsample_factor=downsample_args["downsample_factor"], + d_model=encoder_args["dim"], + encoder_layers=encoder_args["n_layers"], + encoder_ffn_dim=encoder_args["hidden_dim"], + encoder_attention_heads=encoder_args["n_heads"], + vocab_size=encoder_args["vocab_size"], + max_source_positions=encoder_args["max_source_positions"], + is_encoder_decoder=False, # Override WhisperConfig default + ), + } + if quant_config: + config["quantization_config"] = quant_config + return config + + +def _remap_mistral_eagle_args(config: dict) -> dict: + config["architectures"] = ["MistralLarge3ForCausalLMEagle"] + return config + + +def _remap_moe_args(config: dict) -> dict: + moe_config_map = { + "route_every_n": "moe_layer_freq", + "first_k_dense_replace": "first_k_dense_replace", + "num_experts_per_tok": "num_experts_per_tok", + "num_experts": "n_routed_experts", + "expert_hidden_dim": "moe_intermediate_size", + "routed_scale": "routed_scaling_factor", + "num_shared_experts": "n_shared_experts", + "num_expert_groups": "n_group", + "num_expert_groups_per_tok": "topk_group", + } + moe_config = config.get("moe", {}) + for old_name, new_name in moe_config_map.items(): + if old_name in moe_config: + value = moe_config.pop(old_name) + config[new_name] = value + + config["topk_method"] = None + config["norm_topk_prob"] = True + config["scoring_func"] = "softmax" + + return config + + +class MistralConfigParser: + def get_hf_file_to_dict( + self, file_name: str, model: str | Path, revision: str | None = "main" + ): + file_path = Path(model) / file_name + if not file_path.is_file(): + # TODO: Add logic to download from HF in case file is not locally found + raise FileNotFoundError(f"File not found {model}, {file_name}") + + if file_path is not None and file_path.is_file(): + with open(file_path) as file: + return json.load(file) + + return None + + def _download_mistral_config_file(self, model, revision) -> dict: + config_file_name = "params.json" + config_dict = self.get_hf_file_to_dict(config_file_name, model, revision) + if config_dict is None: + raise ValueError( + f"Failed to load mistral '{config_file_name}' config for model " + f"{model}. Please check if the model is a mistral-format model " + f"and if the config file exists." + ) + assert isinstance(config_dict, dict) + return config_dict + + def parse( + self, + model: str | Path, + revision: str | None = None, + **kwargs, + ) -> tuple[dict, PretrainedConfig]: + # This function loads a params.json config which + # should be used when loading models in mistral format + config_dict = self._download_mistral_config_file(model, revision) + if config_dict.get("max_position_embeddings") is None: + logger.warning( + "The params.json file is missing 'max_position_embeddings'" + " and could not get a value from the HF config." + " Defaulting to 128000" + ) + config_dict["max_position_embeddings"] = 128_000 + + config_dict, config = adapt_config_dict(config_dict, model) + + # Mistral configs may define sliding_window as list[int]. Convert it + # to int and add the layer_types list[str] to make it HF compatible + if (sliding_window := getattr(config, "sliding_window", None)) and isinstance( + sliding_window, list + ): + pattern_repeats = config.num_hidden_layers // len(sliding_window) + layer_types = sliding_window * pattern_repeats + config.layer_types = [ + "full_attention" if layer_type is None else "sliding_attention" + for layer_type in layer_types + ] + config.sliding_window = next(filter(None, sliding_window), None) + + return config_dict, config diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index c3738b04c5e6..16b15cff0f08 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -347,13 +347,29 @@ impl RouterTrait for RouterManager { } } - async fn get_model_info(&self, _req: Request) -> Response { - // TODO: Extract model from request and route to appropriate router - ( - StatusCode::NOT_IMPLEMENTED, - "Model info endpoint not yet implemented in RouterManager", - ) - .into_response() + async fn get_model_info(&self, req: Request) -> Response { + // Route to default router or first available router + let router_id = { + let default_router = self.default_router.read().unwrap(); + default_router.clone() + }; + + let router = if let Some(id) = router_id { + self.routers.get(&id).map(|r| r.clone()) + } else { + // If no default, use first available router + self.routers.iter().next().map(|r| r.value().clone()) + }; + + if let Some(router) = router { + router.get_model_info(req).await + } else { + ( + StatusCode::SERVICE_UNAVAILABLE, + "No routers available", + ) + .into_response() + } } async fn route_generate( From b92e6415cbb02a782bd5c599a32b67183f13b190 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20C=C3=A1mpora?= <961215+dcampora@users.noreply.github.com> Date: Mon, 1 Dec 2025 16:18:34 +0100 Subject: [PATCH 02/30] Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/sglang/srt/models/mistral_large_3.py | 3 ++- python/sglang/srt/utils/hf_transformers_utils.py | 14 +++++--------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/models/mistral_large_3.py b/python/sglang/srt/models/mistral_large_3.py index 586e6a926209..1f0c2c192666 100644 --- a/python/sglang/srt/models/mistral_large_3.py +++ b/python/sglang/srt/models/mistral_large_3.py @@ -56,7 +56,8 @@ def _iterable_remap_mistral_to_ds( name = re.sub(k, v, name) break else: - + import logging + logging.warning(f"Unrecognized weight: {name}. Skipping.") continue # Note(Andy): Unlike Llama, this implementation uses diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 00d2fc173c6f..c68c7486744b 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -197,16 +197,12 @@ def _load_mistral_large_3_for_causal_LM( parser = mistral_utils.MistralConfigParser() config_dict, _ = parser.parse(local_path) - tmp_path = os.path.join(tempfile.gettempdir(), "_tmp_config_folder") - os.makedirs(tmp_path, exist_ok=True) - - unique_path = os.path.join(tmp_path, f"mistral_large_3_{os.getpid()}") - with open(unique_path, "w") as f: + with tempfile.NamedTemporaryFile(mode="w+", suffix=".json") as f: json.dump(config_dict, f) - - loaded_config = AutoConfig.from_pretrained( - unique_path, trust_remote_code=trust_remote_code, revision=revision, **kwargs - ) + f.flush() + loaded_config = AutoConfig.from_pretrained( + f.name, trust_remote_code=trust_remote_code, revision=revision, **kwargs + ) text_config = getattr(loaded_config, "text_config", None) if text_config is not None and isinstance(text_config, dict): text_config = AutoConfig.for_model(**text_config) From 4bfe1755c888d7c4bf58d5d6b9504dae672fb66b Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Mon, 1 Dec 2025 07:20:25 -0800 Subject: [PATCH 03/30] Commit suggestions from gemini. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- python/sglang/srt/layers/quantization/fp8.py | 3 +-- python/sglang/srt/models/deepseek_v2.py | 3 --- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 5b6c6f3788f1..2334b07b5ee1 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -952,8 +952,7 @@ def process_weights_after_loading(self, layer: Module) -> None: requires_grad=False, ) - # Precompute and register per-expert output scaling factors for FI MoE - if get_moe_runner_backend().is_flashinfer_trtllm(): + # Precompute and register per-expert output scaling factors for FI MoE # Note: w13_input_scale and w2_input_scale are scalar Parameters post-reduction assert ( hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e2803e1280a8..485881b6fff5 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1986,9 +1986,6 @@ def forward_absorb_core( q = torch.cat([q_nope_out, q_pe], dim=-1) k = torch.cat([k_nope, k_pe], dim=-1) - if self.llama_4_scaling: - q *= self.llama_4_scaling - attn_output = self.attn_mqa( q, k, From 94f923f291810a774d6232f6f02bd1ba8a8eac14 Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Mon, 1 Dec 2025 07:23:49 -0800 Subject: [PATCH 04/30] Formatting. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- .../kernels/fused_moe_triton/common_utils.py | 7 ++- python/sglang/srt/layers/quantization/fp8.py | 56 +++++++++++++------ python/sglang/srt/models/mistral_large_3.py | 1 + python/sglang/srt/server_args.py | 11 +++- 4 files changed, 54 insertions(+), 21 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/common_utils.py b/benchmark/kernels/fused_moe_triton/common_utils.py index dd8b09f8a79b..cc8e955c5513 100644 --- a/benchmark/kernels/fused_moe_triton/common_utils.py +++ b/benchmark/kernels/fused_moe_triton/common_utils.py @@ -78,7 +78,12 @@ def get_model_config( E = (config.n_routed_experts // ep_size) + ( 0 if disable_shared_experts_fusion - or architecture not in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM", "MistralLarge3ForCausalLM"] + or architecture + not in [ + "DeepseekV3ForCausalLM", + "Glm4MoeForCausalLM", + "MistralLarge3ForCausalLM", + ] else 1 ) topk = config.num_experts_per_tok + ( diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 2334b07b5ee1..1a7af34f4ec5 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -915,14 +915,14 @@ def process_weights_after_loading(self, layer: Module) -> None: ) if _is_hip: self.process_weights_hip_scale_padding(layer) - + # Align FP8 weights to FlashInfer per-tensor kernel layout if enabled if get_moe_runner_backend().is_flashinfer_trtllm(): from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a # Note: No need to swap W13 halves, they are already in the correct order: [Gate, Up] num_experts, two_n, hidden = layer.w13_weight.shape - + # 2) Reorder rows for fused gated activation (W13) w13_interleaved = [ reorder_rows_for_gated_act_gemm(layer.w13_weight[i]) @@ -935,11 +935,15 @@ def process_weights_after_loading(self, layer: Module) -> None: # 3) Shuffle weights for transposed MMA output (both W13, W2) epilogue_tile_m = 128 w13_shuffled = [ - shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m) + shuffle_matrix_a( + w13_interleaved[i].view(torch.uint8), epilogue_tile_m + ) for i in range(num_experts) ] w2_shuffled = [ - shuffle_matrix_a(layer.w2_weight[i].view(torch.uint8), epilogue_tile_m) + shuffle_matrix_a( + layer.w2_weight[i].view(torch.uint8), epilogue_tile_m + ) for i in range(num_experts) ] @@ -955,15 +959,20 @@ def process_weights_after_loading(self, layer: Module) -> None: # Precompute and register per-expert output scaling factors for FI MoE # Note: w13_input_scale and w2_input_scale are scalar Parameters post-reduction assert ( - hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None + hasattr(layer, "w13_input_scale") + and layer.w13_input_scale is not None + ) + assert ( + hasattr(layer, "w2_input_scale") + and layer.w2_input_scale is not None ) - assert hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None assert ( hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None ) assert ( - hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None + hasattr(layer, "w2_weight_scale") + and layer.w2_weight_scale is not None ) input_scale = layer.w13_input_scale.to(torch.float32) @@ -987,7 +996,7 @@ def process_weights_after_loading(self, layer: Module) -> None: output2_scales_scalar, requires_grad=False ) return - + return def process_weights_hip_int4(self, layer: Module): @@ -1071,9 +1080,7 @@ def create_moe_runner( ): from sglang.srt.layers import deep_gemm_wrapper - from sglang.srt.layers.moe.utils import ( - get_moe_a2a_backend, - ) + from sglang.srt.layers.moe.utils import get_moe_a2a_backend self.moe_runner_config = moe_runner_config moe_runner_backend = get_moe_runner_backend() @@ -1293,7 +1300,10 @@ def apply_with_router_logits( activation = self.moe_runner_config.activation routed_scaling_factor = self.moe_runner_config.routed_scaling_factor - from flashinfer.fused_moe import trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe + from flashinfer.fused_moe import ( + trtllm_fp8_block_scale_moe, + trtllm_fp8_per_tensor_scale_moe, + ) from sglang.srt.layers.moe.topk import TopKOutputChecker from sglang.srt.layers.moe.utils import RoutingMethodType @@ -1304,9 +1314,11 @@ def apply_with_router_logits( assert ( activation == "silu" ), "Only silu is supported for flashinfer blockscale fp8 moe" - + if self.block_quant: - a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1]) + a_q, a_sf = per_token_group_quant_fp8( + x, self.quant_config.weight_block_size[1] + ) # NOTE: scales of hidden states have to be transposed! a_sf_t = a_sf.t().contiguous() else: @@ -1318,7 +1330,9 @@ def apply_with_router_logits( else topk_config.correction_bias.to(x.dtype) ) - routing_method_type = getattr(layer, "routing_method_type", RoutingMethodType.DeepSeekV3) + routing_method_type = getattr( + layer, "routing_method_type", RoutingMethodType.DeepSeekV3 + ) with use_symmetric_memory( get_tp_group(), disabled=not is_allocation_symmetric() @@ -1350,7 +1364,9 @@ def apply_with_router_logits( local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, local_num_experts=layer.num_local_experts, routed_scaling_factor=( - routed_scaling_factor if routed_scaling_factor is not None else 1.0 + routed_scaling_factor + if routed_scaling_factor is not None + else 1.0 ), tile_tokens_dim=None, routing_method_type=routing_method_type, @@ -1358,7 +1374,9 @@ def apply_with_router_logits( ) else: routing_bias_cast = ( - None if correction_bias is None else correction_bias.to(torch.bfloat16) + None + if correction_bias is None + else correction_bias.to(torch.bfloat16) ) return trtllm_fp8_per_tensor_scale_moe( @@ -1378,7 +1396,9 @@ def apply_with_router_logits( local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, local_num_experts=layer.num_local_experts, routed_scaling_factor=( - routed_scaling_factor if routed_scaling_factor is not None else 1.0 + routed_scaling_factor + if routed_scaling_factor is not None + else 1.0 ), use_routing_scales_on_input=False, routing_method_type=routing_method_type, diff --git a/python/sglang/srt/models/mistral_large_3.py b/python/sglang/srt/models/mistral_large_3.py index 1f0c2c192666..fd60ef61f7c0 100644 --- a/python/sglang/srt/models/mistral_large_3.py +++ b/python/sglang/srt/models/mistral_large_3.py @@ -57,6 +57,7 @@ def _iterable_remap_mistral_to_ds( break else: import logging + logging.warning(f"Unrecognized weight: {name}. Skipping.") continue diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e465a1543b61..ebe31c5b2214 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -937,10 +937,17 @@ def _handle_model_specific_adjustments(self): hf_config = self.get_hf_config() model_arch = hf_config.architectures[0] - if model_arch in ["MistralLarge3ForCausalLM", "PixtralForConditionalGeneration"]: + if model_arch in [ + "MistralLarge3ForCausalLM", + "PixtralForConditionalGeneration", + ]: self.dtype = "bfloat16" - if model_arch in ["DeepseekV3ForCausalLM", "MistralLarge3ForCausalLM", "PixtralForConditionalGeneration"]: + if model_arch in [ + "DeepseekV3ForCausalLM", + "MistralLarge3ForCausalLM", + "PixtralForConditionalGeneration", + ]: if is_deepseek_nsa(hf_config): if ( self.attention_backend is None From 4fe81ec0ad3f090ef89c04718ec43c09949aad9b Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Mon, 1 Dec 2025 07:37:14 -0800 Subject: [PATCH 05/30] Add back routed scaling factor for nvfp4. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- python/sglang/srt/layers/moe/fused_moe_triton/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 5f8b27f2b385..90a782a60502 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1174,7 +1174,7 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): intermediate_size=self.intermediate_size_per_partition, local_expert_offset=self.moe_ep_rank * self.num_local_experts, local_num_experts=self.num_local_experts, - routed_scaling_factor=None, + routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, tile_tokens_dim=None, routing_method_type=routing_method_type, do_finalize=True, From efc8ffad8ab5de2805b7a51d1f869ea69d932528 Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Mon, 1 Dec 2025 16:12:39 -0800 Subject: [PATCH 06/30] Add back codepath for deepseek_v32 as it was. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- .../sglang/srt/utils/hf_transformers_utils.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index c68c7486744b..bf0cce90e462 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -246,18 +246,21 @@ def get_config( client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) model = client.get_local_dir() - if "mistral-large-3" in model: + if "mistral-large-3" in model.lower(): config = _load_mistral_large_3_for_causal_LM( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) - elif "deepseek_v32" in model: - config = _load_deepseek_v32_model( - model, trust_remote_code=trust_remote_code, revision=revision, **kwargs - ) else: - config = AutoConfig.from_pretrained( - model, trust_remote_code=trust_remote_code, revision=revision, **kwargs - ) + try: + config = AutoConfig.from_pretrained( + model, trust_remote_code=trust_remote_code, revision=revision, **kwargs + ) + except ValueError as e: + if not "deepseek_v32" in str(e): + raise e + config = _load_deepseek_v32_model( + model, trust_remote_code=trust_remote_code, revision=revision, **kwargs + ) if ( config.architectures is not None From 896aa66e2b639337cae50842e9e125165fa47230 Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Mon, 1 Dec 2025 16:29:47 -0800 Subject: [PATCH 07/30] Use str(model).lower(). Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- python/sglang/srt/utils/hf_transformers_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index bf0cce90e462..62638e4a8094 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -246,7 +246,7 @@ def get_config( client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) model = client.get_local_dir() - if "mistral-large-3" in model.lower(): + if "mistral-large-3" in str(model).lower(): config = _load_mistral_large_3_for_causal_LM( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) From f1fa66ffe8bde3816c341f4959a04385a0ead3a3 Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Mon, 1 Dec 2025 18:05:05 -0800 Subject: [PATCH 08/30] Adapt router manager formatting. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- sgl-router/src/routers/router_manager.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index 9a4cea5201f7..65dff42be315 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -380,11 +380,7 @@ impl RouterTrait for RouterManager { if let Some(router) = router { router.get_model_info(req).await } else { - ( - StatusCode::SERVICE_UNAVAILABLE, - "No routers available", - ) - .into_response() + (StatusCode::SERVICE_UNAVAILABLE, "No routers available").into_response() } } From 1a3b88d4deca88515aca0d46b37f472852359350 Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Tue, 2 Dec 2025 01:42:39 -0800 Subject: [PATCH 09/30] Adding fp8 loader for attn. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- .../compressed_tensors/compressed_tensors.py | 11 +- .../schemes/compressed_tensors_w8a8_fp8.py | 187 ++++++++++++------ .../srt/layers/quantization/fp8_utils.py | 107 ++++++++++ 3 files changed, 238 insertions(+), 67 deletions(-) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 0375e25998c9..230a9c7d1ebe 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -306,7 +306,9 @@ def _is_dynamic_token_w8a8( # Only symmetric weight quantization supported. return is_8_bits and is_token and weight_quant.symmetric and is_dynamic - def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: + def _is_fp8_w8a8( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: # Confirm weights and activations quantized. if weight_quant is None or input_quant is None: return False @@ -318,15 +320,16 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: ) is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_or_channel_weight = weight_quant.strategy in [ + is_tensor_or_channel_or_block_weight = weight_quant.strategy in [ QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL, + QuantizationStrategy.BLOCK, ] if not ( is_floating_point and is_symmetric_weight and is_static_weight - and is_per_tensor_or_channel_weight + and is_tensor_or_channel_or_block_weight ): return False @@ -406,7 +409,7 @@ def _get_scheme_from_parts( ) if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( - strategy=weight_quant.strategy, + weight_quant=weight_quant, is_static_input_scheme=( input_quant and not input_quant.dynamic ), diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index ef136a782972..8ac22bdc46e4 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -4,13 +4,14 @@ from typing import Callable, List, Optional import torch -from compressed_tensors.quantization import QuantizationStrategy +from compressed_tensors.quantization import QuantizationStrategy, QuantizationArgs from torch.nn import Parameter from sglang.srt.layers.parameter import ( ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter, + BlockQuantScaleParameter, ) from sglang.srt.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, @@ -20,6 +21,8 @@ apply_fp8_linear, apply_fp8_ptpc_linear, normalize_e4m3fn_to_e4m3fnuz, + validate_fp8_block_shape, + apply_fp8_block_linear, ) from sglang.srt.layers.quantization.utils import requantize_with_max_scale from sglang.srt.utils import get_bool_env_var, is_hip @@ -32,21 +35,108 @@ from aiter.ops.shuffle import shuffle_weight -class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): +strategy_to_parameter_type = { + QuantizationStrategy.BLOCK: BlockQuantScaleParameter, + QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter, + QuantizationStrategy.TENSOR: PerTensorScaleParameter, +} + - def __init__(self, strategy: str, is_static_input_scheme: bool): - self.strategy = strategy +class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): + def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool): + self.weight_quant = weight_quant + self.strategy = self.weight_quant.strategy self.is_static_input_scheme = is_static_input_scheme + self.weight_block_size = self.weight_quant.block_structure @classmethod def get_min_capability(cls) -> int: # lovelace and up return 89 + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.weight_block_size = None + layer.orig_dtype = params_dtype + + if self.strategy == QuantizationStrategy.BLOCK: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + # Validate block quantization shapes + validate_fp8_block_shape( + layer, + input_size, + output_size, + input_size_per_partition, + output_partition_sizes, + self.weight_block_size, + ) + + # WEIGHT + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + elif self.strategy == QuantizationStrategy.TENSOR: + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + elif self.strategy == QuantizationStrategy.BLOCK: + assert layer.weight_block_size is not None + block_n, block_k = layer.weight_block_size[0], layer.weight_block_size[1] + output_size_per_partition = sum(output_partition_sizes) + weight_scale = BlockQuantScaleParameter( + data=torch.empty( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + # min requirement for fp8 kernels + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + input_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", input_scale) + def process_weights_after_loading(self, layer) -> None: - # If per tensor, when we have a fused module (e.g. QKV) with per - # tensor scales (thus N scales being passed to the kernel), - # requantize so we can always run per tensor if self.strategy == QuantizationStrategy.TENSOR: max_w_scale, weight = requantize_with_max_scale( weight=layer.weight, @@ -65,7 +155,6 @@ def process_weights_after_loading(self, layer) -> None: layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) - # If channelwise, scales are already lined up, so just transpose. elif self.strategy == QuantizationStrategy.CHANNEL: weight = layer.weight @@ -93,75 +182,47 @@ def process_weights_after_loading(self, layer) -> None: # required by torch.compile to be torch.nn.Parameter layer.weight_scale = Parameter(weight_scale, requires_grad=False) + elif self.strategy == QuantizationStrategy.BLOCK: + assert self.is_static_input_scheme is False + + if is_fp8_fnuz(): + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.weight, weight_scale=layer.weight_scale + ) + + input_scale = None + else: raise ValueError(f"Unknown quantization strategy {self.strategy}") + # required by torch.compile to be torch.nn.Parameter + layer.weight = Parameter(weight.data, requires_grad=False) + layer.weight_scale = Parameter(weight_scale.data, requires_grad=False) + if input_scale is not None: + layer.input_scale = Parameter(input_scale.data, requires_grad=False) + # INPUT SCALE if self.is_static_input_scheme and hasattr(layer, "input_scale"): layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) else: layer.input_scale = None - def create_weights( - self, - layer: torch.nn.Module, - output_partition_sizes: List[int], - input_size_per_partition: int, - params_dtype: torch.dtype, - weight_loader: Callable, - **kwargs, - ): - output_size_per_partition = sum(output_partition_sizes) - layer.logical_widths = output_partition_sizes - - # WEIGHT - weight = ModelWeightParameter( - data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight", weight) - - # WEIGHT SCALE - # TODO: update create_xxx_parameter functions to return - # the newly added parameters - if self.strategy == QuantizationStrategy.CHANNEL: - weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader, - ) - else: - assert self.strategy == QuantizationStrategy.TENSOR - weight_scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader, - ) - - # min requirement for fp8 kernels - weight_scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("weight_scale", weight_scale) - - # INPUT SCALE - if self.is_static_input_scheme: - input_scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader, - ) - input_scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("input_scale", input_scale) - def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: + if self.weight_block_size is not None: + return apply_fp8_block_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + weight_group_shape=self.weight_block_size, + bias=bias, + ) + if _use_aiter and self.strategy == QuantizationStrategy.CHANNEL: return apply_fp8_ptpc_linear( input=x, diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index f8a98e98fad5..c6c820977b1e 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -971,3 +971,110 @@ def apply_fp8_ptpc_linear( if bias is not None: output = output + bias return output.view(*output_shape) + +def validate_fp8_block_shape( + layer: torch.nn.Module, + input_size: int, + output_size: int, + input_size_per_partition: int, + output_partition_sizes: list[int], + block_size: list[int], +) -> None: + """Validate block quantization shapes for tensor parallelism.""" + from sglang.srt.distributed import get_tensor_model_parallel_world_size + + tp_size = getattr(layer, "tp_size", get_tensor_model_parallel_world_size()) + block_n, block_k = block_size[0], block_size[1] + + # Required by row parallel + if ( + tp_size > 1 + and input_size // input_size_per_partition == tp_size + and input_size_per_partition % block_k != 0 + ): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition} " + f"is not divisible by weight quantization block_k = {block_k}." + ) + + # Required by column parallel or enabling merged weights + is_tp_split = tp_size > 1 and output_size // sum(output_partition_sizes) == tp_size + is_merged_gemm = len(output_partition_sizes) > 1 + if is_tp_split or is_merged_gemm: + sizes_to_check = output_partition_sizes + if not is_tp_split and is_merged_gemm: + # In case of merged matrices, we allow the last + # matrix to not be a multiple of block size + sizes_to_check = output_partition_sizes[:-1] + for output_partition_size in sizes_to_check: + if output_partition_size % block_n != 0: + raise ValueError( + f"Weight output_partition_size = " + f"{output_partition_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + +def apply_fp8_block_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: torch.Tensor | None = None, + weight_group_shape: list[int] | None = None, + bias: torch.Tensor | None = None, +) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: + # input: torch.Tensor, + # scale: Optional[torch.Tensor] = None, + # num_token_padding: Optional[int] = None, + # use_per_token_if_dynamic: bool = False, + assert input_scale is None + + # GroupShape(1, self.weight_block_size[0]) + q_input, input_scale = scaled_fp8_quant( + input_2d, + input_scale=None, + num_token_padding=None, + use_per_token_if_dynamic=False, + ) + + # Fall back to vllm cutlass w8a8 fp8 kernel + output = ops.cutlass_scaled_mm( + q_input, + weight, + input_scale, + weight_scale, + weight_group_shape, + input_2d.dtype, + ) + else: + # Run triton w8a8 fp8 kernel + + assert input_scale is None + # assert self.input_quant_op is not None + # q_input, input_scale = self.input_quant_op(input_2d) + # return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( + # q_input, + # weight, + # input_scale, + # weight_scale, + # weight_group_shape, + # input_2d.dtype, + # ) + + output = triton_w8a8_block_fp8_linear( + input_2d, + weight, + weight_group_shape, + weight_scale, + input_scale, + bias, + ) + + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) \ No newline at end of file From 0353f52f96c37c673c18223e05e62fd124f55d88 Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Tue, 2 Dec 2025 02:26:10 -0800 Subject: [PATCH 10/30] Use preexisting dispatch fn. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- .../schemes/compressed_tensors_w8a8_fp8.py | 6 +- .../srt/layers/quantization/fp8_utils.py | 65 ------------------- 2 files changed, 3 insertions(+), 68 deletions(-) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 8ac22bdc46e4..57451d2291df 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -22,7 +22,7 @@ apply_fp8_ptpc_linear, normalize_e4m3fn_to_e4m3fnuz, validate_fp8_block_shape, - apply_fp8_block_linear, + dispatch_w8a8_block_fp8_linear, ) from sglang.srt.layers.quantization.utils import requantize_with_max_scale from sglang.srt.utils import get_bool_env_var, is_hip @@ -214,12 +214,12 @@ def apply_weights( bias: torch.Tensor | None = None, ) -> torch.Tensor: if self.weight_block_size is not None: - return apply_fp8_block_linear( + return dispatch_w8a8_block_fp8_linear( input=x, weight=layer.weight, + block_size=self.weight_block_size, weight_scale=layer.weight_scale, input_scale=layer.input_scale, - weight_group_shape=self.weight_block_size, bias=bias, ) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index c6c820977b1e..c181d859b3e3 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -1013,68 +1013,3 @@ def validate_fp8_block_shape( f"{output_partition_size} is not divisible by " f"weight quantization block_n = {block_n}." ) - -def apply_fp8_block_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - input_scale: torch.Tensor | None = None, - weight_group_shape: list[int] | None = None, - bias: torch.Tensor | None = None, -) -> torch.Tensor: - assert input_scale is None - # View input as 2D matrix for fp8 methods - input_2d = input.view(-1, input.shape[-1]) - output_shape = [*input.shape[:-1], weight.shape[0]] - - if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: - # input: torch.Tensor, - # scale: Optional[torch.Tensor] = None, - # num_token_padding: Optional[int] = None, - # use_per_token_if_dynamic: bool = False, - assert input_scale is None - - # GroupShape(1, self.weight_block_size[0]) - q_input, input_scale = scaled_fp8_quant( - input_2d, - input_scale=None, - num_token_padding=None, - use_per_token_if_dynamic=False, - ) - - # Fall back to vllm cutlass w8a8 fp8 kernel - output = ops.cutlass_scaled_mm( - q_input, - weight, - input_scale, - weight_scale, - weight_group_shape, - input_2d.dtype, - ) - else: - # Run triton w8a8 fp8 kernel - - assert input_scale is None - # assert self.input_quant_op is not None - # q_input, input_scale = self.input_quant_op(input_2d) - # return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( - # q_input, - # weight, - # input_scale, - # weight_scale, - # weight_group_shape, - # input_2d.dtype, - # ) - - output = triton_w8a8_block_fp8_linear( - input_2d, - weight, - weight_group_shape, - weight_scale, - input_scale, - bias, - ) - - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) \ No newline at end of file From 69186b3ce8ebdf6db6119a78ff5e0e51206b3dd9 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Tue, 2 Dec 2025 02:28:17 -0800 Subject: [PATCH 11/30] fp8 block moe --- .../compressed_tensors_moe.py | 103 +++++++++++++++++- .../srt/layers/quantization/fp8_utils.py | 7 ++ 2 files changed, 108 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 5ab499f67074..024fb10410ac 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -10,7 +10,10 @@ import torch from compressed_tensors import CompressionFormat from compressed_tensors.quantization import QuantizationStrategy +from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor +from sglang.multimodal_gen.runtime.distributed import get_tp_world_size +from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase @@ -18,7 +21,11 @@ WNA16_SUPPORTED_BITS, ) from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant -from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz +from sglang.srt.layers.quantization.fp8_utils import ( + expert_weight_is_col_major, + normalize_e4m3fn_to_e4m3fnuz, + requant_weight_ue8m0_inplace, +) from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales from sglang.srt.layers.quantization.utils import ( @@ -81,8 +88,8 @@ def get_moe_method( weight_quant = quant_config.target_scheme_map["Linear"].get("weights") input_quant = quant_config.target_scheme_map["Linear"].get("input_activations") - if quant_config._is_wNa16_group_channel(weight_quant, input_quant): + if quant_config._is_wNa16_group_channel(weight_quant, input_quant): logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MoEMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): @@ -102,7 +109,28 @@ def __init__(self, quant_config: CompressedTensorsConfig): "input_activations" ) + per_tensor = ( + self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy == QuantizationStrategy.TENSOR + ) + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN + ) + if not (per_tensor or per_channel): + assert self.weight_quant.strategy == QuantizationStrategy.BLOCK + self.weight_block_size = self.weight_quant.block_structure + assert self.weight_quant.dynamic is not None + else: + self.weight_block_size = None + self.block_quant = self.weight_block_size is not None + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales and per_channel: + raise ValueError( + "For FP8 Fused MoE layer, we require either per tensor or " + "channelwise, dynamic per token quantization." + ) def create_weights( self, @@ -117,6 +145,32 @@ def create_weights( params_dtype = torch.float8_e4m3fn + if self.block_quant: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + tp_size = get_tp_world_size() + block_n, block_k = ( + self.weight_block_size[0], + self.weight_block_size[1], + ) + # NOTE: To ensure proper alignment of the block-wise quantization + # scales, the output_size of the weights for both the gate and up + # layers must be divisible by block_n. + # Required by column parallel or enabling merged weights + if intermediate_size_per_partition % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1 and intermediate_size_per_partition % block_k != 0: + # Required by row parallel + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( @@ -169,6 +223,26 @@ def create_weights( requires_grad=False, ) weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value + elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + weight_quant_method = FusedMoeWeightScaleSupported.BLOCK.value else: raise ValueError( f"Unsupported weight quantization strategy: {self.weight_quant.strategy}" @@ -291,6 +365,31 @@ def process_weights_after_loading(self, layer: torch.nn.Module | FusedMoE) -> No ) torch.cuda.empty_cache() + if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 and self.block_quant: + assert layer.weight_block_size is not None + # Re-quantise the expert weights so their scales are UE8M0. + block_sz = tuple(layer.weight_block_size) + requant_weight_ue8m0_inplace( + layer.w13_weight.data, + layer.w13_weight_scale.data, + block_sz, + ) + requant_weight_ue8m0_inplace( + layer.w2_weight.data, + layer.w2_weight_scale.data, + block_sz, + ) + + # Ensure column-major TMA alignment expected by DeepGEMM. + if expert_weight_is_col_major(layer.w13_weight_scale): + layer.w13_weight_scale = get_mn_major_tma_aligned_tensor( + layer.w13_weight_scale + ) + if expert_weight_is_col_major(layer.w2_weight_scale): + layer.w2_weight_scale = get_mn_major_tma_aligned_tensor( + layer.w2_weight_scale + ) + def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index c181d859b3e3..c6de3f4cb0ac 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -972,6 +972,7 @@ def apply_fp8_ptpc_linear( output = output + bias return output.view(*output_shape) + def validate_fp8_block_shape( layer: torch.nn.Module, input_size: int, @@ -1013,3 +1014,9 @@ def validate_fp8_block_shape( f"{output_partition_size} is not divisible by " f"weight quantization block_n = {block_n}." ) + + +def expert_weight_is_col_major(x: torch.Tensor) -> bool: + assert x.dim() == 3 + b, m, n = x.shape + return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m From 7aac21691bf6239bd9459ed3b1bb94807b387c35 Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Tue, 2 Dec 2025 03:09:54 -0800 Subject: [PATCH 12/30] Fixing issues. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- .../compressed_tensors/compressed_tensors.py | 1 - .../compressed_tensors/compressed_tensors_moe.py | 4 ++-- python/sglang/srt/models/deepseek_v2.py | 14 ++++++++++++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 230a9c7d1ebe..4e6a799e4cd5 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -91,7 +91,6 @@ def __init__( self.sparsity_ignore_list = sparsity_ignore_list self.config = config self.packed_modules_mapping = packed_modules_mapping or {} - # FP8 config for linear layers, compressed tensor currently does not support block fp8, this is used for ktransformers self.linear_fp8_config = linear_fp8_config def get_linear_method(self) -> CompressedTensorsLinearMethod: diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 024fb10410ac..cb1a07878c5a 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -12,7 +12,7 @@ from compressed_tensors.quantization import QuantizationStrategy from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor -from sglang.multimodal_gen.runtime.distributed import get_tp_world_size +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo @@ -148,7 +148,7 @@ def create_weights( if self.block_quant: assert self.weight_block_size is not None layer.weight_block_size = self.weight_block_size - tp_size = get_tp_world_size() + tp_size = get_tensor_model_parallel_world_size() block_n, block_k = ( self.weight_block_size[0], self.weight_block_size[1], diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 485881b6fff5..7915f2be5239 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -101,6 +101,7 @@ from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat from sglang.srt.layers.moe.utils import RoutingMethodType from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsLinearMethod from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8_kernel import ( is_fp8_fnuz, @@ -718,6 +719,19 @@ def __init__( ): # For compressed-tensors ptpc model, don't need to check the weight_block_size pass + elif ( + type(self.shared_experts.gate_up_proj.quant_method) == CompressedTensorsLinearMethod + ): + gate_up_proj_config = self.shared_experts.gate_up_proj.quant_method.quantization_config.config + down_proj_config = self.shared_experts.down_proj.quant_method.quantization_config.config + + gate_up_proj_group_size = gate_up_proj_config["config_groups"]["FP8_BLOCK"]["input_activations"]["group_size"] + down_proj_group_size = down_proj_config["config_groups"]["FP8_BLOCK"]["input_activations"]["group_size"] + + assert gate_up_proj_group_size == down_proj_group_size + self.shared_experts_weight_block_size = gate_up_proj_group_size + + # config {'config_groups': {'FP8_BLOCK': {'format': 'float-quantized', 'input_activations': {'actorder': None, 'block_structure': None, 'dynamic': True, 'group_size': 128, 'num_bits': 8, 'observer': None, 'observer_kwargs': {}, 'strategy': 'group', 'symmetric': True, 'type': 'float'}, 'output_activations': None, 'targets': ['Linear'], 'weights': {'actorder': None, 'block_structure': [128, 128], 'dynamic': False, 'group_size': None, 'num_bits': 8, 'observer': 'static_minmax', 'observer_kwargs': {}, 'strategy': 'block', 'symmetric': True, 'type': 'float'}}}, 'format': 'float-quantized', 'global_compression_ratio': None, 'ignore': ['model.embed_tokens', 're:patch_merger.*', 're:vision_encoder.*', 're:vision_language_adapter.*', 're:.*kv_a_proj_with_mqa$', 're:.*q_a_proj$', 're:.*gate$', 'lm_head'], 'kv_cache_scheme': None, 'quant_method': 'compressed-tensors', 'quantization_status': 'compressed', 'sparsity_config': {}, 'transform_config': {}, 'version': '0.12.3.dev29+g73c2cf9.d20251119', 'packed_modules_mapping': {}} else: assert ( self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size From ee68d595a46e83239355a3bf3181b13202dcd883 Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Tue, 2 Dec 2025 03:16:50 -0800 Subject: [PATCH 13/30] Fix. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- python/sglang/srt/layers/linear.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index f3500540dd62..6c66ab2091b7 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -31,6 +31,7 @@ RowvLLMParameter, _ColumnvLLMParameter, ) +from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.utils import pad_or_narrow_weight from sglang.srt.utils import get_bool_env_var, is_cpu, is_hip, is_npu, set_weight_attrs @@ -735,7 +736,10 @@ def weight_loader_v2( assert loaded_shard_id < len(self.output_sizes) if isinstance(param, BlockQuantScaleParameter): - weight_block_size = self.quant_method.quant_config.weight_block_size + if type(self.quant_method) == CompressedTensorsLinearMethod: + weight_block_size = self.quant_method.quantization_config.config["config_groups"]["FP8_BLOCK"]["input_activations"]["group_size"] + else: + weight_block_size = self.quant_method.quant_config.weight_block_size raw_block_n, _ = weight_block_size[0], weight_block_size[1] block_n = 1 if getattr(param, "format_ue8m0", False) else raw_block_n shard_offset = ( From bc4ed51b39b2ff4a8b374a95296d901445e4de57 Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Tue, 2 Dec 2025 03:21:42 -0800 Subject: [PATCH 14/30] Possible fix. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- python/sglang/srt/layers/linear.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 6c66ab2091b7..bc0d57faf25d 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -737,7 +737,7 @@ def weight_loader_v2( if isinstance(param, BlockQuantScaleParameter): if type(self.quant_method) == CompressedTensorsLinearMethod: - weight_block_size = self.quant_method.quantization_config.config["config_groups"]["FP8_BLOCK"]["input_activations"]["group_size"] + weight_block_size = self.quant_method.quantization_config.config["config_groups"]["FP8_BLOCK"]["weights"]["block_structure"] else: weight_block_size = self.quant_method.quant_config.weight_block_size raw_block_n, _ = weight_block_size[0], weight_block_size[1] diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 7915f2be5239..5c524681acdc 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -725,11 +725,11 @@ def __init__( gate_up_proj_config = self.shared_experts.gate_up_proj.quant_method.quantization_config.config down_proj_config = self.shared_experts.down_proj.quant_method.quantization_config.config - gate_up_proj_group_size = gate_up_proj_config["config_groups"]["FP8_BLOCK"]["input_activations"]["group_size"] - down_proj_group_size = down_proj_config["config_groups"]["FP8_BLOCK"]["input_activations"]["group_size"] + gate_up_proj_block_structure = gate_up_proj_config["config_groups"]["FP8_BLOCK"]["weights"]["block_structure"] + down_proj_block_structure = down_proj_config["config_groups"]["FP8_BLOCK"]["weights"]["block_structure"] - assert gate_up_proj_group_size == down_proj_group_size - self.shared_experts_weight_block_size = gate_up_proj_group_size + assert gate_up_proj_block_structure == down_proj_block_structure + self.shared_experts_weight_block_size = gate_up_proj_block_structure # config {'config_groups': {'FP8_BLOCK': {'format': 'float-quantized', 'input_activations': {'actorder': None, 'block_structure': None, 'dynamic': True, 'group_size': 128, 'num_bits': 8, 'observer': None, 'observer_kwargs': {}, 'strategy': 'group', 'symmetric': True, 'type': 'float'}, 'output_activations': None, 'targets': ['Linear'], 'weights': {'actorder': None, 'block_structure': [128, 128], 'dynamic': False, 'group_size': None, 'num_bits': 8, 'observer': 'static_minmax', 'observer_kwargs': {}, 'strategy': 'block', 'symmetric': True, 'type': 'float'}}}, 'format': 'float-quantized', 'global_compression_ratio': None, 'ignore': ['model.embed_tokens', 're:patch_merger.*', 're:vision_encoder.*', 're:vision_language_adapter.*', 're:.*kv_a_proj_with_mqa$', 're:.*q_a_proj$', 're:.*gate$', 'lm_head'], 'kv_cache_scheme': None, 'quant_method': 'compressed-tensors', 'quantization_status': 'compressed', 'sparsity_config': {}, 'transform_config': {}, 'version': '0.12.3.dev29+g73c2cf9.d20251119', 'packed_modules_mapping': {}} else: From 74cda590e242a19996adfabe6c13c256f8469d15 Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Tue, 2 Dec 2025 04:16:58 -0800 Subject: [PATCH 15/30] Test. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- python/sglang/srt/layers/linear.py | 5 +---- .../compressed_tensors/compressed_tensors.py | 5 +++++ python/sglang/srt/layers/quantization/fp8.py | 12 +++++----- python/sglang/srt/models/deepseek_v2.py | 5 +++++ python/sglang/srt/utils/mistral_utils.py | 22 +++++++++++++++++++ 5 files changed, 38 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index bc0d57faf25d..a9c38000ff52 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -736,10 +736,7 @@ def weight_loader_v2( assert loaded_shard_id < len(self.output_sizes) if isinstance(param, BlockQuantScaleParameter): - if type(self.quant_method) == CompressedTensorsLinearMethod: - weight_block_size = self.quant_method.quantization_config.config["config_groups"]["FP8_BLOCK"]["weights"]["block_structure"] - else: - weight_block_size = self.quant_method.quant_config.weight_block_size + weight_block_size = self.quant_method.quant_config.weight_block_size raw_block_n, _ = weight_block_size[0], weight_block_size[1] block_n = 1 if getattr(param, "format_ue8m0", False) else raw_block_n shard_offset = ( diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 4e6a799e4cd5..af2b7c16f06b 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -35,6 +35,7 @@ CompressedTensorsW8A16Fp8, CompressedTensorsWNA16, ) +from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, @@ -93,6 +94,7 @@ def __init__( self.packed_modules_mapping = packed_modules_mapping or {} self.linear_fp8_config = linear_fp8_config + def get_linear_method(self) -> CompressedTensorsLinearMethod: return CompressedTensorsLinearMethod(self) @@ -158,6 +160,8 @@ def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig: if "linear_fp8_config" in config: from sglang.srt.layers.quantization.fp8 import Fp8Config + print("linear_fp8_config", config["linear_fp8_config"]) + fp8_cfg = config["linear_fp8_config"] # Check if it's fp8 format based on quant_method field is_fp8 = fp8_cfg.get("quant_method") == "fp8" @@ -610,6 +614,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase): def __init__(self, quantization_config: CompressedTensorsConfig): self.quantization_config = quantization_config + self.quant_config = quantization_config def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.scheme.process_weights_after_loading(layer) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 1a7af34f4ec5..ab97877a32a5 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -131,10 +131,10 @@ def __init__( raise ValueError( f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions." ) - if activation_scheme != "dynamic": - raise ValueError( - f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme." - ) + # if activation_scheme != "dynamic": + # raise ValueError( + # f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme." + # ) self.weight_block_size = weight_block_size @classmethod @@ -285,9 +285,7 @@ def create_weights( if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE if self.block_quant: - if hasattr(self.quant_config, "activation_scheme"): - assert self.quant_config.activation_scheme == "dynamic" - elif hasattr(self.quant_config, "linear_activation_scheme"): + if hasattr(self.quant_config, "linear_activation_scheme"): assert self.quant_config.linear_activation_scheme == "dynamic" scale = BlockQuantScaleParameter( data=torch.empty( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5c524681acdc..4807b5e9daa9 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -3581,6 +3581,11 @@ def post_load_weights(self, is_nextn=False, weight_names=None): weight_block_size = getattr( selected_quant_config, "weight_block_size", None ) + + print("weight_block_size", weight_block_size) + print("selected_quant_config", selected_quant_config) + print("self.quant_config", self.quant_config) + if weight_block_size is not None: assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") if _is_fp8_fnuz: diff --git a/python/sglang/srt/utils/mistral_utils.py b/python/sglang/srt/utils/mistral_utils.py index d4ce44aded5a..a205872ecf44 100644 --- a/python/sglang/srt/utils/mistral_utils.py +++ b/python/sglang/srt/utils/mistral_utils.py @@ -104,6 +104,28 @@ def _remap_mistral_vision_args(config: dict) -> dict: } if quant_config: config["quantization_config"] = quant_config + + if quant_config["config_groups"]["FP8_BLOCK"] is not None: + config["quantization_config"]["linear_fp8_config"] = { + "quant_method": "fp8", + "activation_scheme": "dynamic" if quant_config["config_groups"]["FP8_BLOCK"]["input_activations"]["dynamic"] else "static", + "ignored_layers": quant_config["ignore"], + "weight_block_size": quant_config["config_groups"]["FP8_BLOCK"]["weights"]["block_structure"], + } + + # linear_fp8_config = None + # if "linear_fp8_config" in config: + # from sglang.srt.layers.quantization.fp8 import Fp8Config + + # fp8_cfg = config["linear_fp8_config"] + # # Check if it's fp8 format based on quant_method field + # is_fp8 = fp8_cfg.get("quant_method") == "fp8" + # linear_fp8_config = Fp8Config( + # is_checkpoint_fp8_serialized=is_fp8, + # activation_scheme=fp8_cfg.get("activation_scheme", "dynamic"), + # ignored_layers=fp8_cfg.get("ignored_layers"), + # weight_block_size=fp8_cfg.get("weight_block_size"), + # ) return config From 7f0c02177a7e033e027692de73193ca65d0c9ced Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Tue, 2 Dec 2025 04:20:02 -0800 Subject: [PATCH 16/30] Revert "Test." This reverts commit 74cda590e242a19996adfabe6c13c256f8469d15. --- python/sglang/srt/layers/linear.py | 5 ++++- .../compressed_tensors/compressed_tensors.py | 5 ----- python/sglang/srt/layers/quantization/fp8.py | 12 +++++----- python/sglang/srt/models/deepseek_v2.py | 5 ----- python/sglang/srt/utils/mistral_utils.py | 22 ------------------- 5 files changed, 11 insertions(+), 38 deletions(-) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index a9c38000ff52..bc0d57faf25d 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -736,7 +736,10 @@ def weight_loader_v2( assert loaded_shard_id < len(self.output_sizes) if isinstance(param, BlockQuantScaleParameter): - weight_block_size = self.quant_method.quant_config.weight_block_size + if type(self.quant_method) == CompressedTensorsLinearMethod: + weight_block_size = self.quant_method.quantization_config.config["config_groups"]["FP8_BLOCK"]["weights"]["block_structure"] + else: + weight_block_size = self.quant_method.quant_config.weight_block_size raw_block_n, _ = weight_block_size[0], weight_block_size[1] block_n = 1 if getattr(param, "format_ue8m0", False) else raw_block_n shard_offset = ( diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index af2b7c16f06b..4e6a799e4cd5 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -35,7 +35,6 @@ CompressedTensorsW8A16Fp8, CompressedTensorsWNA16, ) -from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, @@ -94,7 +93,6 @@ def __init__( self.packed_modules_mapping = packed_modules_mapping or {} self.linear_fp8_config = linear_fp8_config - def get_linear_method(self) -> CompressedTensorsLinearMethod: return CompressedTensorsLinearMethod(self) @@ -160,8 +158,6 @@ def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig: if "linear_fp8_config" in config: from sglang.srt.layers.quantization.fp8 import Fp8Config - print("linear_fp8_config", config["linear_fp8_config"]) - fp8_cfg = config["linear_fp8_config"] # Check if it's fp8 format based on quant_method field is_fp8 = fp8_cfg.get("quant_method") == "fp8" @@ -614,7 +610,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase): def __init__(self, quantization_config: CompressedTensorsConfig): self.quantization_config = quantization_config - self.quant_config = quantization_config def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.scheme.process_weights_after_loading(layer) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index ab97877a32a5..1a7af34f4ec5 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -131,10 +131,10 @@ def __init__( raise ValueError( f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions." ) - # if activation_scheme != "dynamic": - # raise ValueError( - # f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme." - # ) + if activation_scheme != "dynamic": + raise ValueError( + f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme." + ) self.weight_block_size = weight_block_size @classmethod @@ -285,7 +285,9 @@ def create_weights( if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE if self.block_quant: - if hasattr(self.quant_config, "linear_activation_scheme"): + if hasattr(self.quant_config, "activation_scheme"): + assert self.quant_config.activation_scheme == "dynamic" + elif hasattr(self.quant_config, "linear_activation_scheme"): assert self.quant_config.linear_activation_scheme == "dynamic" scale = BlockQuantScaleParameter( data=torch.empty( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 4807b5e9daa9..5c524681acdc 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -3581,11 +3581,6 @@ def post_load_weights(self, is_nextn=False, weight_names=None): weight_block_size = getattr( selected_quant_config, "weight_block_size", None ) - - print("weight_block_size", weight_block_size) - print("selected_quant_config", selected_quant_config) - print("self.quant_config", self.quant_config) - if weight_block_size is not None: assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") if _is_fp8_fnuz: diff --git a/python/sglang/srt/utils/mistral_utils.py b/python/sglang/srt/utils/mistral_utils.py index a205872ecf44..d4ce44aded5a 100644 --- a/python/sglang/srt/utils/mistral_utils.py +++ b/python/sglang/srt/utils/mistral_utils.py @@ -104,28 +104,6 @@ def _remap_mistral_vision_args(config: dict) -> dict: } if quant_config: config["quantization_config"] = quant_config - - if quant_config["config_groups"]["FP8_BLOCK"] is not None: - config["quantization_config"]["linear_fp8_config"] = { - "quant_method": "fp8", - "activation_scheme": "dynamic" if quant_config["config_groups"]["FP8_BLOCK"]["input_activations"]["dynamic"] else "static", - "ignored_layers": quant_config["ignore"], - "weight_block_size": quant_config["config_groups"]["FP8_BLOCK"]["weights"]["block_structure"], - } - - # linear_fp8_config = None - # if "linear_fp8_config" in config: - # from sglang.srt.layers.quantization.fp8 import Fp8Config - - # fp8_cfg = config["linear_fp8_config"] - # # Check if it's fp8 format based on quant_method field - # is_fp8 = fp8_cfg.get("quant_method") == "fp8" - # linear_fp8_config = Fp8Config( - # is_checkpoint_fp8_serialized=is_fp8, - # activation_scheme=fp8_cfg.get("activation_scheme", "dynamic"), - # ignored_layers=fp8_cfg.get("ignored_layers"), - # weight_block_size=fp8_cfg.get("weight_block_size"), - # ) return config From 15d620158368e516895241c5fd3fb77a5eb0c5ab Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Tue, 2 Dec 2025 04:47:12 -0800 Subject: [PATCH 17/30] fix quant config --- python/sglang/srt/layers/linear.py | 6 +----- .../compressed_tensors/compressed_tensors.py | 10 ++++++++++ python/sglang/srt/models/deepseek_v2.py | 6 ++++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index bc0d57faf25d..f3500540dd62 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -31,7 +31,6 @@ RowvLLMParameter, _ColumnvLLMParameter, ) -from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.utils import pad_or_narrow_weight from sglang.srt.utils import get_bool_env_var, is_cpu, is_hip, is_npu, set_weight_attrs @@ -736,10 +735,7 @@ def weight_loader_v2( assert loaded_shard_id < len(self.output_sizes) if isinstance(param, BlockQuantScaleParameter): - if type(self.quant_method) == CompressedTensorsLinearMethod: - weight_block_size = self.quant_method.quantization_config.config["config_groups"]["FP8_BLOCK"]["weights"]["block_structure"] - else: - weight_block_size = self.quant_method.quant_config.weight_block_size + weight_block_size = self.quant_method.quant_config.weight_block_size raw_block_n, _ = weight_block_size[0], weight_block_size[1] block_n = 1 if getattr(param, "format_ue8m0", False) else raw_block_n shard_offset = ( diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 4e6a799e4cd5..31f47b88bc2f 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -141,6 +141,15 @@ def get_quant_method( return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix) return None + @property + def weight_block_size(self) -> Optional[List[int]]: + """Get the weight block size from the quantization config.""" + if "Linear" in self.target_scheme_map: + weights_config = self.target_scheme_map["Linear"].get("weights") + if weights_config and hasattr(weights_config, "block_structure"): + return weights_config.block_structure + return None + @classmethod def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig: ignore: List[str] = cast(List[str], config.get("ignore", [])) @@ -610,6 +619,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase): def __init__(self, quantization_config: CompressedTensorsConfig): self.quantization_config = quantization_config + self.quant_config = quantization_config def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.scheme.process_weights_after_loading(layer) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5c524681acdc..ce4412c522d5 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -724,7 +724,7 @@ def __init__( ): gate_up_proj_config = self.shared_experts.gate_up_proj.quant_method.quantization_config.config down_proj_config = self.shared_experts.down_proj.quant_method.quantization_config.config - + gate_up_proj_block_structure = gate_up_proj_config["config_groups"]["FP8_BLOCK"]["weights"]["block_structure"] down_proj_block_structure = down_proj_config["config_groups"]["FP8_BLOCK"]["weights"]["block_structure"] @@ -3576,8 +3576,10 @@ def post_load_weights(self, is_nextn=False, weight_names=None): ): # For mixed quantization (experts int4, linear fp8), use linear_fp8_config selected_quant_config = getattr( - self.quant_config, "linear_fp8_config", self.quant_config + self.quant_config, "linear_fp8_config", None ) + if selected_quant_config is None: + selected_quant_config = self.quant_config weight_block_size = getattr( selected_quant_config, "weight_block_size", None ) From f4ad69c000896a3bb9b314a9e4124e2cf44edff1 Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Tue, 2 Dec 2025 05:22:36 -0800 Subject: [PATCH 18/30] Add fix. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- .../compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 57451d2291df..486d27325065 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -122,6 +122,8 @@ def create_weights( output_dim=0, weight_loader=weight_loader, ) + # The weight_scale_inv name is intentional for deepseekv3 + layer.register_parameter("weight_scale_inv", weight_scale) # min requirement for fp8 kernels weight_scale[:] = torch.finfo(torch.float32).min From 0a48bd0d346e6a955a76cfcb704f0c41f892bec0 Mon Sep 17 00:00:00 2001 From: Daniel Campora <961215+dcampora@users.noreply.github.com> Date: Tue, 2 Dec 2025 06:02:25 -0800 Subject: [PATCH 19/30] Fix. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- .../schemes/compressed_tensors_w8a8_fp8.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 486d27325065..1d7d82ab7249 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -103,11 +103,15 @@ def create_weights( output_dim=0, weight_loader=weight_loader, ) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) elif self.strategy == QuantizationStrategy.TENSOR: weight_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, - ) + ) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) elif self.strategy == QuantizationStrategy.BLOCK: assert layer.weight_block_size is not None block_n, block_k = layer.weight_block_size[0], layer.weight_block_size[1] @@ -122,13 +126,10 @@ def create_weights( output_dim=0, weight_loader=weight_loader, ) - # The weight_scale_inv name is intentional for deepseekv3 + weight_scale.format_ue8m0 = False + weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale_inv", weight_scale) - # min requirement for fp8 kernels - weight_scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("weight_scale", weight_scale) - # INPUT SCALE if self.is_static_input_scheme: input_scale = PerTensorScaleParameter( From 90f589a904ba88a7baf395b99cb60d5f4abb0f76 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Tue, 2 Dec 2025 08:33:04 -0800 Subject: [PATCH 20/30] fix to run --- .../compressed_tensors_moe.py | 8 ++--- .../schemes/compressed_tensors_w8a8_fp8.py | 34 +++++++++---------- .../srt/layers/quantization/fp8_utils.py | 2 +- python/sglang/srt/models/mistral_large_3.py | 3 ++ 4 files changed, 25 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index cb1a07878c5a..9007b1897199 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -370,13 +370,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module | FusedMoE) -> No # Re-quantise the expert weights so their scales are UE8M0. block_sz = tuple(layer.weight_block_size) requant_weight_ue8m0_inplace( - layer.w13_weight.data, - layer.w13_weight_scale.data, + layer.w13_weight, + layer.w13_weight_scale, block_sz, ) requant_weight_ue8m0_inplace( - layer.w2_weight.data, - layer.w2_weight_scale.data, + layer.w2_weight, + layer.w2_weight_scale, block_sz, ) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 1d7d82ab7249..05b73eecd021 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -1,17 +1,17 @@ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Optional +from typing import Callable import torch -from compressed_tensors.quantization import QuantizationStrategy, QuantizationArgs +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from torch.nn import Parameter from sglang.srt.layers.parameter import ( + BlockQuantScaleParameter, ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter, - BlockQuantScaleParameter, ) from sglang.srt.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, @@ -20,9 +20,9 @@ from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, apply_fp8_ptpc_linear, + dispatch_w8a8_block_fp8_linear, normalize_e4m3fn_to_e4m3fnuz, validate_fp8_block_shape, - dispatch_w8a8_block_fp8_linear, ) from sglang.srt.layers.quantization.utils import requantize_with_max_scale from sglang.srt.utils import get_bool_env_var, is_hip @@ -48,6 +48,8 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) self.strategy = self.weight_quant.strategy self.is_static_input_scheme = is_static_input_scheme self.weight_block_size = self.weight_quant.block_structure + if self.weight_block_size is not None: + self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear() @classmethod def get_min_capability(cls) -> int: @@ -187,23 +189,21 @@ def process_weights_after_loading(self, layer) -> None: elif self.strategy == QuantizationStrategy.BLOCK: assert self.is_static_input_scheme is False - + weight = layer.weight + weight_scale_inv = layer.weight_scale_inv + if is_fp8_fnuz(): - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=layer.weight, weight_scale=layer.weight_scale + weight, weight_scale_inv, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale_inv ) - - input_scale = None + layer.weight = Parameter(weight.data, requires_grad=False) + layer.weight_scale_inv = Parameter( + weight_scale_inv.data, requires_grad=False + ) else: raise ValueError(f"Unknown quantization strategy {self.strategy}") - # required by torch.compile to be torch.nn.Parameter - layer.weight = Parameter(weight.data, requires_grad=False) - layer.weight_scale = Parameter(weight_scale.data, requires_grad=False) - if input_scale is not None: - layer.input_scale = Parameter(input_scale.data, requires_grad=False) - # INPUT SCALE if self.is_static_input_scheme and hasattr(layer, "input_scale"): layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) @@ -217,11 +217,11 @@ def apply_weights( bias: torch.Tensor | None = None, ) -> torch.Tensor: if self.weight_block_size is not None: - return dispatch_w8a8_block_fp8_linear( + return self.w8a8_block_fp8_linear( input=x, weight=layer.weight, block_size=self.weight_block_size, - weight_scale=layer.weight_scale, + weight_scale=layer.weight_scale_inv, input_scale=layer.input_scale, bias=bias, ) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index c6de3f4cb0ac..e44b7ce22131 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -523,7 +523,7 @@ def requant_weight_ue8m0( weight_scale_inv: torch.Tensor, weight_block_size: List[int], ): - assert weight_block_size == [128, 128] + assert tuple(weight_block_size) == (128, 128) *_, n, k = weight.shape diff --git a/python/sglang/srt/models/mistral_large_3.py b/python/sglang/srt/models/mistral_large_3.py index fd60ef61f7c0..719d47915407 100644 --- a/python/sglang/srt/models/mistral_large_3.py +++ b/python/sglang/srt/models/mistral_large_3.py @@ -72,6 +72,9 @@ def _iterable_remap_mistral_to_ds( elif name.endswith(".qscale_weight"): name = re.sub(r"\.qscale_weight$", ".weight_scale", name) + if name.endswith(".weight_scale") and ".experts." not in name: + name = re.sub(r"\.weight_scale$", ".weight_scale_inv", name) + yield name, loaded_weight From c7be4b383a4c4368f2b42a4366b12e77376412b2 Mon Sep 17 00:00:00 2001 From: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> Date: Tue, 2 Dec 2025 09:18:59 -0800 Subject: [PATCH 21/30] tuple simplification Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> --- .../compressed_tensors/compressed_tensors_moe.py | 5 ++--- python/sglang/srt/layers/quantization/fp8_utils.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 9007b1897199..cecb6d60ff1a 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -368,16 +368,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module | FusedMoE) -> No if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 and self.block_quant: assert layer.weight_block_size is not None # Re-quantise the expert weights so their scales are UE8M0. - block_sz = tuple(layer.weight_block_size) requant_weight_ue8m0_inplace( layer.w13_weight, layer.w13_weight_scale, - block_sz, + layer.weight_block_size, ) requant_weight_ue8m0_inplace( layer.w2_weight, layer.w2_weight_scale, - block_sz, + layer.weight_block_size, ) # Ensure column-major TMA alignment expected by DeepGEMM. diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index e44b7ce22131..c6de3f4cb0ac 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -523,7 +523,7 @@ def requant_weight_ue8m0( weight_scale_inv: torch.Tensor, weight_block_size: List[int], ): - assert tuple(weight_block_size) == (128, 128) + assert weight_block_size == [128, 128] *_, n, k = weight.shape From 1ef9948ff1ab90115bdd4b865d7a738552f29c9c Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Tue, 2 Dec 2025 19:39:47 -0800 Subject: [PATCH 22/30] fix lint --- .../schemes/compressed_tensors_w8a8_fp8.py | 4 ++-- python/sglang/srt/models/deepseek_v2.py | 14 -------------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 05b73eecd021..2effad0360f9 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -1,7 +1,7 @@ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors # SPDX-License-Identifier: Apache-2.0 -from typing import Callable +from typing import Callable, Optional import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy @@ -214,7 +214,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: torch.Tensor | None = None, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.weight_block_size is not None: return self.w8a8_block_fp8_linear( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index ce4412c522d5..21fa4e5f2c90 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -101,7 +101,6 @@ from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat from sglang.srt.layers.moe.utils import RoutingMethodType from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsLinearMethod from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8_kernel import ( is_fp8_fnuz, @@ -719,19 +718,6 @@ def __init__( ): # For compressed-tensors ptpc model, don't need to check the weight_block_size pass - elif ( - type(self.shared_experts.gate_up_proj.quant_method) == CompressedTensorsLinearMethod - ): - gate_up_proj_config = self.shared_experts.gate_up_proj.quant_method.quantization_config.config - down_proj_config = self.shared_experts.down_proj.quant_method.quantization_config.config - - gate_up_proj_block_structure = gate_up_proj_config["config_groups"]["FP8_BLOCK"]["weights"]["block_structure"] - down_proj_block_structure = down_proj_config["config_groups"]["FP8_BLOCK"]["weights"]["block_structure"] - - assert gate_up_proj_block_structure == down_proj_block_structure - self.shared_experts_weight_block_size = gate_up_proj_block_structure - - # config {'config_groups': {'FP8_BLOCK': {'format': 'float-quantized', 'input_activations': {'actorder': None, 'block_structure': None, 'dynamic': True, 'group_size': 128, 'num_bits': 8, 'observer': None, 'observer_kwargs': {}, 'strategy': 'group', 'symmetric': True, 'type': 'float'}, 'output_activations': None, 'targets': ['Linear'], 'weights': {'actorder': None, 'block_structure': [128, 128], 'dynamic': False, 'group_size': None, 'num_bits': 8, 'observer': 'static_minmax', 'observer_kwargs': {}, 'strategy': 'block', 'symmetric': True, 'type': 'float'}}}, 'format': 'float-quantized', 'global_compression_ratio': None, 'ignore': ['model.embed_tokens', 're:patch_merger.*', 're:vision_encoder.*', 're:vision_language_adapter.*', 're:.*kv_a_proj_with_mqa$', 're:.*q_a_proj$', 're:.*gate$', 'lm_head'], 'kv_cache_scheme': None, 'quant_method': 'compressed-tensors', 'quantization_status': 'compressed', 'sparsity_config': {}, 'transform_config': {}, 'version': '0.12.3.dev29+g73c2cf9.d20251119', 'packed_modules_mapping': {}} else: assert ( self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size From b9d5f644350d597cbfc4b72aa6d41977d812d57a Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Tue, 2 Dec 2025 22:32:55 -0800 Subject: [PATCH 23/30] fix dsv3 accuracy --- python/sglang/srt/layers/moe/fused_moe_triton/layer.py | 10 +++------- python/sglang/srt/layers/quantization/fp8.py | 4 +--- python/sglang/srt/models/deepseek_v2.py | 6 ++---- python/sglang/srt/utils/mistral_utils.py | 2 +- 4 files changed, 7 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index e9b43674fb86..32c4d2fd7306 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1084,7 +1084,9 @@ def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor): hs_fp4_bytes, hs_sf_bytes = fp4_quantize( hidden_states, self.w13_input_scale_quant, - is_sf_swizzled_layout=False, + 16, # sf_vec_size + False, # use_ue8m0 + False, # is_sf_swizzled_layout ) hs_fp4 = hs_fp4_bytes.reshape( @@ -1125,12 +1127,6 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): else topk_config.correction_bias.to(hidden_states.dtype) ) - router_logits = ( - router_logits.to(torch.float32) - if routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits - ) - with use_symmetric_memory( get_tp_group(), disabled=not is_allocation_symmetric() ): diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 1a7af34f4ec5..349ddd15f3ba 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -563,7 +563,6 @@ def create_weights( if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.uint32 if _use_hip_int4 else torch.float8_e4m3fn - tp_size = get_tensor_model_parallel_world_size() if self.block_quant: block_n, block_k = ( @@ -913,6 +912,7 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_weight_scale = torch.nn.Parameter( max_w13_scales, requires_grad=False ) + if _is_hip: self.process_weights_hip_scale_padding(layer) @@ -997,8 +997,6 @@ def process_weights_after_loading(self, layer: Module) -> None: ) return - return - def process_weights_hip_int4(self, layer: Module): # TODO: _use_aiter: add after triton kernel added # INT4-FP8 (INT4 MoE Weight, FP8 Compute) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 21fa4e5f2c90..b3b2fa0d7e1b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -649,10 +649,8 @@ def __init__( layer_id=self.layer_id, quant_config=quant_config, routed_scaling_factor=self.routed_scaling_factor, - routing_method_type=( - RoutingMethodType.Renormalize - if config.norm_topk_prob - else RoutingMethodType.DeepSeekV3 + routing_method_type=getattr( + config, "routing_method_type", RoutingMethodType.DeepSeekV3 ), prefix=add_prefix("experts", prefix), ) diff --git a/python/sglang/srt/utils/mistral_utils.py b/python/sglang/srt/utils/mistral_utils.py index d4ce44aded5a..34372fb68d4e 100644 --- a/python/sglang/srt/utils/mistral_utils.py +++ b/python/sglang/srt/utils/mistral_utils.py @@ -235,7 +235,7 @@ def _remap_moe_args(config: dict) -> dict: config[new_name] = value config["topk_method"] = None - config["norm_topk_prob"] = True + config["routing_method_type"] = 1 # RoutingMethodType.Renormalize config["scoring_func"] = "softmax" return config From dc500a0f14b7e386277d49285c8948b877e69057 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Tue, 2 Dec 2025 23:03:03 -0800 Subject: [PATCH 24/30] fix ckpt loading --- .../sglang/srt/utils/hf_transformers_utils.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 62638e4a8094..4cacc88f49d7 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -500,12 +500,20 @@ def get_processor( ): # pop 'revision' from kwargs if present. revision = kwargs.pop("revision", tokenizer_revision) - config = AutoConfig.from_pretrained( - tokenizer_name, - trust_remote_code=trust_remote_code, - revision=revision, - **kwargs, - ) + if "mistral-large-3" in str(tokenizer_name).lower(): + config = _load_mistral_large_3_for_causal_LM( + tokenizer_name, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + else: + config = AutoConfig.from_pretrained( + tokenizer_name, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) if _is_deepseek_ocr_model(config): # Temporary hack for load deepseek-ocr config.model_type = "deepseek-ocr" From 77319c4568d4d65a58d0f8d52856211898811f55 Mon Sep 17 00:00:00 2001 From: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> Date: Wed, 3 Dec 2025 04:59:41 -0800 Subject: [PATCH 25/30] fix no module named deep_gemm error Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> --- .../compressed_tensors/compressed_tensors_moe.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index cecb6d60ff1a..0b86186734b0 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -10,7 +10,11 @@ import torch from compressed_tensors import CompressionFormat from compressed_tensors.quantization import QuantizationStrategy -from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor + +from python.sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM + +if ENABLE_JIT_DEEPGEMM: + from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers import deep_gemm_wrapper From 623b33aa7f723313a703ffb962efe5194cf78bf3 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Wed, 3 Dec 2025 05:11:04 -0800 Subject: [PATCH 26/30] fix triton moe --- .../compressed_tensors_moe.py | 48 +++++-------------- 1 file changed, 13 insertions(+), 35 deletions(-) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 0b86186734b0..8d0d475ac1f2 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -11,13 +11,7 @@ from compressed_tensors import CompressionFormat from compressed_tensors.quantization import QuantizationStrategy -from python.sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM - -if ENABLE_JIT_DEEPGEMM: - from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor - from sglang.srt.distributed import get_tensor_model_parallel_world_size -from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase @@ -25,11 +19,7 @@ WNA16_SUPPORTED_BITS, ) from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant -from sglang.srt.layers.quantization.fp8_utils import ( - expert_weight_is_col_major, - normalize_e4m3fn_to_e4m3fnuz, - requant_weight_ue8m0_inplace, -) +from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales from sglang.srt.layers.quantization.utils import ( @@ -369,30 +359,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module | FusedMoE) -> No ) torch.cuda.empty_cache() - if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 and self.block_quant: - assert layer.weight_block_size is not None - # Re-quantise the expert weights so their scales are UE8M0. - requant_weight_ue8m0_inplace( - layer.w13_weight, - layer.w13_weight_scale, - layer.weight_block_size, - ) - requant_weight_ue8m0_inplace( - layer.w2_weight, - layer.w2_weight_scale, - layer.weight_block_size, - ) - - # Ensure column-major TMA alignment expected by DeepGEMM. - if expert_weight_is_col_major(layer.w13_weight_scale): - layer.w13_weight_scale = get_mn_major_tma_aligned_tensor( - layer.w13_weight_scale - ) - if expert_weight_is_col_major(layer.w2_weight_scale): - layer.w2_weight_scale = get_mn_major_tma_aligned_tensor( - layer.w2_weight_scale - ) - def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): @@ -445,6 +411,18 @@ def apply( a2_scale=layer.w2_input_scale, ) return StandardCombineInput(hidden_states=output) + elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + use_fp8_w8a8=True, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a13_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.weight_block_size, + ) + return self.runner.run(dispatch_output, quant_info) else: quant_info = TritonMoeQuantInfo( w13_weight=layer.w13_weight, From 73abf6bf80ef5f21384486c526113bca273ab820 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Wed, 3 Dec 2025 21:55:29 +0800 Subject: [PATCH 27/30] clean up (#2) --- python/sglang/srt/layers/quantization/fp8.py | 197 ++++-------------- .../srt/layers/quantization/fp8_utils.py | 6 - .../srt/layers/quantization/modelopt_quant.py | 30 +-- python/sglang/srt/models/deepseek_v2.py | 7 +- 4 files changed, 49 insertions(+), 191 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 349ddd15f3ba..f86fcf413b10 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -915,86 +915,6 @@ def process_weights_after_loading(self, layer: Module) -> None: if _is_hip: self.process_weights_hip_scale_padding(layer) - - # Align FP8 weights to FlashInfer per-tensor kernel layout if enabled - if get_moe_runner_backend().is_flashinfer_trtllm(): - from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a - - # Note: No need to swap W13 halves, they are already in the correct order: [Gate, Up] - num_experts, two_n, hidden = layer.w13_weight.shape - - # 2) Reorder rows for fused gated activation (W13) - w13_interleaved = [ - reorder_rows_for_gated_act_gemm(layer.w13_weight[i]) - for i in range(num_experts) - ] - w13_interleaved = torch.stack(w13_interleaved).reshape( - num_experts, two_n, hidden - ) - - # 3) Shuffle weights for transposed MMA output (both W13, W2) - epilogue_tile_m = 128 - w13_shuffled = [ - shuffle_matrix_a( - w13_interleaved[i].view(torch.uint8), epilogue_tile_m - ) - for i in range(num_experts) - ] - w2_shuffled = [ - shuffle_matrix_a( - layer.w2_weight[i].view(torch.uint8), epilogue_tile_m - ) - for i in range(num_experts) - ] - - layer.w13_weight = Parameter( - torch.stack(w13_shuffled).view(torch.float8_e4m3fn), - requires_grad=False, - ) - layer.w2_weight = Parameter( - torch.stack(w2_shuffled).view(torch.float8_e4m3fn), - requires_grad=False, - ) - - # Precompute and register per-expert output scaling factors for FI MoE - # Note: w13_input_scale and w2_input_scale are scalar Parameters post-reduction - assert ( - hasattr(layer, "w13_input_scale") - and layer.w13_input_scale is not None - ) - assert ( - hasattr(layer, "w2_input_scale") - and layer.w2_input_scale is not None - ) - assert ( - hasattr(layer, "w13_weight_scale") - and layer.w13_weight_scale is not None - ) - assert ( - hasattr(layer, "w2_weight_scale") - and layer.w2_weight_scale is not None - ) - - input_scale = layer.w13_input_scale.to(torch.float32) - activation_scale = layer.w2_input_scale.to(torch.float32) - w13_weight_scale = layer.w13_weight_scale.to(torch.float32) - w2_weight_scale = layer.w2_weight_scale.to(torch.float32) - - output1_scales_scalar = ( - w13_weight_scale * input_scale * (1.0 / activation_scale) - ) - output1_scales_gate_scalar = w13_weight_scale * input_scale - output2_scales_scalar = activation_scale * w2_weight_scale - - layer.output1_scales_scalar = Parameter( - output1_scales_scalar, requires_grad=False - ) - layer.output1_scales_gate_scalar = Parameter( - output1_scales_gate_scalar, requires_grad=False - ) - layer.output2_scales_scalar = Parameter( - output2_scales_scalar, requires_grad=False - ) return def process_weights_hip_int4(self, layer: Module): @@ -1298,10 +1218,7 @@ def apply_with_router_logits( activation = self.moe_runner_config.activation routed_scaling_factor = self.moe_runner_config.routed_scaling_factor - from flashinfer.fused_moe import ( - trtllm_fp8_block_scale_moe, - trtllm_fp8_per_tensor_scale_moe, - ) + from flashinfer.fused_moe import trtllm_fp8_block_scale_moe from sglang.srt.layers.moe.topk import TopKOutputChecker from sglang.srt.layers.moe.utils import RoutingMethodType @@ -1312,15 +1229,9 @@ def apply_with_router_logits( assert ( activation == "silu" ), "Only silu is supported for flashinfer blockscale fp8 moe" - - if self.block_quant: - a_q, a_sf = per_token_group_quant_fp8( - x, self.quant_config.weight_block_size[1] - ) - # NOTE: scales of hidden states have to be transposed! - a_sf_t = a_sf.t().contiguous() - else: - a_q, _ = scaled_fp8_quant(x, layer.w13_input_scale) + a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1]) + # NOTE: scales of hidden states have to be transposed! + a_sf_t = a_sf.t().contiguous() correction_bias = ( None @@ -1328,79 +1239,43 @@ def apply_with_router_logits( else topk_config.correction_bias.to(x.dtype) ) - routing_method_type = getattr( - layer, "routing_method_type", RoutingMethodType.DeepSeekV3 - ) + routing_method_type = getattr(layer, "routing_method_type") with use_symmetric_memory( get_tp_group(), disabled=not is_allocation_symmetric() ): - if self.block_quant: - # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. - # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 - # so we put the whole function under the ``use_symmetric_memory`` context manager. - # If the bug is fixed, we can only put the output tensor allocation under the context manager. - return trtllm_fp8_block_scale_moe( - routing_logits=( - router_logits.to(torch.float32) - if routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits - ), - routing_bias=correction_bias, - hidden_states=a_q, - hidden_states_scale=a_sf_t, - gemm1_weights=layer.w13_weight, - gemm1_weights_scale=layer.w13_weight_scale_inv, - gemm2_weights=layer.w2_weight, - gemm2_weights_scale=layer.w2_weight_scale_inv, - num_experts=layer.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=layer.w2_weight.shape[2], - local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, - local_num_experts=layer.num_local_experts, - routed_scaling_factor=( - routed_scaling_factor - if routed_scaling_factor is not None - else 1.0 - ), - tile_tokens_dim=None, - routing_method_type=routing_method_type, - use_shuffled_weight=False, - ) - else: - routing_bias_cast = ( - None - if correction_bias is None - else correction_bias.to(torch.bfloat16) - ) - - return trtllm_fp8_per_tensor_scale_moe( - routing_logits=router_logits.to(torch.bfloat16), - routing_bias=routing_bias_cast, - hidden_states=a_q, - gemm1_weights=layer.w13_weight, - output1_scales_scalar=layer.output1_scales_scalar, - output1_scales_gate_scalar=layer.output1_scales_gate_scalar, - gemm2_weights=layer.w2_weight, - output2_scales_scalar=layer.output2_scales_scalar, - num_experts=layer.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=layer.w2_weight.shape[2], - local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, - local_num_experts=layer.num_local_experts, - routed_scaling_factor=( - routed_scaling_factor - if routed_scaling_factor is not None - else 1.0 - ), - use_routing_scales_on_input=False, - routing_method_type=routing_method_type, - ) + # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. + # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 + # so we put the whole function under the ``use_symmetric_memory`` context manager. + # If the bug is fixed, we can only put the output tensor allocation under the context manager. + return trtllm_fp8_block_scale_moe( + routing_logits=( + router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ), + routing_bias=correction_bias, + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale_inv, + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale_inv, + num_experts=layer.num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=layer.w2_weight.shape[2], + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + routed_scaling_factor=( + routed_scaling_factor if routed_scaling_factor is not None else 1.0 + ), + tile_tokens_dim=None, + routing_method_type=routing_method_type, + use_shuffled_weight=False, + ) def maybe_apply_hip_fused_experts( self, diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index c6de3f4cb0ac..70914c9d3fd2 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -1014,9 +1014,3 @@ def validate_fp8_block_shape( f"{output_partition_size} is not divisible by " f"weight quantization block_n = {block_n}." ) - - -def expert_weight_is_col_major(x: torch.Tensor) -> bool: - assert x.dim() == 3 - b, m, n = x.shape - return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index ba31c12c6fdb..6b106379eec0 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1565,27 +1565,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13_weight_scale_2 = layer.w13_weight_scale_2[:] layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) - def _slice_scale(w): - assert w.shape == (layer.num_experts,) - assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts - return w[ - layer.moe_ep_rank - * layer.num_local_experts : (layer.moe_ep_rank + 1) - * layer.num_local_experts - ] - # Calculate input scales based on strategy if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe: - w13_input_scale = ( - layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts) - ) - w2_input_scale = ( - layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts) - ) - - if layer.moe_ep_size > 1: - w13_input_scale = _slice_scale(w13_input_scale) - w2_input_scale = _slice_scale(w2_input_scale) + w13_input_scale = layer.w13_input_scale.max().to(torch.float32) + w2_input_scale = layer.w2_input_scale.max().to(torch.float32) elif self.enable_flashinfer_cutedsl_moe: # All-expert-one-input-scale is mathematically different from default per-expert-input-scale # Thus we allow users to switch the flag to do thorough testing @@ -1602,6 +1585,15 @@ def _slice_scale(w): w2_input_scale = layer.w2_input_scale + def _slice_scale(w): + assert w.shape == (layer.num_experts,) + assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts + return w[ + layer.moe_ep_rank + * layer.num_local_experts : (layer.moe_ep_rank + 1) + * layer.num_local_experts + ] + w13_input_scale = _slice_scale(w13_input_scale) w2_input_scale = _slice_scale(w2_input_scale) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index b3b2fa0d7e1b..9dc9fed76524 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -649,9 +649,7 @@ def __init__( layer_id=self.layer_id, quant_config=quant_config, routed_scaling_factor=self.routed_scaling_factor, - routing_method_type=getattr( - config, "routing_method_type", RoutingMethodType.DeepSeekV3 - ), + routing_method_type=RoutingMethodType.DeepSeekV3, prefix=add_prefix("experts", prefix), ) @@ -3349,7 +3347,6 @@ def forward( class DeepseekV2ForCausalLM(nn.Module): # for quark model load packed_modules_mapping = {} - model_cls = DeepseekV2Model def __init__( self, @@ -3376,7 +3373,7 @@ def __init__( self.quant_config = quant_config self.determine_num_fused_shared_experts() self.use_nsa = is_deepseek_nsa(config) - self.model = self.model_cls( + self.model = DeepseekV2Model( config, quant_config, prefix=add_prefix("model", prefix) ) if self.pp_group.is_last_rank: From d7a7b6cc0ad1405328b627a84f1f7f44b0c84a81 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Mon, 1 Dec 2025 00:02:58 -0800 Subject: [PATCH 28/30] add init llama4 scale --- .../layers/attention/trtllm_mla_backend.py | 11 ++++ python/sglang/srt/models/deepseek_v2.py | 51 ++++++++++++++++--- 2 files changed, 55 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 0c8f832ed26b..7afa0a09ed7c 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -802,6 +802,7 @@ def forward_decode( k_rope: Optional[torch.Tensor] = None, cos_sin_cache: Optional[torch.Tensor] = None, is_neox: Optional[bool] = False, + llama_4_scaling: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Run forward for decode using TRTLLM MLA kernel.""" merge_query = q_rope is not None @@ -843,6 +844,11 @@ def forward_decode( # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function query = q.view(-1, layer.tp_q_head_num, layer.head_dim) + # Apply llama 4 scaling if provided + if llama_4_scaling is not None: + query = query.to(self.q_data_type) * llama_4_scaling + query = query.to(self.data_type) + # Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1 if query.dim() == 3: query = query.unsqueeze(1) @@ -903,6 +909,7 @@ def forward_extend( k_rope: Optional[torch.Tensor] = None, cos_sin_cache: Optional[torch.Tensor] = None, is_neox: Optional[bool] = False, + llama_4_scaling: Optional[torch.Tensor] = None, ) -> torch.Tensor: if ( @@ -955,6 +962,10 @@ def forward_extend( q = q.view(-1, layer.tp_q_head_num, layer.head_dim) + # Apply llama 4 scaling if provided + if llama_4_scaling is not None: + q *= llama_4_scaling + if ( forward_batch.forward_mode.is_target_verify() or forward_batch.forward_mode.is_draft_extend(include_v2=True) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 9dc9fed76524..1331fe916724 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1158,6 +1158,16 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: return 0.1 * mscale * math.log(scale) + 1.0 +def _get_llama_4_scaling( + original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor +) -> torch.Tensor: + scaling = 1 + scaling_beta * torch.log( + 1 + torch.floor(positions / original_max_position_embeddings) + ) + # Broadcast over num_heads and head_dim + return scaling[..., None, None] + + class DeepseekV2AttentionMLA(nn.Module): def __init__( @@ -1212,12 +1222,6 @@ def __init__( if rope_scaling: rope_scaling["rope_type"] = "deepseek_yarn" - self.llama_4_scaling = ( - config.to_dict()["llama_4_scaling"]["beta"] - if "llama_4_scaling" in config - else None - ) - # For tensor parallel attention if self.q_lora_rank is not None: self.fused_qkv_a_proj_with_mqa = ReplicatedLinear( @@ -1470,12 +1474,14 @@ def forward( hidden_states: torch.Tensor, forward_batch: ForwardBatch, zero_allocator: BumpAllocator, + llama_4_scaling: Optional[torch.Tensor], ): s = self.forward_prepare( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, zero_allocator=zero_allocator, + llama_4_scaling=llama_4_scaling, ) return self.forward_core(s) @@ -1485,6 +1491,7 @@ def forward_prepare( hidden_states: torch.Tensor, forward_batch: ForwardBatch, zero_allocator: BumpAllocator, + llama_4_scaling: Optional[torch.Tensor], ): if self.attn_mha.kv_b_proj is None: self.attn_mha.kv_b_proj = self.kv_b_proj @@ -1525,7 +1532,11 @@ def forward_prepare( elif attn_forward_method == AttnForwardMethod.MLA: if not self.is_mla_preprocess_enabled: inner_state = self.forward_absorb_prepare( - positions, hidden_states, forward_batch, zero_allocator + positions, + hidden_states, + forward_batch, + zero_allocator, + llama_4_scaling, ) else: # TODO(iforgetmyname): to be separated as a standalone func @@ -1756,6 +1767,7 @@ def forward_absorb_prepare( hidden_states: torch.Tensor, forward_batch: ForwardBatch, zero_allocator: BumpAllocator, + llama_4_scaling: Optional[torch.Tensor], ): from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode @@ -1933,6 +1945,7 @@ def forward_absorb_prepare( zero_allocator, positions, topk_indices, + llama_4_scaling, ) def forward_absorb_core( @@ -1945,6 +1958,7 @@ def forward_absorb_core( zero_allocator, positions, topk_indices, + llama_4_scaling, ): if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS: extra_args = {} @@ -1952,6 +1966,7 @@ def forward_absorb_core( extra_args = { "cos_sin_cache": self.rotary_emb.cos_sin_cache, "is_neox": self.rotary_emb.is_neox_style, + "llama_4_scaling": llama_4_scaling, } attn_output = self.attn_mqa( @@ -1982,6 +1997,10 @@ def forward_absorb_core( q = torch.cat([q_nope_out, q_pe], dim=-1) k = torch.cat([k_nope, k_pe], dim=-1) + # Apply llama 4 scaling if provided + if llama_4_scaling is not None: + q *= llama_4_scaling + attn_output = self.attn_mqa( q, k, @@ -2972,6 +2991,7 @@ def forward( residual: Optional[torch.Tensor], zero_allocator: BumpAllocator, gemm_output_zero_allocator: BumpAllocator = None, + llama_4_scaling: Optional[torch.Tensor] = None, ) -> torch.Tensor: quant_format = ( "mxfp4" @@ -3012,6 +3032,7 @@ def forward( hidden_states=hidden_states, forward_batch=forward_batch, zero_allocator=zero_allocator, + llama_4_scaling=llama_4_scaling, ) hidden_states, residual = self.layer_communicator.prepare_mlp( @@ -3230,6 +3251,9 @@ def __init__( ) self.layers_to_capture = [] + # llama_4_scaling: for supporting Mistral-Large-3 model + self.llama_4_scaling_config = getattr(config, "llama_4_scaling", None) + def get_input_embeddings(self) -> torch.Tensor: return self.embed_tokens @@ -3278,6 +3302,18 @@ def forward( if enable_prefill_cp(forward_batch, self.nsa_enable_prefill_cp): hidden_states = cp_split_and_rebuild_data(forward_batch, hidden_states) + # llama_4_scaling: for supporting Mistral-Large-3 model + # Compute llama 4 scaling once per forward pass if enabled + llama_4_scaling: Optional[torch.Tensor] = None + if self.llama_4_scaling_config is not None: + llama_4_scaling = _get_llama_4_scaling( + original_max_position_embeddings=self.llama_4_scaling_config[ + "original_max_position_embeddings" + ], + scaling_beta=self.llama_4_scaling_config["beta"], + positions=positions, + ) + normal_start_layer = self.start_layer normal_end_layer = self.end_layer if forward_batch.can_run_tbo: @@ -3301,6 +3337,7 @@ def forward( residual, zero_allocator, gemm_output_zero_allocator, + llama_4_scaling, ) if normal_end_layer != self.end_layer: From 1b1c84bfa32d3548945c5eb4d40bfb11f978adff Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Wed, 3 Dec 2025 06:25:48 -0800 Subject: [PATCH 29/30] clean up eagle --- python/sglang/srt/configs/model_config.py | 1 - .../srt/models/mistral_large_3_eagle.py | 104 ------------------ python/sglang/srt/utils/mistral_utils.py | 8 -- 3 files changed, 113 deletions(-) delete mode 100644 python/sglang/srt/models/mistral_large_3_eagle.py diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index dd04b9295c2d..87fc66b63c6f 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -338,7 +338,6 @@ def _derive_model_shapes(self): or "DotsVLMForCausalLM" in self.hf_config.architectures or "MistralLarge3ForCausalLM" in self.hf_config.architectures or "PixtralForConditionalGeneration" in self.hf_config.architectures - or "MistralLarge3ForCausalLMEagle" in self.hf_config.architectures ): self.head_dim = 256 self.attention_arch = AttentionArch.MLA diff --git a/python/sglang/srt/models/mistral_large_3_eagle.py b/python/sglang/srt/models/mistral_large_3_eagle.py deleted file mode 100644 index e85a629e76f0..000000000000 --- a/python/sglang/srt/models/mistral_large_3_eagle.py +++ /dev/null @@ -1,104 +0,0 @@ -from typing import Optional - -import torch -from torch import nn -from transformers import PretrainedConfig - -from python.sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp -from sglang.srt.distributed import get_pp_group -from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.linear import RowParallelLinear -from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors -from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV2Model -from sglang.srt.models.mistral_large_3 import MistralLarge3ForCausalLM -from sglang.srt.utils import add_prefix - - -class MistralLarge3Model(DeepseekV2Model): - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - nn.Module.__init__(self) - - self.config = config - self.vocab_size = config.vocab_size - assert get_pp_group().world_size == 1 - self.pp_group = get_pp_group() - self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() - - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - prefix=add_prefix("embed_tokens", prefix), - ) - - self.layers = nn.ModuleList( - [ - DeepseekV2DecoderLayer( - config=config, - prefix=add_prefix(prefix, f"layers.{i}"), - quant_config=quant_config, - layer_id=i, - ) - for i in range(self.config.num_hidden_layers) - ] - ) - self.start_layer = 0 - self.end_layer = self.config.num_hidden_layers - - self.fc = RowParallelLinear( - self.config.hidden_size * 2, - self.config.hidden_size, - bias=False, - quant_config=quant_config, - prefix=add_prefix(prefix, "fc"), - input_is_parallel=False, - ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.layers_to_capture = [] - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - forward_batch: ForwardBatch, - input_embeds: torch.Tensor = None, - pp_proxy_tensors: Optional[PPProxyTensors] = None, - ) -> torch.Tensor: - if input_embeds is None: - input_embeds = self.embed_tokens(input_ids) - input_embeds, _ = self.fc( - torch.cat((input_embeds, forward_batch.spec_info.hidden_states), dim=-1) - ) - output = super().forward( - input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors - ) - assert isinstance(output, torch.Tensor) - return output - - -class MistralLarge3ForCausalLMEagle(MistralLarge3ForCausalLM): - remapping = MistralLarge3ForCausalLM.remapping | { - r"eagle_linear\.weight": r"model.fc.weight", - r"eagle_linear\.qscale_act": r"model.fc.input_scale", - r"eagle_linear\.qscale_weight": r"model.fc.weight_scale", - } - - def __init__( - self, - *, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - config.quant_config = quant_config - self.model_cls = MistralLarge3Model - super().__init__(config=config, quant_config=quant_config, prefix=prefix) - - -EntryClass = [MistralLarge3ForCausalLMEagle] diff --git a/python/sglang/srt/utils/mistral_utils.py b/python/sglang/srt/utils/mistral_utils.py index 34372fb68d4e..e23abb53c5fd 100644 --- a/python/sglang/srt/utils/mistral_utils.py +++ b/python/sglang/srt/utils/mistral_utils.py @@ -70,7 +70,6 @@ def adapt_config_dict( "encoder_args" ) ) - is_eagle = "eagle" in model assert not (is_vision and is_audio), "Vision and audio are mutually exclusive" @@ -78,8 +77,6 @@ def adapt_config_dict( config_dict = _remap_mistral_vision_args(config_dict) if is_audio: config_dict = _remap_mistral_audio_args(config_dict) - if is_eagle: - config_dict = _remap_mistral_eagle_args(config_dict) config = PretrainedConfig.from_dict(config_dict) @@ -211,11 +208,6 @@ def _remap_mistral_audio_args(config: dict) -> dict: return config -def _remap_mistral_eagle_args(config: dict) -> dict: - config["architectures"] = ["MistralLarge3ForCausalLMEagle"] - return config - - def _remap_moe_args(config: dict) -> dict: moe_config_map = { "route_every_n": "moe_layer_freq", From 5d52732a4e71bd670ebd0754f2bd45f2711dc7b0 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Wed, 3 Dec 2025 22:44:55 -0800 Subject: [PATCH 30/30] fix Llama4 scaling None --- python/sglang/srt/models/deepseek_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 1331fe916724..06e63d58229e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1474,7 +1474,7 @@ def forward( hidden_states: torch.Tensor, forward_batch: ForwardBatch, zero_allocator: BumpAllocator, - llama_4_scaling: Optional[torch.Tensor], + llama_4_scaling: Optional[torch.Tensor] = None, ): s = self.forward_prepare( positions=positions, @@ -1491,7 +1491,7 @@ def forward_prepare( hidden_states: torch.Tensor, forward_batch: ForwardBatch, zero_allocator: BumpAllocator, - llama_4_scaling: Optional[torch.Tensor], + llama_4_scaling: Optional[torch.Tensor] = None, ): if self.attn_mha.kv_b_proj is None: self.attn_mha.kv_b_proj = self.kv_b_proj @@ -1767,7 +1767,7 @@ def forward_absorb_prepare( hidden_states: torch.Tensor, forward_batch: ForwardBatch, zero_allocator: BumpAllocator, - llama_4_scaling: Optional[torch.Tensor], + llama_4_scaling: Optional[torch.Tensor] = None, ): from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode