diff --git a/benchmark/kernels/fused_moe_triton/common_utils.py b/benchmark/kernels/fused_moe_triton/common_utils.py index 40f8697f1e4e..cc8e955c5513 100644 --- a/benchmark/kernels/fused_moe_triton/common_utils.py +++ b/benchmark/kernels/fused_moe_triton/common_utils.py @@ -73,11 +73,17 @@ def get_model_config( "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", "Glm4MoeForCausalLM", + "MistralLarge3ForCausalLM", ]: E = (config.n_routed_experts // ep_size) + ( 0 if disable_shared_experts_fusion - or architecture not in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"] + or architecture + not in [ + "DeepseekV3ForCausalLM", + "Glm4MoeForCausalLM", + "MistralLarge3ForCausalLM", + ] else 1 ) topk = config.num_experts_per_tok + ( diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 26dfbe5eb1d5..87fc66b63c6f 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -58,6 +58,8 @@ def is_deepseek_nsa(config: PretrainedConfig) -> bool: "DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM", "DeepseekV3ForCausalLMNextN", + "MistralLarge3ForCausalLM", + "PixtralForConditionalGeneration", ] and getattr(config, "index_topk", None) is not None ) @@ -334,6 +336,8 @@ def _derive_model_shapes(self): or "LongcatFlashForCausalLM" in self.hf_config.architectures or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures or "DotsVLMForCausalLM" in self.hf_config.architectures + or "MistralLarge3ForCausalLM" in self.hf_config.architectures + or "PixtralForConditionalGeneration" in self.hf_config.architectures ): self.head_dim = 256 self.attention_arch = AttentionArch.MLA @@ -939,6 +943,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal "MultiModalityCausalLM", "MllamaForConditionalGeneration", "NemotronH_Nano_VL_V2", + "PixtralForConditionalGeneration", "Qwen2AudioForConditionalGeneration", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration", diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 0c8f832ed26b..7afa0a09ed7c 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -802,6 +802,7 @@ def forward_decode( k_rope: Optional[torch.Tensor] = None, cos_sin_cache: Optional[torch.Tensor] = None, is_neox: Optional[bool] = False, + llama_4_scaling: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Run forward for decode using TRTLLM MLA kernel.""" merge_query = q_rope is not None @@ -843,6 +844,11 @@ def forward_decode( # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function query = q.view(-1, layer.tp_q_head_num, layer.head_dim) + # Apply llama 4 scaling if provided + if llama_4_scaling is not None: + query = query.to(self.q_data_type) * llama_4_scaling + query = query.to(self.data_type) + # Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1 if query.dim() == 3: query = query.unsqueeze(1) @@ -903,6 +909,7 @@ def forward_extend( k_rope: Optional[torch.Tensor] = None, cos_sin_cache: Optional[torch.Tensor] = None, is_neox: Optional[bool] = False, + llama_4_scaling: Optional[torch.Tensor] = None, ) -> torch.Tensor: if ( @@ -955,6 +962,10 @@ def forward_extend( q = q.view(-1, layer.tp_q_head_num, layer.head_dim) + # Apply llama 4 scaling if provided + if llama_4_scaling is not None: + q *= llama_4_scaling + if ( forward_batch.forward_mode.is_target_verify() or forward_batch.forward_mode.is_draft_extend(include_v2=True) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 0375e25998c9..31f47b88bc2f 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -91,7 +91,6 @@ def __init__( self.sparsity_ignore_list = sparsity_ignore_list self.config = config self.packed_modules_mapping = packed_modules_mapping or {} - # FP8 config for linear layers, compressed tensor currently does not support block fp8, this is used for ktransformers self.linear_fp8_config = linear_fp8_config def get_linear_method(self) -> CompressedTensorsLinearMethod: @@ -142,6 +141,15 @@ def get_quant_method( return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix) return None + @property + def weight_block_size(self) -> Optional[List[int]]: + """Get the weight block size from the quantization config.""" + if "Linear" in self.target_scheme_map: + weights_config = self.target_scheme_map["Linear"].get("weights") + if weights_config and hasattr(weights_config, "block_structure"): + return weights_config.block_structure + return None + @classmethod def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig: ignore: List[str] = cast(List[str], config.get("ignore", [])) @@ -306,7 +314,9 @@ def _is_dynamic_token_w8a8( # Only symmetric weight quantization supported. return is_8_bits and is_token and weight_quant.symmetric and is_dynamic - def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: + def _is_fp8_w8a8( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: # Confirm weights and activations quantized. if weight_quant is None or input_quant is None: return False @@ -318,15 +328,16 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: ) is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_or_channel_weight = weight_quant.strategy in [ + is_tensor_or_channel_or_block_weight = weight_quant.strategy in [ QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL, + QuantizationStrategy.BLOCK, ] if not ( is_floating_point and is_symmetric_weight and is_static_weight - and is_per_tensor_or_channel_weight + and is_tensor_or_channel_or_block_weight ): return False @@ -406,7 +417,7 @@ def _get_scheme_from_parts( ) if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( - strategy=weight_quant.strategy, + weight_quant=weight_quant, is_static_input_scheme=( input_quant and not input_quant.dynamic ), @@ -608,6 +619,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase): def __init__(self, quantization_config: CompressedTensorsConfig): self.quantization_config = quantization_config + self.quant_config = quantization_config def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.scheme.process_weights_after_loading(layer) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 5ab499f67074..8d0d475ac1f2 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -11,6 +11,7 @@ from compressed_tensors import CompressionFormat from compressed_tensors.quantization import QuantizationStrategy +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase @@ -81,8 +82,8 @@ def get_moe_method( weight_quant = quant_config.target_scheme_map["Linear"].get("weights") input_quant = quant_config.target_scheme_map["Linear"].get("input_activations") - if quant_config._is_wNa16_group_channel(weight_quant, input_quant): + if quant_config._is_wNa16_group_channel(weight_quant, input_quant): logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MoEMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): @@ -102,7 +103,28 @@ def __init__(self, quant_config: CompressedTensorsConfig): "input_activations" ) + per_tensor = ( + self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy == QuantizationStrategy.TENSOR + ) + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN + ) + if not (per_tensor or per_channel): + assert self.weight_quant.strategy == QuantizationStrategy.BLOCK + self.weight_block_size = self.weight_quant.block_structure + assert self.weight_quant.dynamic is not None + else: + self.weight_block_size = None + self.block_quant = self.weight_block_size is not None + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales and per_channel: + raise ValueError( + "For FP8 Fused MoE layer, we require either per tensor or " + "channelwise, dynamic per token quantization." + ) def create_weights( self, @@ -117,6 +139,32 @@ def create_weights( params_dtype = torch.float8_e4m3fn + if self.block_quant: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + tp_size = get_tensor_model_parallel_world_size() + block_n, block_k = ( + self.weight_block_size[0], + self.weight_block_size[1], + ) + # NOTE: To ensure proper alignment of the block-wise quantization + # scales, the output_size of the weights for both the gate and up + # layers must be divisible by block_n. + # Required by column parallel or enabling merged weights + if intermediate_size_per_partition % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1 and intermediate_size_per_partition % block_k != 0: + # Required by row parallel + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( @@ -169,6 +217,26 @@ def create_weights( requires_grad=False, ) weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value + elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + weight_quant_method = FusedMoeWeightScaleSupported.BLOCK.value else: raise ValueError( f"Unsupported weight quantization strategy: {self.weight_quant.strategy}" @@ -343,6 +411,18 @@ def apply( a2_scale=layer.w2_input_scale, ) return StandardCombineInput(hidden_states=output) + elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + use_fp8_w8a8=True, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a13_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.weight_block_size, + ) + return self.runner.run(dispatch_output, quant_info) else: quant_info = TritonMoeQuantInfo( w13_weight=layer.w13_weight, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index ef136a782972..2effad0360f9 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -1,13 +1,14 @@ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Optional +from typing import Callable, Optional import torch -from compressed_tensors.quantization import QuantizationStrategy +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from torch.nn import Parameter from sglang.srt.layers.parameter import ( + BlockQuantScaleParameter, ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter, @@ -19,7 +20,9 @@ from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, apply_fp8_ptpc_linear, + dispatch_w8a8_block_fp8_linear, normalize_e4m3fn_to_e4m3fnuz, + validate_fp8_block_shape, ) from sglang.srt.layers.quantization.utils import requantize_with_max_scale from sglang.srt.utils import get_bool_env_var, is_hip @@ -32,21 +35,113 @@ from aiter.ops.shuffle import shuffle_weight -class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): +strategy_to_parameter_type = { + QuantizationStrategy.BLOCK: BlockQuantScaleParameter, + QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter, + QuantizationStrategy.TENSOR: PerTensorScaleParameter, +} + - def __init__(self, strategy: str, is_static_input_scheme: bool): - self.strategy = strategy +class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): + def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool): + self.weight_quant = weight_quant + self.strategy = self.weight_quant.strategy self.is_static_input_scheme = is_static_input_scheme + self.weight_block_size = self.weight_quant.block_structure + if self.weight_block_size is not None: + self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear() @classmethod def get_min_capability(cls) -> int: # lovelace and up return 89 + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.weight_block_size = None + layer.orig_dtype = params_dtype + + if self.strategy == QuantizationStrategy.BLOCK: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + # Validate block quantization shapes + validate_fp8_block_shape( + layer, + input_size, + output_size, + input_size_per_partition, + output_partition_sizes, + self.weight_block_size, + ) + + # WEIGHT + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + elif self.strategy == QuantizationStrategy.TENSOR: + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + elif self.strategy == QuantizationStrategy.BLOCK: + assert layer.weight_block_size is not None + block_n, block_k = layer.weight_block_size[0], layer.weight_block_size[1] + output_size_per_partition = sum(output_partition_sizes) + weight_scale = BlockQuantScaleParameter( + data=torch.empty( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + weight_scale.format_ue8m0 = False + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale_inv", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + input_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", input_scale) + def process_weights_after_loading(self, layer) -> None: - # If per tensor, when we have a fused module (e.g. QKV) with per - # tensor scales (thus N scales being passed to the kernel), - # requantize so we can always run per tensor if self.strategy == QuantizationStrategy.TENSOR: max_w_scale, weight = requantize_with_max_scale( weight=layer.weight, @@ -65,7 +160,6 @@ def process_weights_after_loading(self, layer) -> None: layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) - # If channelwise, scales are already lined up, so just transpose. elif self.strategy == QuantizationStrategy.CHANNEL: weight = layer.weight @@ -93,6 +187,20 @@ def process_weights_after_loading(self, layer) -> None: # required by torch.compile to be torch.nn.Parameter layer.weight_scale = Parameter(weight_scale, requires_grad=False) + elif self.strategy == QuantizationStrategy.BLOCK: + assert self.is_static_input_scheme is False + weight = layer.weight + weight_scale_inv = layer.weight_scale_inv + + if is_fp8_fnuz(): + weight, weight_scale_inv, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale_inv + ) + layer.weight = Parameter(weight.data, requires_grad=False) + layer.weight_scale_inv = Parameter( + weight_scale_inv.data, requires_grad=False + ) + else: raise ValueError(f"Unknown quantization strategy {self.strategy}") @@ -102,66 +210,22 @@ def process_weights_after_loading(self, layer) -> None: else: layer.input_scale = None - def create_weights( - self, - layer: torch.nn.Module, - output_partition_sizes: List[int], - input_size_per_partition: int, - params_dtype: torch.dtype, - weight_loader: Callable, - **kwargs, - ): - output_size_per_partition = sum(output_partition_sizes) - layer.logical_widths = output_partition_sizes - - # WEIGHT - weight = ModelWeightParameter( - data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight", weight) - - # WEIGHT SCALE - # TODO: update create_xxx_parameter functions to return - # the newly added parameters - if self.strategy == QuantizationStrategy.CHANNEL: - weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader, - ) - else: - assert self.strategy == QuantizationStrategy.TENSOR - weight_scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader, - ) - - # min requirement for fp8 kernels - weight_scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("weight_scale", weight_scale) - - # INPUT SCALE - if self.is_static_input_scheme: - input_scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader, - ) - input_scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("input_scale", input_scale) - def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if self.weight_block_size is not None: + return self.w8a8_block_fp8_linear( + input=x, + weight=layer.weight, + block_size=self.weight_block_size, + weight_scale=layer.weight_scale_inv, + input_scale=layer.input_scale, + bias=bias, + ) + if _use_aiter and self.strategy == QuantizationStrategy.CHANNEL: return apply_fp8_ptpc_linear( input=x, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index d858abcead86..d1a42236f689 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -786,7 +786,6 @@ def process_weights_after_loading(self, layer: Module) -> None: _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) else: # For fp8 moe run with deepgemm, the expert weights and scales need be requantized to ue8m0 - from sglang.srt.layers.moe import get_moe_runner_backend from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE from sglang.srt.model_loader.utils import ( should_deepgemm_weight_requant_ue8m0, @@ -1006,10 +1005,7 @@ def create_moe_runner( ): from sglang.srt.layers import deep_gemm_wrapper - from sglang.srt.layers.moe.utils import ( - get_moe_a2a_backend, - get_moe_runner_backend, - ) + from sglang.srt.layers.moe.utils import get_moe_a2a_backend self.moe_runner_config = moe_runner_config moe_runner_backend = get_moe_runner_backend() diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index f8a98e98fad5..70914c9d3fd2 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -971,3 +971,46 @@ def apply_fp8_ptpc_linear( if bias is not None: output = output + bias return output.view(*output_shape) + + +def validate_fp8_block_shape( + layer: torch.nn.Module, + input_size: int, + output_size: int, + input_size_per_partition: int, + output_partition_sizes: list[int], + block_size: list[int], +) -> None: + """Validate block quantization shapes for tensor parallelism.""" + from sglang.srt.distributed import get_tensor_model_parallel_world_size + + tp_size = getattr(layer, "tp_size", get_tensor_model_parallel_world_size()) + block_n, block_k = block_size[0], block_size[1] + + # Required by row parallel + if ( + tp_size > 1 + and input_size // input_size_per_partition == tp_size + and input_size_per_partition % block_k != 0 + ): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition} " + f"is not divisible by weight quantization block_k = {block_k}." + ) + + # Required by column parallel or enabling merged weights + is_tp_split = tp_size > 1 and output_size // sum(output_partition_sizes) == tp_size + is_merged_gemm = len(output_partition_sizes) > 1 + if is_tp_split or is_merged_gemm: + sizes_to_check = output_partition_sizes + if not is_tp_split and is_merged_gemm: + # In case of merged matrices, we allow the last + # matrix to not be a multiple of block size + sizes_to_check = output_partition_sizes[:-1] + for output_partition_size in sizes_to_check: + if output_partition_size % block_n != 0: + raise ValueError( + f"Weight output_partition_size = " + f"{output_partition_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 7d0a16003880..30d7224dc3e1 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1218,6 +1218,16 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: return 0.1 * mscale * math.log(scale) + 1.0 +def _get_llama_4_scaling( + original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor +) -> torch.Tensor: + scaling = 1 + scaling_beta * torch.log( + 1 + torch.floor(positions / original_max_position_embeddings) + ) + # Broadcast over num_heads and head_dim + return scaling[..., None, None] + + class DeepseekV2AttentionMLA(nn.Module): def __init__( @@ -1519,12 +1529,14 @@ def forward( hidden_states: torch.Tensor, forward_batch: ForwardBatch, zero_allocator: BumpAllocator, + llama_4_scaling: Optional[torch.Tensor] = None, ): s = self.forward_prepare( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, zero_allocator=zero_allocator, + llama_4_scaling=llama_4_scaling, ) return self.forward_core(s) @@ -1534,6 +1546,7 @@ def forward_prepare( hidden_states: torch.Tensor, forward_batch: ForwardBatch, zero_allocator: BumpAllocator, + llama_4_scaling: Optional[torch.Tensor] = None, ): if self.attn_mha.kv_b_proj is None: self.attn_mha.kv_b_proj = self.kv_b_proj @@ -1573,7 +1586,7 @@ def forward_prepare( ) elif attn_forward_method == AttnForwardMethod.MLA: inner_state = self.forward_absorb_prepare( - positions, hidden_states, forward_batch, zero_allocator + positions, hidden_states, forward_batch, zero_allocator, llama_4_scaling ) elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE: inner_state = self.forward_absorb_fused_mla_rope_prepare( @@ -1797,6 +1810,7 @@ def forward_absorb_prepare( hidden_states: torch.Tensor, forward_batch: ForwardBatch, zero_allocator: BumpAllocator, + llama_4_scaling: Optional[torch.Tensor] = None, ): from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode @@ -1974,6 +1988,7 @@ def forward_absorb_prepare( zero_allocator, positions, topk_indices, + llama_4_scaling, ) def forward_absorb_core( @@ -1986,6 +2001,7 @@ def forward_absorb_core( zero_allocator, positions, topk_indices, + llama_4_scaling, ): if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS: extra_args = {} @@ -1993,6 +2009,7 @@ def forward_absorb_core( extra_args = { "cos_sin_cache": self.rotary_emb.cos_sin_cache, "is_neox": self.rotary_emb.is_neox_style, + "llama_4_scaling": llama_4_scaling, } attn_output = self.attn_mqa( @@ -2023,6 +2040,10 @@ def forward_absorb_core( q = torch.cat([q_nope_out, q_pe], dim=-1) k = torch.cat([k_nope, k_pe], dim=-1) + # Apply llama 4 scaling if provided + if llama_4_scaling is not None: + q *= llama_4_scaling + attn_output = self.attn_mqa( q, k, @@ -2766,6 +2787,7 @@ def forward( residual: Optional[torch.Tensor], zero_allocator: BumpAllocator, gemm_output_zero_allocator: BumpAllocator = None, + llama_4_scaling: Optional[torch.Tensor] = None, ) -> torch.Tensor: quant_format = ( "mxfp4" @@ -2806,6 +2828,7 @@ def forward( hidden_states=hidden_states, forward_batch=forward_batch, zero_allocator=zero_allocator, + llama_4_scaling=llama_4_scaling, ) hidden_states, residual = self.layer_communicator.prepare_mlp( @@ -3024,6 +3047,9 @@ def __init__( ) self.layers_to_capture = [] + # llama_4_scaling: for supporting Mistral-Large-3 model + self.llama_4_scaling_config = getattr(config, "llama_4_scaling", None) + def get_input_embeddings(self) -> torch.Tensor: return self.embed_tokens @@ -3072,6 +3098,18 @@ def forward( if enable_prefill_cp(forward_batch, self.nsa_enable_prefill_cp): hidden_states = cp_split_and_rebuild_data(forward_batch, hidden_states) + # llama_4_scaling: for supporting Mistral-Large-3 model + # Compute llama 4 scaling once per forward pass if enabled + llama_4_scaling: Optional[torch.Tensor] = None + if self.llama_4_scaling_config is not None: + llama_4_scaling = _get_llama_4_scaling( + original_max_position_embeddings=self.llama_4_scaling_config[ + "original_max_position_embeddings" + ], + scaling_beta=self.llama_4_scaling_config["beta"], + positions=positions, + ) + normal_start_layer = self.start_layer normal_end_layer = self.end_layer if forward_batch.can_run_tbo: @@ -3095,6 +3133,7 @@ def forward( residual, zero_allocator, gemm_output_zero_allocator, + llama_4_scaling, ) if normal_end_layer != self.end_layer: @@ -3351,8 +3390,10 @@ def post_load_weights(self, is_nextn=False, weight_names=None): ): # For mixed quantization (experts int4, linear fp8), use linear_fp8_config selected_quant_config = getattr( - self.quant_config, "linear_fp8_config", self.quant_config + self.quant_config, "linear_fp8_config", None ) + if selected_quant_config is None: + selected_quant_config = self.quant_config weight_block_size = getattr( selected_quant_config, "weight_block_size", None ) @@ -3739,16 +3780,24 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal ): q_a_proj_weight = cached_a_proj[q_a_proj_name] kv_a_proj_weight = cached_a_proj[kv_a_proj_name] - cat_dim = 0 - if self.quant_config is not None and ( - self.quant_config.get_name() == "awq" - or self.quant_config.get_name() == "awq_marlin" - or self.quant_config.get_name() == "moe_wna16" - ): - cat_dim = 1 - fused_weight = torch.cat( - [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim - ) + + if q_a_proj_weight.shape == torch.Size( + [] + ) and kv_a_proj_weight.shape == torch.Size([]): + fused_weight = q_a_proj_weight + else: + cat_dim = 0 + if self.quant_config is not None and ( + self.quant_config.get_name() == "awq" + or self.quant_config.get_name() == "awq_marlin" + or self.quant_config.get_name() == "moe_wna16" + ): + cat_dim = 1 + + fused_weight = torch.cat( + [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim + ) + param_name = ( name.replace( "q_a_proj", "fused_qkv_a_proj_with_mqa" diff --git a/python/sglang/srt/models/mistral_large_3.py b/python/sglang/srt/models/mistral_large_3.py new file mode 100644 index 000000000000..719d47915407 --- /dev/null +++ b/python/sglang/srt/models/mistral_large_3.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable + +import regex as re +import torch + +from sglang.srt.models.deepseek_v2 import DeepseekV3ForCausalLM + + +class MistralLarge3ForCausalLM(DeepseekV3ForCausalLM): + # fmt: off + remapping = { + r"layers\.(\d+)\.attention_norm\.weight": r"model.layers.\1.input_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.attention\.wq\.(\w+)": r"model.layers.\1.self_attn.q_proj.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.wq_a\.(\w+)": r"model.layers.\1.self_attn.q_a_proj.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.q_a_norm\.weight": r"model.layers.\1.self_attn.q_a_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.attention\.wq_b\.(\w+)": r"model.layers.\1.self_attn.q_b_proj.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.wkv_a_with_mqa\.(\w+)": r"model.layers.\1.self_attn.kv_a_proj_with_mqa.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.kv_a_norm\.weight": r"model.layers.\1.self_attn.kv_a_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.attention\.wkv_b\.(\w+)": r"model.layers.\1.self_attn.kv_b_proj.\2", # noqa: E501 + r"layers\.(\d+)\.attention\.wo\.(\w+)": r"model.layers.\1.self_attn.o_proj.\2", # noqa: E501 + # FP8 scales + r"layers\.(\d+)\.attention\.k_fake_quantizer\.qscale_act": r"model.layers.\1.self_attn.mla_attn.mla_attn.k_scale", # noqa: E501 + r"layers\.(\d+)\.attention\.q_fake_quantizer\.qscale_act": r"model.layers.\1.self_attn.mla_attn.mla_attn.q_scale", # noqa: E501 + r"layers\.(\d+)\.attention\.v_fake_quantizer\.qscale_act": r"model.layers.\1.self_attn.mla_attn.mla_attn.v_scale", # noqa: E501 + r"layers\.(\d+)\.ffn_norm\.weight": r"model.layers.\1.post_attention_layernorm.weight", # noqa: E501 + r"layers\.(\d+)\.feed_forward\.w1\.(\w+)": r"model.layers.\1.mlp.gate_proj.\2", # noqa: E501 + r"layers\.(\d+)\.feed_forward\.w2\.(\w+)": r"model.layers.\1.mlp.down_proj.\2", # noqa: E501 + r"layers\.(\d+)\.feed_forward\.w3\.(\w+)": r"model.layers.\1.mlp.up_proj.\2", # noqa: E501 + r"layers\.(\d+)\.gate\.weight": r"model.layers.\1.mlp.gate.weight", # noqa: E501 + r"layers\.(\d+)\.shared_experts\.w1\.(\w+)": r"model.layers.\1.mlp.shared_experts.gate_proj.\2", # noqa: E501 + r"layers\.(\d+)\.shared_experts\.w2\.(\w+)": r"model.layers.\1.mlp.shared_experts.down_proj.\2", # noqa: E501 + r"layers\.(\d+)\.shared_experts\.w3\.(\w+)": r"model.layers.\1.mlp.shared_experts.up_proj.\2", # noqa: E501 + r"layers\.(\d+)\.experts\.(\d+)\.w1\.(\w+)": r"model.layers.\1.mlp.experts.\2.gate_proj.\3", # noqa: E501 + r"layers\.(\d+)\.experts\.(\d+)\.w2\.(\w+)": r"model.layers.\1.mlp.experts.\2.down_proj.\3", # noqa: E501 + r"layers\.(\d+)\.experts\.(\d+)\.w3\.(\w+)": r"model.layers.\1.mlp.experts.\2.up_proj.\3", # noqa: E501 + r"layers\.(\d+)\.router_biases": r"model.layers.\1.mlp.gate.e_score_correction_bias", # noqa: E501 + r"norm\.weight": "model.norm.weight", # noqa: E501 + r"tok_embeddings\.weight": "model.embed_tokens.weight", # noqa: E501 + r"output\.weight": "lm_head.weight", # noqa: E501 + } + # fmt: on + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + return super().load_weights(self._iterable_remap_mistral_to_ds(weights)) + + def _iterable_remap_mistral_to_ds( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> Iterable[tuple[str, torch.Tensor]]: + """Remap Mistral parameters to DeepseekV2 parameters.""" + for name, loaded_weight in weights: + for k, v in self.remapping.items(): + match = re.fullmatch(k, name) + if match: + name = re.sub(k, v, name) + break + else: + import logging + + logging.warning(f"Unrecognized weight: {name}. Skipping.") + continue + + # Note(Andy): Unlike Llama, this implementation uses + # is_neox_style=False for RoPE, which matches Mistral's implementation. + # Thus we don't need to permute the q/k weights (unlike Llama) + + # Remapping scale names. We could do this in the regex above but it + # would triple the number of lines for most layers. + if name.endswith(".qscale_act"): + name = re.sub(r"\.qscale_act$", ".input_scale", name) + elif name.endswith(".qscale_weight"): + name = re.sub(r"\.qscale_weight$", ".weight_scale", name) + + if name.endswith(".weight_scale") and ".experts." not in name: + name = re.sub(r"\.weight_scale$", ".weight_scale_inv", name) + + yield name, loaded_weight + + +EntryClass = MistralLarge3ForCausalLM diff --git a/python/sglang/srt/models/pixtral.py b/python/sglang/srt/models/pixtral.py index 249a5ce81bba..1f9c236bb2ef 100644 --- a/python/sglang/srt/models/pixtral.py +++ b/python/sglang/srt/models/pixtral.py @@ -16,10 +16,12 @@ Using mistral-community/pixtral-12b as reference. """ +from dataclasses import dataclass, fields from typing import Iterable, List, Optional, Set, Tuple, Union import torch import torch.nn as nn +import torch.nn.functional as F from transformers import PixtralVisionConfig, PretrainedConfig from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding from transformers.models.pixtral.modeling_pixtral import ( @@ -32,9 +34,398 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import MergedColumnParallelLinear, RowParallelLinear from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens -from sglang.srt.managers.schedule_batch import MultimodalInputs +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.mistral_large_3 import MistralLarge3ForCausalLM + +USE_XFORMERS_OPS = False +PATCH_MERGE = "patch_merge" + + +# Vision encoder +@dataclass +class VisionEncoderArgs: + hidden_size: int + num_channels: int + image_size: int + patch_size: int + intermediate_size: int + num_hidden_layers: int + num_attention_heads: int + rope_theta: float # for rope-2D + image_token_id: int + adapter_bias: bool = True + spatial_merge_size: int = 1 + add_pre_mm_projector_layer_norm: bool = False + mm_projector_id: str = "" + + +class PixtralForConditionalGeneration(nn.Module): + merge_by_field_config = True + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return None + + raise ValueError("Only image modality is supported") + + def __init__(self, *, config, prefix: str = "", **kwargs): + super().__init__() + self.config = config + dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} + vision_args = { + key: value + for key, value in self.config.vision_config.to_dict().items() + if key in dataclass_fields + } + + self.vision_args = VisionEncoderArgs(**vision_args) + + self.language_model = MistralLarge3ForCausalLM( + config=self.config.text_config, + quant_config=kwargs.get("quant_config"), + ) + + self.vision_encoder = VisionTransformer(self.vision_args) + + if self.vision_args.add_pre_mm_projector_layer_norm: + self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size, eps=1e-5) + + if self.vision_args.mm_projector_id == PATCH_MERGE: + self.patch_merger = PatchMerger( + vision_encoder_dim=self.vision_args.hidden_size, + spatial_merge_size=self.vision_args.spatial_merge_size, + use_mlp_bias=False, + ) + + self.vision_language_adapter = VisionLanguageAdapter( + self.vision_args, dim=self.config.text_config.hidden_size + ) + + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + return pattern.pad_input_tokens(input_ids, mm_inputs) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): + return weight[0].startswith("vision_encoder") + + def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]): + return weight[0].startswith("vision_language_adapter") + + def is_patch_merger(weight: tuple[str, torch.Tensor]): + return weight[0].startswith("patch_merger") + + def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]): + return weight[0].startswith("pre_mm_projector_norm") + + # Get references to parameters for direct loading + vision_encoder_dict = dict(self.vision_encoder.named_parameters()) + patch_merger_dict = ( + dict(self.patch_merger.named_parameters()) + if self.vision_args.mm_projector_id == PATCH_MERGE + else dict() + ) + pre_mm_projector_norm_dict = ( + dict(self.pre_mm_projector_norm.named_parameters()) + if self.vision_args.add_pre_mm_projector_layer_norm + else dict() + ) + vision_lang_adapter_dict = dict(self.vision_language_adapter.named_parameters()) + + def llm_weights_generator(): + # Single pass over weights + for name, w in weights: + if is_vision_encoder_weights((name, w)): + # Load vision encoder weights directly + trimmed_name = ".".join(name.split(".")[1:]) + # NOTE: The current nvfp4 model has extra weights that we need to ignore, called + # vision_encoder.transformer.layers.*.attention.{k,v}_fake_quantizer.qscale_act + # TODO: Remove this if condition once the model is fixed + if "fake_quantizer.qscale_act" in trimmed_name: + continue + param = vision_encoder_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + elif is_patch_merger((name, w)): + # Load vision patch merger weights directly + trimmed_name = ".".join(name.split(".")[1:]) + param = patch_merger_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + elif is_pre_mm_projector_norm((name, w)): + # Load vision pre_mm_projector_norm weights directly + trimmed_name = ".".join(name.split(".")[1:]) + param = pre_mm_projector_norm_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + elif is_vision_lang_adapter_weights((name, w)): + # Load vision-language adapter weights directly + trimmed_name = ".".join(name.split(".")[1:]) + param = vision_lang_adapter_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + else: + # LLM weights: yield them to be loaded + # by language_model.load_weights + yield (name, w) + + # Now we call the language model load with the generator + self.language_model.load_weights(llm_weights_generator()) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + images = [item.feature for item in items] + # Process images through vision encoder + image_features = self.vision_encoder(images) + if self.vision_args.add_pre_mm_projector_layer_norm: + image_features = image_features.view(-1, image_features.shape[-1]) + image_features = self.pre_mm_projector_norm(image_features) + if self.vision_args.mm_projector_id == PATCH_MERGE: + patch_size = self.vision_args.patch_size + img_patch_dims = [ + (img.shape[-2] // patch_size, img.shape[-1] // patch_size) + for img in images + for _ in range(img.shape[0]) + ] + image_features = self.patch_merger( + image_features, image_sizes=img_patch_dims + ) + image_embeds = self.vision_language_adapter(image_features) + return image_embeds + + def forward(self, input_ids, positions, forward_batch): + return general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.language_model, + multimodal_model=self, + positions=positions, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def get_embed_and_head(self): + return self.language_model.get_embed_and_head() + + +class PatchMerger(nn.Module): + """ + Learned merging of spatial_merge_size ** 2 patches + """ + + def __init__( + self, + vision_encoder_dim: int, + spatial_merge_size: int, + use_mlp_bias: bool = False, + ) -> None: + super().__init__() + + mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2) + + self.spatial_merge_size = spatial_merge_size + self.mlp_input_dim = mlp_input_dim + + self.merging_layer = nn.Linear( + mlp_input_dim, + vision_encoder_dim, + bias=use_mlp_bias, + ) + + def forward( + self, x: torch.Tensor, image_sizes: list[tuple[int, int]] + ) -> torch.Tensor: + # image_sizes specified in tokens + assert sum([h * w for h, w in image_sizes]) == x.shape[-2] + + # x is (N, vision_encoder_dim) + x = self.permute(x, image_sizes) + + # x is (N / spatial_merge_size ** 2, + # vision_encoder_dim * spatial_merge_size ** 2) + x = self.merging_layer(x) + + # x is (N / spatial_merge_size ** 2, vision_encoder_dim) + return x + + def permute( + self, + x: torch.Tensor, + image_sizes: list[tuple[int, int]], + ) -> torch.Tensor: + """ + Args: + x: (N, D) where N is flattened and concatenated patch tokens + for all images + image_sizes: list of tuple of (height, width) in tokens for + each image + Returns: + image_features: reorders patch tokens so each grid of + (spatial_merge_size, spatial_merge_size) is contiguous. + now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2) + """ + + sub_grids = get_sub_grids( + x=x, image_sizes=image_sizes, spatial_merge_size=self.spatial_merge_size + ) # list of [d x sub_grid_size x sub_grid_size x n_patches] + permuted_tensor: list[torch.Tensor] = [] + for grid in sub_grids: + n_patches = grid.shape[-1] + permuted_tensor.append( + grid.view(-1, n_patches).t() + ) # n_patches x d * sub_grid_size * sub_grid_size + return torch.cat( + permuted_tensor, dim=0 + ) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2) + + +def get_sub_grids( + x: torch.Tensor, + image_sizes: list[tuple[int, int]], + spatial_merge_size: int, +) -> list[torch.Tensor]: + # image_sizes specified in tokens + tokens_per_image = [h * w for h, w in image_sizes] + d = x.shape[-1] + all_img_sub_grids: list[torch.Tensor] = [] + sub_grid_size = spatial_merge_size + + for image_index, image_tokens in enumerate(x.split(tokens_per_image)): + # Reshape image_tokens into a 2D grid + h, w = image_sizes[image_index] + image_grid = image_tokens.view(h, w, d).permute(2, 0, 1)[ + None, :, :, : + ] # 1 x d x h x w + sub_grids = torch.nn.functional.unfold( + image_grid, kernel_size=sub_grid_size, stride=sub_grid_size + ) + sub_grids = sub_grids.view( + 1, d, sub_grid_size, sub_grid_size, -1 + ) # 1 x d x sub_grid_size x sub_grid_size x n_patches + + all_img_sub_grids.append(sub_grids[0]) + + return all_img_sub_grids + + +class VisionTransformer(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.args = args + self.patch_conv = nn.Conv2d( + in_channels=args.num_channels, + out_channels=args.hidden_size, + kernel_size=args.patch_size, + stride=args.patch_size, + bias=False, + ) + self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) + self.transformer = Transformer(args) + + head_dim = self.args.hidden_size // self.args.num_attention_heads + assert head_dim % 2 == 0, "ROPE requires even head_dim" + self._freqs_cis: torch.Tensor | None = None + + @property + def max_patches_per_side(self) -> int: + return self.args.image_size // self.args.patch_size + + @property + def device(self) -> torch.types.Device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @property + def freqs_cis(self) -> torch.Tensor: + if self._freqs_cis is None: + self._freqs_cis = precompute_freqs_cis_2d( + dim=self.args.hidden_size // self.args.num_attention_heads, + height=self.max_patches_per_side, + width=self.max_patches_per_side, + theta=self.args.rope_theta, + ) + + if self._freqs_cis.device != self.device: + self._freqs_cis = self._freqs_cis.to(device=self.device) + + return self._freqs_cis + + def forward( + self, + images: list[torch.Tensor], + ) -> torch.Tensor: + """ + Args: + images: list of N_img images of variable sizes, + each of shape (B, C, H, W) + Returns: + image_features: tensor of token features for + all tokens of all images of shape (N_toks, D) + """ + patch_embeds_list = [self.patch_conv(img.to(self.dtype)) for img in images] + + patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list] + + patch_embeds = torch.cat(patch_embeds, dim=1) + patch_embeds_shape = patch_embeds.shape + patch_embeds = patch_embeds.view(-1, patch_embeds_shape[-1]) + patch_embeds = self.ln_pre(patch_embeds) + patch_embeds = patch_embeds.view(patch_embeds_shape) + + # positional embeddings + positions = position_meshgrid(patch_embeds_list).to(self.device) + freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] + + # pass through Transformer with a block diagonal mask delimiting images + if USE_XFORMERS_OPS: + from xformers import ops as xops + + mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], + ) + else: + from transformers.models.pixtral.modeling_pixtral import ( + generate_block_attention_mask, + ) + + mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds + ) + return self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) + + +def position_meshgrid( + patch_embeds_list: list[torch.Tensor], +) -> torch.Tensor: + positions = torch.cat( + [ + torch.stack( + torch.meshgrid( + torch.arange(p.shape[-2]), + torch.arange(p.shape[-1]), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) + for p in patch_embeds_list + ] + ) + return positions class PixtralHFMLP(nn.Module): @@ -81,6 +472,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out +class VisionLanguageAdapter(nn.Module): + def __init__(self, args: VisionEncoderArgs, dim: int): + super().__init__() + assert isinstance(args, VisionEncoderArgs) + self.w_in = nn.Linear( + args.hidden_size, + dim, + bias=args.adapter_bias, + ) + self.gelu = nn.GELU() + self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w_out(self.gelu(self.w_in(x))) + + class PixtralHFTransformerBlock(nn.Module): """Transformer block for PixtralHFVisionModel using SGLang components.""" @@ -156,6 +563,161 @@ def forward( return output +def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + freqs_cis: complex - (seq_len, head_dim / 2) + x: complex - (bsz, seq_len, head_dim / 2) + """ + ndim = x.ndim + assert ndim > 1 + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( + freqs_cis.shape, + (x.shape[1], x.shape[-1]), + ) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def precompute_freqs_cis_2d( + dim: int, + height: int, + width: int, + theta: float, +) -> torch.Tensor: + """ + freqs_cis: 2D complex tensor of shape (height, width, dim // 2) + to be indexed by (height, width) position tuples + """ + # (dim / 2) frequency bases + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + + h = torch.arange(height, device=freqs.device) + w = torch.arange(width, device=freqs.device) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + freqs_2d = torch.cat( + [ + freqs_h[:, None, :].repeat(1, width, 1), + freqs_w[None, :, :].repeat(height, 1, 1), + ], + dim=-1, + ) + return torch.polar(torch.ones_like(freqs_2d), freqs_2d) + + +def apply_rotary_emb_vit( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + assert freqs_cis.dtype == torch.complex64 + freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class FeedForward(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + assert args.intermediate_size is not None + self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) + self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class Attention(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.args = args + assert not args.hidden_size % args.num_attention_heads + self.n_heads = args.num_attention_heads + self.head_dim = args.hidden_size // args.num_attention_heads + + self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + batch, patches, _ = x.shape + + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = q.reshape(batch, patches, self.n_heads, self.head_dim) + k = k.reshape(batch, patches, self.n_heads, self.head_dim) + v = v.reshape(batch, patches, self.n_heads, self.head_dim) + + q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis) + + if USE_XFORMERS_OPS: + from xformers import ops as xops + + out = xops.memory_efficient_attention(q, k, v, attn_bias=mask) + else: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + out = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) + out = out.transpose(1, 2) + + out = out.reshape(batch, patches, self.n_heads * self.head_dim) + return self.wo(out) + + +class TransformerBlock(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.attention = Attention(args) + self.feed_forward = FeedForward(args) + self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5) + self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + attention_norm_x = self.attention_norm(x.view(-1, x.shape[-1])) + attention_norm_x = attention_norm_x.view(x.shape) + r = self.attention.forward(attention_norm_x, mask=mask, freqs_cis=freqs_cis) + h = x + r + ffn_norm_h = self.ffn_norm(h.view(-1, h.shape[-1])) + ffn_norm_h = ffn_norm_h.view(h.shape) + r = self.feed_forward.forward(ffn_norm_h) + out = h + r + return out + + +class Transformer(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.layers = torch.nn.ModuleList() + for _ in range(args.num_hidden_layers): + self.layers.append(TransformerBlock(args)) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + freqs_cis: torch.Tensor | None, + ) -> torch.Tensor: + for layer in self.layers: + x = layer(x, mask=mask, freqs_cis=freqs_cis) + return x + + class PixtralHFTransformer(nn.Module): """Transformer for PixtralHFVisionModel using SGLang components.""" @@ -456,4 +1018,4 @@ class PixtralVisionModel(PixtralHFVisionModel): # Register the model classes for external access -EntryClass = [PixtralVisionModel] +EntryClass = [PixtralForConditionalGeneration, PixtralVisionModel] diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index af5cedec9fa6..6b6ab34ad1a0 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -6,7 +6,10 @@ _num_image_tokens as _get_pixtral_hf_num_image_tokens, ) -from sglang.srt.models.pixtral import PixtralVisionModel +from sglang.srt.models.pixtral import ( + PixtralForConditionalGeneration, + PixtralVisionModel, +) from sglang.srt.multimodal.processors.base_processor import ( BaseMultimodalProcessor, MultimodalSpecialTokens, @@ -14,7 +17,7 @@ class PixtralProcessor(BaseMultimodalProcessor): - models = [PixtralVisionModel] + models = [PixtralVisionModel, PixtralForConditionalGeneration] PAD_TOKEN = "" IMG_BREAK_TOKEN_ID = 12 @@ -30,7 +33,6 @@ def get_patch_grid_size( patch_width = patch_height = self.patch_size ratio = max(image_width / max_width, image_height / max_height) - if ratio > 1: image_width = int(math.floor(image_width / ratio)) image_height = int(math.floor(image_height / ratio)) @@ -52,6 +54,10 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): self.vision_config = hf_config.vision_config self.image_size = self.vision_config.image_size self.patch_size = self.vision_config.patch_size + + self._processor.patch_size = self.patch_size + self._processor.spatial_merge_size = self.vision_config.spatial_merge_size + self.mm_tokens = MultimodalSpecialTokens( image_token=_processor.image_token, image_token_id=self.IM_TOKEN_ID, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 691cc862949d..a07b71c5a02e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -944,7 +944,18 @@ def _handle_model_specific_adjustments(self): hf_config = self.get_hf_config() model_arch = hf_config.architectures[0] - if model_arch in ["DeepseekV3ForCausalLM"]: + + if model_arch in [ + "MistralLarge3ForCausalLM", + "PixtralForConditionalGeneration", + ]: + self.dtype = "bfloat16" + + if model_arch in [ + "DeepseekV3ForCausalLM", + "MistralLarge3ForCausalLM", + "PixtralForConditionalGeneration", + ]: if is_deepseek_nsa(hf_config): if ( self.attention_backend is None @@ -1050,7 +1061,7 @@ def _handle_model_specific_adjustments(self): # Default DeepSeek V3/R1 native FP8 when not explicitly set, # Because we need this condition for an assertion in # flashinfer_trtllm MoE runner backend. - if quant_method is None: + if quant_method is None and model_arch == "DeepseekV3ForCausalLM": self.quantization = "fp8" logger.info( "Quantization not specified, default to fp8 for DeepSeek on sm100" @@ -1693,6 +1704,8 @@ def _handle_speculative_decoding(self): "Glm4MoeForCausalLM", "BailingMoeForCausalLM", "BailingMoeV2ForCausalLM", + "MistralLarge3ForCausalLM", + "PixtralForConditionalGeneration", ]: if self.speculative_draft_model_path is None: self.speculative_draft_model_path = self.model_path @@ -1933,6 +1946,8 @@ def _handle_deterministic_inference(self): "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM", + "MistralLarge3ForCausalLM", + "PixtralForConditionalGeneration", ] except Exception: pass @@ -4526,6 +4541,8 @@ def auto_choose_speculative_params(self: ServerArgs): "GptOssForCausalLM", "BailingMoeForCausalLM", "BailingMoeV2ForCausalLM", + "MistralLarge3ForCausalLM", + "PixtralForConditionalGeneration", ]: # The default value for deepseek and gpt-oss return (3, 1, 4) diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 0e71dfb31383..4cacc88f49d7 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -60,7 +60,7 @@ from sglang.srt.configs.internvl import InternVLChatConfig from sglang.srt.connector import create_remote_connector from sglang.srt.multimodal.customized_mm_processor_utils import _CUSTOMIZED_MM_PROCESSOR -from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset +from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset, mistral_utils _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [ ChatGLMConfig, @@ -184,6 +184,37 @@ def _load_deepseek_v32_model( ) +# Temporary hack for Mistral Large +def _load_mistral_large_3_for_causal_LM( + model_path: str, + trust_remote_code: bool = False, + revision: Optional[str] = None, + **kwargs, +): + # first get the local path + local_path = download_from_hf(model_path) + # then load the config file in json + parser = mistral_utils.MistralConfigParser() + config_dict, _ = parser.parse(local_path) + + with tempfile.NamedTemporaryFile(mode="w+", suffix=".json") as f: + json.dump(config_dict, f) + f.flush() + loaded_config = AutoConfig.from_pretrained( + f.name, trust_remote_code=trust_remote_code, revision=revision, **kwargs + ) + text_config = getattr(loaded_config, "text_config", None) + if text_config is not None and isinstance(text_config, dict): + text_config = AutoConfig.for_model(**text_config) + setattr(loaded_config, "text_config", text_config) + vision_config = getattr(loaded_config, "vision_config", None) + if vision_config is not None and isinstance(vision_config, dict): + vision_config = AutoConfig.for_model(**vision_config) + setattr(loaded_config, "vision_config", vision_config) + + return loaded_config + + def _is_deepseek_ocr_model(config: PretrainedConfig) -> bool: # TODO: Remove this workaround related when AutoConfig correctly identifies deepseek-ocr. # Hugging Face's AutoConfig currently misidentifies it as deepseekvl2. @@ -215,17 +246,21 @@ def get_config( client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) model = client.get_local_dir() - try: - config = AutoConfig.from_pretrained( - model, trust_remote_code=trust_remote_code, revision=revision, **kwargs - ) - - except ValueError as e: - if not "deepseek_v32" in str(e): - raise e - config = _load_deepseek_v32_model( + if "mistral-large-3" in str(model).lower(): + config = _load_mistral_large_3_for_causal_LM( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) + else: + try: + config = AutoConfig.from_pretrained( + model, trust_remote_code=trust_remote_code, revision=revision, **kwargs + ) + except ValueError as e: + if not "deepseek_v32" in str(e): + raise e + config = _load_deepseek_v32_model( + model, trust_remote_code=trust_remote_code, revision=revision, **kwargs + ) if ( config.architectures is not None @@ -465,14 +500,20 @@ def get_processor( ): # pop 'revision' from kwargs if present. revision = kwargs.pop("revision", tokenizer_revision) - - config = AutoConfig.from_pretrained( - tokenizer_name, - trust_remote_code=trust_remote_code, - revision=revision, - **kwargs, - ) - + if "mistral-large-3" in str(tokenizer_name).lower(): + config = _load_mistral_large_3_for_causal_LM( + tokenizer_name, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + else: + config = AutoConfig.from_pretrained( + tokenizer_name, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) if _is_deepseek_ocr_model(config): # Temporary hack for load deepseek-ocr config.model_type = "deepseek-ocr" diff --git a/python/sglang/srt/utils/mistral_utils.py b/python/sglang/srt/utils/mistral_utils.py new file mode 100644 index 000000000000..e23abb53c5fd --- /dev/null +++ b/python/sglang/srt/utils/mistral_utils.py @@ -0,0 +1,295 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +from pathlib import Path +from typing import Any + +from transformers import PretrainedConfig, WhisperConfig + +from sglang.srt.utils import logger + + +def adapt_config_dict( + config_dict: dict[str, Any], model: str, **kwargs +) -> PretrainedConfig: + config_dict.update(kwargs) + config_dict = _remap_general_mistral_args(config_dict) + + if bool(config_dict.get("quantization")): + config_dict = _remap_mistral_quantization_args(config_dict) + + is_moe = bool(config_dict.get("moe")) + is_mistral_large_3 = ( + is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0 + ) + if is_moe: + if is_mistral_large_3: + config_dict = _remap_moe_args(config_dict) + config_dict["model_type"] = "deepseek_v3" + config_dict["architectures"] = ["MistralLarge3ForCausalLM"] + + assert ( + "llama_4_scaling" in config_dict + ), "MistralLarge3 expect llama4 scaling config." + llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"] + assert all( + [ + key in config_dict["llama_4_scaling"] + for key in llama_4_scaling_config_keys + ] + ), ( + "llama_4_scaling config should define the keys: " + f"{','.join(llama_4_scaling_config_keys)}" + ) + else: + config_dict["architectures"] = ["MixtralForCausalLM"] + else: + config_dict["architectures"] = ["MistralForCausalLM"] + + if bool(config_dict.get("yarn")): + config_dict = _remap_mistral_yarn_args(config_dict) + + if bool(config_dict.get("llama_4_scaling")): + llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"] + assert all( + [ + key in config_dict["llama_4_scaling"] + for key in llama_4_scaling_config_keys + ] + ), ( + "llama_4_scaling config should define the keys: " + f"{','.join(llama_4_scaling_config_keys)}" + ) + + is_vision = bool( + (config_dict.get("multimodal") or {}).get("vision_encoder_args") + or config_dict.get("vision_encoder") + ) + is_audio = bool( + ((config_dict.get("multimodal") or {}).get("whisper_model_args") or {}).get( + "encoder_args" + ) + ) + + assert not (is_vision and is_audio), "Vision and audio are mutually exclusive" + + if is_vision: + config_dict = _remap_mistral_vision_args(config_dict) + if is_audio: + config_dict = _remap_mistral_audio_args(config_dict) + + config = PretrainedConfig.from_dict(config_dict) + + logger.debug("Initialized config %s", config) + + return config_dict, config + + +def _remap_mistral_vision_args(config: dict) -> dict: + if config.get("multimodal"): + vision_config = config.pop("multimodal") + else: + vision_config = config.pop("vision_encoder") + + quant_config = config.get("quantization_config") + + config = { + "model_type": "pixtral", + "architectures": ["PixtralForConditionalGeneration"], + "text_config": config, + "vision_config": {"model_type": "pixtral", **vision_config}, + } + if quant_config: + config["quantization_config"] = quant_config + return config + + +def _remap_mistral_yarn_args(config: dict) -> dict: + yarn_config_map = { + "factor": "factor", + "original_max_position_embeddings": "original_max_position_embeddings", + "beta": "beta_fast", + "alpha": "beta_slow", + "apply_scale": None, + } + yarn_config = config.get("yarn") or {} + config["rope_scaling"] = { + "rope_type": "yarn", + "mscale_all_dim": 1, + } + for old_name, new_name in yarn_config_map.items(): + if old_name in yarn_config: + value = yarn_config.pop(old_name) + if new_name is not None: + config["rope_scaling"][new_name] = value + + assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}" + + return config + + +def _remap_general_mistral_args(config: dict) -> dict: + # Mistral key -> HF key + config_mapping = { + "dim": "hidden_size", + "norm_eps": "rms_norm_eps", + "n_kv_heads": "num_key_value_heads", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "hidden_dim": "intermediate_size", + } + # HF key -> (Mistral key, default value) + top_level_mapping_with_default = { + "model_type": ("model_type", "transformer"), + "hidden_act": ("activation", "silu"), + "tie_word_embeddings": ("tied_embeddings", False), + "max_seq_len": ("max_seq_len", 128_000), + "max_position_embeddings": ("max_position_embeddings", 128_000), + } + + for key, new_key in config_mapping.items(): + if key in config: + config[new_key] = config.pop(key) + + for new_key, (key, default_value) in top_level_mapping_with_default.items(): + config[new_key] = config.pop(key, default_value) + + return config + + +def _remap_mistral_quantization_args(config: dict) -> dict: + if config.get("quantization"): + quantization = config.pop("quantization", {}) + if quantization.get("qformat_weight") == "fp8_e4m3": + qscheme_act = quantization.get("qscheme_act") + assert qscheme_act in ( + "NO_SCALES", + "TENSOR", + None, + ), "Only NO_SCALES and TENSOR (default) are supported for qscheme_act" + is_dynamic = qscheme_act == "NO_SCALES" + config["quantization_config"] = { + "quant_method": "fp8", + "activation_scheme": "dynamic" if is_dynamic else "static", + } + else: + raise ValueError(f"Found unknown quantization='{quantization}' in config") + + return config + + +def _remap_mistral_audio_args(config: dict) -> dict: + whisper_args = config["multimodal"].pop("whisper_model_args") + encoder_args = whisper_args["encoder_args"] + downsample_args = whisper_args["downsample_args"] + + quant_config = config.get("quantization_config") + config = { + "model_type": "whixtral", + "architectures": ["VoxtralForConditionalGeneration"], + "text_config": PretrainedConfig.from_dict(config), + "audio_config": WhisperConfig( + num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"], + window_size=encoder_args["audio_encoding_args"]["window_size"], + sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"], + hop_length=encoder_args["audio_encoding_args"]["hop_length"], + downsample_factor=downsample_args["downsample_factor"], + d_model=encoder_args["dim"], + encoder_layers=encoder_args["n_layers"], + encoder_ffn_dim=encoder_args["hidden_dim"], + encoder_attention_heads=encoder_args["n_heads"], + vocab_size=encoder_args["vocab_size"], + max_source_positions=encoder_args["max_source_positions"], + is_encoder_decoder=False, # Override WhisperConfig default + ), + } + if quant_config: + config["quantization_config"] = quant_config + return config + + +def _remap_moe_args(config: dict) -> dict: + moe_config_map = { + "route_every_n": "moe_layer_freq", + "first_k_dense_replace": "first_k_dense_replace", + "num_experts_per_tok": "num_experts_per_tok", + "num_experts": "n_routed_experts", + "expert_hidden_dim": "moe_intermediate_size", + "routed_scale": "routed_scaling_factor", + "num_shared_experts": "n_shared_experts", + "num_expert_groups": "n_group", + "num_expert_groups_per_tok": "topk_group", + } + moe_config = config.get("moe", {}) + for old_name, new_name in moe_config_map.items(): + if old_name in moe_config: + value = moe_config.pop(old_name) + config[new_name] = value + + config["topk_method"] = None + config["routing_method_type"] = 1 # RoutingMethodType.Renormalize + config["scoring_func"] = "softmax" + + return config + + +class MistralConfigParser: + def get_hf_file_to_dict( + self, file_name: str, model: str | Path, revision: str | None = "main" + ): + file_path = Path(model) / file_name + if not file_path.is_file(): + # TODO: Add logic to download from HF in case file is not locally found + raise FileNotFoundError(f"File not found {model}, {file_name}") + + if file_path is not None and file_path.is_file(): + with open(file_path) as file: + return json.load(file) + + return None + + def _download_mistral_config_file(self, model, revision) -> dict: + config_file_name = "params.json" + config_dict = self.get_hf_file_to_dict(config_file_name, model, revision) + if config_dict is None: + raise ValueError( + f"Failed to load mistral '{config_file_name}' config for model " + f"{model}. Please check if the model is a mistral-format model " + f"and if the config file exists." + ) + assert isinstance(config_dict, dict) + return config_dict + + def parse( + self, + model: str | Path, + revision: str | None = None, + **kwargs, + ) -> tuple[dict, PretrainedConfig]: + # This function loads a params.json config which + # should be used when loading models in mistral format + config_dict = self._download_mistral_config_file(model, revision) + if config_dict.get("max_position_embeddings") is None: + logger.warning( + "The params.json file is missing 'max_position_embeddings'" + " and could not get a value from the HF config." + " Defaulting to 128000" + ) + config_dict["max_position_embeddings"] = 128_000 + + config_dict, config = adapt_config_dict(config_dict, model) + + # Mistral configs may define sliding_window as list[int]. Convert it + # to int and add the layer_types list[str] to make it HF compatible + if (sliding_window := getattr(config, "sliding_window", None)) and isinstance( + sliding_window, list + ): + pattern_repeats = config.num_hidden_layers // len(sliding_window) + layer_types = sliding_window * pattern_repeats + config.layer_types = [ + "full_attention" if layer_type is None else "sliding_attention" + for layer_type in layer_types + ] + config.sliding_window = next(filter(None, sliding_window), None) + + return config_dict, config diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index 68fcfb3bd5a8..65dff42be315 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -363,13 +363,25 @@ impl RouterTrait for RouterManager { } } - async fn get_model_info(&self, _req: Request) -> Response { - // TODO: Extract model from request and route to appropriate router - ( - StatusCode::NOT_IMPLEMENTED, - "Model info endpoint not yet implemented in RouterManager", - ) - .into_response() + async fn get_model_info(&self, req: Request) -> Response { + // Route to default router or first available router + let router_id = { + let default_router = self.default_router.read().unwrap(); + default_router.clone() + }; + + let router = if let Some(id) = router_id { + self.routers.get(&id).map(|r| r.clone()) + } else { + // If no default, use first available router + self.routers.iter().next().map(|r| r.value().clone()) + }; + + if let Some(router) = router { + router.get_model_info(req).await + } else { + (StatusCode::SERVICE_UNAVAILABLE, "No routers available").into_response() + } } async fn route_generate(