From 7eeffcb93a21a5ef2723202136b31cd683e64b93 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sat, 28 Feb 2026 13:55:27 +0000 Subject: [PATCH 01/27] Add Mistral4/Pixtral support changes --- python/sglang/srt/models/pixtral.py | 7 ++++++- .../srt/multimodal/processors/pixtral.py | 20 ++++++++++++++++--- .../sglang/srt/utils/hf_transformers_utils.py | 12 +++++------ 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/models/pixtral.py b/python/sglang/srt/models/pixtral.py index 265801901421..1981fe9d55a0 100644 --- a/python/sglang/srt/models/pixtral.py +++ b/python/sglang/srt/models/pixtral.py @@ -82,9 +82,14 @@ def __init__(self, *, config, prefix: str = "", **kwargs): super().__init__() self.config = config dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} + config_dict = self.config.vision_config.to_dict() + if config_dict.get("rope_parameters"): # transformers v5 compatibility + config_dict["rope_theta"] = config_dict["rope_parameters"].get("rope_theta") + config_dict["rope_scaling"] = config_dict["rope_parameters"] + config_dict.pop("rope_parameters") vision_args = { key: value - for key, value in self.config.vision_config.to_dict().items() + for key, value in config_dict.items() if key in dataclass_fields } diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index b923ff342a19..152f1a164279 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -10,6 +10,8 @@ PixtralForConditionalGeneration, PixtralVisionModel, ) +from transformers import PreTrainedTokenizerBase + from sglang.srt.multimodal.processors.base_processor import ( BaseMultimodalProcessor, MultimodalSpecialTokens, @@ -20,6 +22,7 @@ class PixtralProcessor(BaseMultimodalProcessor): models = [PixtralVisionModel, PixtralForConditionalGeneration] PAD_TOKEN = "" + DEFAULT_IMAGE_TOKEN = "[IMG]" IMG_BREAK_TOKEN_ID = 12 IMG_END_TOKEN_ID = 13 @@ -59,11 +62,22 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): if hasattr(self.vision_config, "spatial_merge_size"): self._processor.spatial_merge_size = self.vision_config.spatial_merge_size + # When the model lacks processor_config.json (e.g. Mistral-Small-4), + # AutoProcessor returns a bare tokenizer without an image_token attribute. + tokenizer = ( + _processor + if isinstance(_processor, PreTrainedTokenizerBase) + else _processor.tokenizer + ) + self.image_token = getattr( + _processor, "image_token", self.DEFAULT_IMAGE_TOKEN + ) + self.mm_tokens = MultimodalSpecialTokens( - image_token=_processor.image_token, + image_token=self.image_token, image_token_id=self.IM_TOKEN_ID, ).build(_processor) - _processor.tokenizer.add_special_tokens( + tokenizer.add_special_tokens( { "pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN), } @@ -103,5 +117,5 @@ async def process_mm_data_async( "mm_items": mm_items, "input_ids": input_ids.tolist(), "im_token_id": self.IM_TOKEN_ID, - "im_token": self._processor.image_token, + "im_token": self.image_token, } diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 9c6886fe54ca..b9cdbecb461d 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -14,6 +14,7 @@ """Utilities for Huggingface Transformers.""" import contextlib +from functools import lru_cache import json import logging import os @@ -207,11 +208,11 @@ def _load_deepseek_v32_model( # Temporary hack for Mistral Large +@lru_cache(maxsize=2) def _load_mistral_large_3_for_causal_LM( model_path: str, trust_remote_code: bool = False, revision: Optional[str] = None, - **kwargs, ): # first get the local path local_path = download_from_hf(model_path) @@ -223,7 +224,7 @@ def _load_mistral_large_3_for_causal_LM( json.dump(config_dict, f) f.flush() loaded_config = AutoConfig.from_pretrained( - f.name, trust_remote_code=trust_remote_code, revision=revision, **kwargs + f.name, trust_remote_code=trust_remote_code, revision=revision ) text_config = getattr(loaded_config, "text_config", None) if text_config is not None and isinstance(text_config, dict): @@ -305,9 +306,9 @@ def get_config( client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) model = client.get_local_dir() - if "mistral-large-3" in str(model).lower(): + if "mistral-large-3" in str(model).lower() or "mistral-small-4" in str(model).lower(): config = _load_mistral_large_3_for_causal_LM( - model, trust_remote_code=trust_remote_code, revision=revision, **kwargs + model, trust_remote_code=trust_remote_code, revision=revision ) else: _ensure_llama_flash_attention2_compat() @@ -583,12 +584,11 @@ def get_processor( ): # pop 'revision' from kwargs if present. revision = kwargs.pop("revision", tokenizer_revision) - if "mistral-large-3" in str(tokenizer_name).lower(): + if "mistral-large-3" in str(tokenizer_name).lower() or "mistral-small-4" in str(tokenizer_name).lower(): config = _load_mistral_large_3_for_causal_LM( tokenizer_name, trust_remote_code=trust_remote_code, revision=revision, - **kwargs, ) else: _ensure_llama_flash_attention2_compat() From c3297fc8e13360b531ab4879c24dde9d1229b33b Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Thu, 5 Mar 2026 22:01:50 +0000 Subject: [PATCH 02/27] lint Signed-off-by: Xinyuan Tong --- python/sglang/srt/models/pixtral.py | 6 ++---- python/sglang/srt/multimodal/processors/pixtral.py | 7 ++----- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/models/pixtral.py b/python/sglang/srt/models/pixtral.py index 1981fe9d55a0..166c3753453a 100644 --- a/python/sglang/srt/models/pixtral.py +++ b/python/sglang/srt/models/pixtral.py @@ -83,14 +83,12 @@ def __init__(self, *, config, prefix: str = "", **kwargs): self.config = config dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} config_dict = self.config.vision_config.to_dict() - if config_dict.get("rope_parameters"): # transformers v5 compatibility + if config_dict.get("rope_parameters"): # transformers v5 compatibility config_dict["rope_theta"] = config_dict["rope_parameters"].get("rope_theta") config_dict["rope_scaling"] = config_dict["rope_parameters"] config_dict.pop("rope_parameters") vision_args = { - key: value - for key, value in config_dict.items() - if key in dataclass_fields + key: value for key, value in config_dict.items() if key in dataclass_fields } self.vision_args = VisionEncoderArgs(**vision_args) diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index 152f1a164279..dac7a0b8aaab 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -2,6 +2,7 @@ import math from typing import List, Union +from transformers import PreTrainedTokenizerBase from transformers.models.pixtral.image_processing_pixtral import ( _num_image_tokens as _get_pixtral_hf_num_image_tokens, ) @@ -10,8 +11,6 @@ PixtralForConditionalGeneration, PixtralVisionModel, ) -from transformers import PreTrainedTokenizerBase - from sglang.srt.multimodal.processors.base_processor import ( BaseMultimodalProcessor, MultimodalSpecialTokens, @@ -69,9 +68,7 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): if isinstance(_processor, PreTrainedTokenizerBase) else _processor.tokenizer ) - self.image_token = getattr( - _processor, "image_token", self.DEFAULT_IMAGE_TOKEN - ) + self.image_token = getattr(_processor, "image_token", self.DEFAULT_IMAGE_TOKEN) self.mm_tokens = MultimodalSpecialTokens( image_token=self.image_token, From 296fcd5af375466ae483cdee8f4ba50653356702 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Thu, 5 Mar 2026 22:02:37 +0000 Subject: [PATCH 03/27] Add special handling for mistral 4 Signed-off-by: Xinyuan Tong --- .../sglang/srt/utils/hf_transformers_utils.py | 48 +++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index b9cdbecb461d..38ecb88139de 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -14,12 +14,12 @@ """Utilities for Huggingface Transformers.""" import contextlib -from functools import lru_cache import json import logging import os import tempfile import warnings +from functools import lru_cache from pathlib import Path from typing import Any, Dict, List, Optional, Type, Union @@ -306,7 +306,10 @@ def get_config( client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) model = client.get_local_dir() - if "mistral-large-3" in str(model).lower() or "mistral-small-4" in str(model).lower(): + if ( + "mistral-large-3" in str(model).lower() + or "mistral-small-4" in str(model).lower() + ): config = _load_mistral_large_3_for_causal_LM( model, trust_remote_code=trust_remote_code, revision=revision ) @@ -584,7 +587,10 @@ def get_processor( ): # pop 'revision' from kwargs if present. revision = kwargs.pop("revision", tokenizer_revision) - if "mistral-large-3" in str(tokenizer_name).lower() or "mistral-small-4" in str(tokenizer_name).lower(): + if ( + "mistral-large-3" in str(tokenizer_name).lower() + or "mistral-small-4" in str(tokenizer_name).lower() + ): config = _load_mistral_large_3_for_causal_LM( tokenizer_name, trust_remote_code=trust_remote_code, @@ -658,6 +664,42 @@ def get_processor( ) else: raise e + # If processor is a bare tokenizer (e.g. Mistral-Small-4 has no processor_config.json) + # and the model is a vision model (pixtral), wrap it in a proper PixtralProcessor + # so that image data is actually processed through the image processor. + if ( + isinstance(processor, PreTrainedTokenizerBase) + and getattr(config, "model_type", None) == "pixtral" + ): + from transformers.models.pixtral.image_processing_pixtral import ( + PixtralImageProcessor, + ) + from transformers.models.pixtral.processing_pixtral import ( + PixtralProcessor as HFPixtralProcessor, + ) + + vision_config = config.vision_config + if isinstance(vision_config, dict): + patch_size = vision_config.get("patch_size", 16) + image_size = vision_config.get("image_size", 1024) + spatial_merge_size = vision_config.get("spatial_merge_size", 1) + else: + patch_size = getattr(vision_config, "patch_size", 16) + image_size = getattr(vision_config, "image_size", 1024) + spatial_merge_size = getattr(vision_config, "spatial_merge_size", 1) + + image_processor = PixtralImageProcessor( + do_resize=True, + size={"longest_edge": image_size}, + patch_size={"height": patch_size, "width": patch_size}, + ) + processor = HFPixtralProcessor( + image_processor=image_processor, + tokenizer=processor, + patch_size=patch_size, + spatial_merge_size=spatial_merge_size, + ) + tokenizer = get_tokenizer_from_processor(processor) attach_additional_stop_token_ids(tokenizer) From c7457c9adc8c57f92fe6c5a0f1533d84d6caea11 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Thu, 5 Mar 2026 22:02:57 +0000 Subject: [PATCH 04/27] add reasoning parser for mistral Signed-off-by: Xinyuan Tong --- python/sglang/srt/parser/reasoning_parser.py | 28 ++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/sglang/srt/parser/reasoning_parser.py b/python/sglang/srt/parser/reasoning_parser.py index cd346625ae31..84fb5d6a8d2d 100644 --- a/python/sglang/srt/parser/reasoning_parser.py +++ b/python/sglang/srt/parser/reasoning_parser.py @@ -361,6 +361,33 @@ def __init__( ) +class MistralDetector(BaseReasoningFormatDetector): + """ + Detector for Mistral models with reasoning (e.g., Mistral-Small-4-119B-2602). + Assumes reasoning format: + [THINK]reasoning content[/THINK]answer + + Reasoning is optional — it only appears when reasoning_effort="high" is set. + When reasoning_effort="none", the model outputs directly without thinking tokens. + """ + + def __init__( + self, + stream_reasoning: bool = True, + force_reasoning: bool = False, + continue_final_message: bool = False, + previous_content: str = "", + ): + super().__init__( + "[THINK]", + "[/THINK]", + force_reasoning=force_reasoning, + stream_reasoning=stream_reasoning, + continue_final_message=continue_final_message, + previous_content=previous_content, + ) + + class ReasoningParser: """ Parser that handles both streaming and non-streaming scenarios for extracting @@ -385,6 +412,7 @@ class ReasoningParser: "minimax-append-think": MiniMaxAppendThinkDetector, "step3": DeepSeekR1Detector, "step3p5": DeepSeekR1Detector, + "mistral": MistralDetector, "nano_v3": NanoV3Detector, "interns1": Qwen3Detector, } From 557d6fdbc63566f98bb03ee6a46f9a9c5280a82b Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Thu, 5 Mar 2026 22:12:44 +0000 Subject: [PATCH 05/27] Set default reasoning_effort to None in ChatCompletionRequest Signed-off-by: Xinyuan Tong --- python/sglang/srt/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 1550c707a311..90bffd8431f5 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -589,7 +589,7 @@ class ChatCompletionRequest(BaseModel): return_routed_experts: bool = False return_cached_tokens_details: bool = False reasoning_effort: Optional[Literal["low", "medium", "high"]] = Field( - default="medium", + default=None, description="Constrains effort on reasoning for reasoning models. " "'low' is the least effort, 'high' is the most effort. Reducing reasoning effort can " "result in faster responses and fewer tokens used on reasoning in a response. " From 2c4349ffb54cdc7071c6be917d78cebd1d91f9bf Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Thu, 5 Mar 2026 23:15:04 +0000 Subject: [PATCH 06/27] fix: Add activation type mapping for FlashInfer in moe_runner Signed-off-by: Xinyuan Tong --- .../moe/moe_runner/flashinfer_trtllm.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index b76e37b0b6bb..fa7b700cfb8d 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -17,6 +17,31 @@ MoeRunnerConfig, register_fused_func, ) + + +def _get_flashinfer_activation_type(activation: str, is_gated: bool) -> int: + """Map SGLang activation config to flashinfer ActivationType int values. + + See flashinfer.fused_moe.core.ActivationType for the enum definition. + """ + if is_gated: + if activation == "silu": + return 3 # Swiglu + elif activation == "gelu": + return 4 # Geglu + else: + raise ValueError(f"Unsupported gated activation: {activation}") + else: + if activation == "silu": + return 2 # Silu + elif activation == "gelu": + return 0 # Gelu + elif activation == "relu": + return 1 # Relu + else: + raise ValueError(f"Unsupported activation: {activation}") + + from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, scaled_fp8_quant, @@ -328,6 +353,9 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( use_routing_scales_on_input=quant_info.use_routing_scales_on_input, routing_method_type=routing_method_type, tune_max_num_tokens=next_power_of_2(a_q.shape[0]), + activation_type=_get_flashinfer_activation_type( + runner_config.activation, runner_config.is_gated + ), ) return StandardCombineInput(hidden_states=output) From 0322d01b12429c732e44792213c0e742915d3db7 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Fri, 6 Mar 2026 14:50:50 +0000 Subject: [PATCH 07/27] fix: add reasoning request handling for mistral 4 Signed-off-by: Xinyuan Tong --- python/sglang/srt/entrypoints/openai/serving_chat.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index fde200f82688..b472960924d6 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -1209,7 +1209,9 @@ def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int: return idx def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: - """Judge whether the request needs reasoning""" + """Judge whether the request needs reasoning for hybrid reasoning models + NOTE: This is predefined based on model's chat template + """ if not self.reasoning_parser: return False if self.reasoning_parser in ["deepseek-v3"]: @@ -1230,6 +1232,9 @@ def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: not request.chat_template_kwargs or request.chat_template_kwargs.get("enable_thinking") is not False ) + if self.reasoning_parser in ["mistral"]: + # Mistral models only reason when reasoning_effort="high" + return request.reasoning_effort == "high" return True # default async def _process_tool_call_stream( From 4802ecc21d8c9dc557da9ffa2e26b40f319bd85a Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Fri, 6 Mar 2026 20:56:44 +0000 Subject: [PATCH 08/27] fix: streamline vision config handling in get_processor function Signed-off-by: Xinyuan Tong --- python/sglang/srt/utils/hf_transformers_utils.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 38ecb88139de..1633a3595f4e 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -679,14 +679,9 @@ def get_processor( ) vision_config = config.vision_config - if isinstance(vision_config, dict): - patch_size = vision_config.get("patch_size", 16) - image_size = vision_config.get("image_size", 1024) - spatial_merge_size = vision_config.get("spatial_merge_size", 1) - else: - patch_size = getattr(vision_config, "patch_size", 16) - image_size = getattr(vision_config, "image_size", 1024) - spatial_merge_size = getattr(vision_config, "spatial_merge_size", 1) + patch_size = vision_config.patch_size + image_size = vision_config.image_size + spatial_merge_size = getattr(vision_config, "spatial_merge_size", 1) image_processor = PixtralImageProcessor( do_resize=True, From 04a8673ed9a10f43d7fcc3903b9d115bff6de698 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Fri, 6 Mar 2026 21:01:54 +0000 Subject: [PATCH 09/27] fix: adjust patch grid size calculation to incorporate spatial merge size Signed-off-by: Xinyuan Tong --- python/sglang/srt/multimodal/processors/pixtral.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index dac7a0b8aaab..15e28f070a5c 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -33,17 +33,29 @@ def get_patch_grid_size( ) -> tuple[int, int]: max_width = max_height = self.image_size patch_width = patch_height = self.patch_size + spatial_merge_size = getattr(self.vision_config, "spatial_merge_size", 1) ratio = max(image_width / max_width, image_height / max_height) if ratio > 1: image_width = int(math.floor(image_width / ratio)) image_height = int(math.floor(image_height / ratio)) + # Use effective_patch_size = patch_size * spatial_merge_size so that + # the resulting patch grid dimensions are divisible by spatial_merge_size. + # This matches the reference mistral_common implementation and ensures + # the PatchMerger (which groups spatial_merge_size^2 patches) works correctly. + effective_patch_width = patch_width * spatial_merge_size + effective_patch_height = patch_height * spatial_merge_size + nrows, ncols = _get_pixtral_hf_num_image_tokens( (image_height, image_width), - (patch_height, patch_width), + (effective_patch_height, effective_patch_width), ) + # Scale back: each "effective token" is spatial_merge_size actual patch tokens + nrows *= spatial_merge_size + ncols *= spatial_merge_size + return ncols, nrows def __init__(self, hf_config, server_args, _processor, *args, **kwargs): From 3ae9d1e37a1dc01c833930ac3c834d240adb19fa Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Thu, 12 Mar 2026 18:52:16 +0000 Subject: [PATCH 10/27] fix(pixtral): use effective_patch_size for image resize and simplify processor - Use patch_size * spatial_merge_size as the effective patch size in PixtralImageProcessor so images resize to multiples of 28 (not 14), matching PatchMerger requirements with spatial_merge_size=2 - Remove manual _resize and get_patch_grid_size methods, relying on the correctly configured HF image processor instead - Add multi-image offset splitting for per-image MultimodalDataItem - Remove unused torch import --- .../srt/multimodal/processors/pixtral.py | 101 +++++++++--------- .../sglang/srt/utils/hf_transformers_utils.py | 3 +- 2 files changed, 50 insertions(+), 54 deletions(-) diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index 15e28f070a5c..c8e6b5be0307 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -1,4 +1,3 @@ -import asyncio import math from typing import List, Union @@ -7,6 +6,7 @@ _num_image_tokens as _get_pixtral_hf_num_image_tokens, ) +from sglang.srt.managers.schedule_batch import Modality from sglang.srt.models.pixtral import ( PixtralForConditionalGeneration, PixtralVisionModel, @@ -22,48 +22,12 @@ class PixtralProcessor(BaseMultimodalProcessor): PAD_TOKEN = "" DEFAULT_IMAGE_TOKEN = "[IMG]" - IMG_BREAK_TOKEN_ID = 12 - IMG_END_TOKEN_ID = 13 - - def get_patch_grid_size( - self, - *, - image_width: int, - image_height: int, - ) -> tuple[int, int]: - max_width = max_height = self.image_size - patch_width = patch_height = self.patch_size - spatial_merge_size = getattr(self.vision_config, "spatial_merge_size", 1) - - ratio = max(image_width / max_width, image_height / max_height) - if ratio > 1: - image_width = int(math.floor(image_width / ratio)) - image_height = int(math.floor(image_height / ratio)) - - # Use effective_patch_size = patch_size * spatial_merge_size so that - # the resulting patch grid dimensions are divisible by spatial_merge_size. - # This matches the reference mistral_common implementation and ensures - # the PatchMerger (which groups spatial_merge_size^2 patches) works correctly. - effective_patch_width = patch_width * spatial_merge_size - effective_patch_height = patch_height * spatial_merge_size - - nrows, ncols = _get_pixtral_hf_num_image_tokens( - (image_height, image_width), - (effective_patch_height, effective_patch_width), - ) - - # Scale back: each "effective token" is spatial_merge_size actual patch tokens - nrows *= spatial_merge_size - ncols *= spatial_merge_size - - return ncols, nrows def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) self.IM_TOKEN_ID = getattr( hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID ) - # Instantiate the patcher logic helper using the class defined above self.vision_config = hf_config.vision_config self.image_size = self.vision_config.image_size @@ -73,8 +37,6 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): if hasattr(self.vision_config, "spatial_merge_size"): self._processor.spatial_merge_size = self.vision_config.spatial_merge_size - # When the model lacks processor_config.json (e.g. Mistral-Small-4), - # AutoProcessor returns a bare tokenizer without an image_token attribute. tokenizer = ( _processor if isinstance(_processor, PreTrainedTokenizerBase) @@ -92,14 +54,6 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): } ) - async def _resize(self, image): - num_w_tokens, num_h_tokens = self.get_patch_grid_size( - image_width=image.size[0], - image_height=image.size[1], - ) - new_size = (num_w_tokens * self.patch_size, num_h_tokens * self.patch_size) - return image.resize(new_size) - async def process_mm_data_async( self, image_data: List[Union[str, bytes]], @@ -115,12 +69,53 @@ async def process_mm_data_async( return_text=True, ) if mm_data.images: - resize_tasks = [self._resize(image) for image in mm_data.images] - mm_data.images = await asyncio.gather(*resize_tasks) - - mm_items, input_ids, _ = self.process_and_combine_mm_data( - mm_data, self.mm_tokens - ) + # Track per-image row counts for multi-image offset splitting + spatial_merge_size = getattr( + self.vision_config, "spatial_merge_size", 1 + ) + effective_patch = self.patch_size * spatial_merge_size + image_nrows = [] + for img in mm_data.images: + w, h = img.size + ratio = max(w / self.image_size, h / self.image_size) + if ratio > 1: + w = int(math.floor(w / ratio)) + h = int(math.floor(h / ratio)) + nrows, ncols = _get_pixtral_hf_num_image_tokens( + (h, w), (effective_patch, effective_patch) + ) + image_nrows.append(nrows) + + mm_items, input_ids, _ = self.process_and_combine_mm_data( + mm_data, self.mm_tokens + ) + + # For multi-image: split single IMAGE mm_item into per-image items + if len(mm_data.images) > 1: + from sglang.srt.managers.schedule_batch import MultimodalDataItem + + old_item = next( + item for item in mm_items if item.modality == Modality.IMAGE + ) + all_offsets = old_item.offsets + old_feature = old_item.feature + + mm_items = [ + item for item in mm_items if item.modality != Modality.IMAGE + ] + offset_idx = 0 + for i, img in enumerate(mm_data.images): + nr = image_nrows[i] + item_offsets = all_offsets[offset_idx : offset_idx + nr] + offset_idx += nr + new_item = MultimodalDataItem(modality=Modality.IMAGE) + new_item.feature = old_feature[i : i + 1] + new_item.offsets = item_offsets + mm_items.append(new_item) + else: + mm_items, input_ids, _ = self.process_and_combine_mm_data( + mm_data, self.mm_tokens + ) return { "mm_items": mm_items, diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 1633a3595f4e..ddfbe75fe62c 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -683,10 +683,11 @@ def get_processor( image_size = vision_config.image_size spatial_merge_size = getattr(vision_config, "spatial_merge_size", 1) + effective_patch = patch_size * spatial_merge_size image_processor = PixtralImageProcessor( do_resize=True, size={"longest_edge": image_size}, - patch_size={"height": patch_size, "width": patch_size}, + patch_size={"height": effective_patch, "width": effective_patch}, ) processor = HFPixtralProcessor( image_processor=image_processor, From 07744ec784aab81a7d33231660f9b25e085d370b Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Thu, 12 Mar 2026 18:52:22 +0000 Subject: [PATCH 11/27] feat(mmmu): add --model and --reasoning-effort flags to benchmark - Add --model flag (default "default") to avoid hardcoded model name - Add --reasoning-effort flag passed as top-level request field - Support local image paths via base64 data URI encoding - Pass reasoning_effort and model as explicit parameters instead of smuggling through sampling_params dict --- benchmark/mmmu/bench_sglang.py | 46 +++++++++++++++++++++++++++------- benchmark/mmmu/eval_utils.py | 8 ++++++ 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/benchmark/mmmu/bench_sglang.py b/benchmark/mmmu/bench_sglang.py index d9426ae5a3ac..27b852c83bf0 100644 --- a/benchmark/mmmu/bench_sglang.py +++ b/benchmark/mmmu/bench_sglang.py @@ -11,11 +11,14 @@ import argparse import asyncio +import base64 +import mimetypes import re import sys import time import traceback from dataclasses import dataclass, field +from pathlib import Path from typing import Any, List, Optional, Tuple import aiohttp @@ -74,7 +77,12 @@ def _get_prefix_suffix(prompt: str) -> Tuple[str, str]: async def process_sample( - client: Any, sample: dict, sampling_params: dict, lora_path: Optional[str] = None + client: Any, + sample: dict, + sampling_params: dict, + model: str, + reasoning_effort: Optional[str] = None, + lora_path: Optional[str] = None, ) -> Tuple[dict, str]: """Send a single sample to the LLM and return (sample, response).""" prompt = sample["final_input_prompt"] @@ -82,23 +90,32 @@ async def process_sample( image = sample["image"] assert image is not None image_path = sample["image_path"] - extra_body = None if lora_path is None else {"lora_path": lora_path} + if image_path and not image_path.startswith(("http://", "https://", "data:")): + p = Path(image_path) + mime = mimetypes.guess_type(str(p))[0] or "image/png" + with open(p, "rb") as f: + b64 = base64.b64encode(f.read()).decode() + image_url = f"data:{mime};base64,{b64}" + else: + image_url = image_path + extra_body = {"lora_path": lora_path} if lora_path else None payload = { - "model": "default", + "model": model, "messages": [ { "role": "user", "content": [ {"type": "text", "text": prefix}, - {"type": "image_url", "image_url": {"url": image_path}}, + {"type": "image_url", "image_url": {"url": image_url}}, {"type": "text", "text": suffix}, ], } ], "extra_body": extra_body, + **sampling_params, } - if sampling_params: - payload.update(sampling_params) + if reasoning_effort: + payload["reasoning_effort"] = reasoning_effort response = await client.chat.completions.create(**payload) return sample, response.choices[0].message.content @@ -108,11 +125,15 @@ async def process_sample_with_semaphore( client: Any, sample: dict, sampling_params: dict, + model: str, + reasoning_effort: Optional[str] = None, lora_path: Optional[str] = None, ) -> Tuple[dict, str]: """Wrap process_sample with a semaphore for concurrency control.""" async with semaphore: - return await process_sample(client, sample, sampling_params, lora_path) + return await process_sample( + client, sample, sampling_params, model, reasoning_effort, lora_path + ) async def eval_mmmu(args) -> None: @@ -120,6 +141,8 @@ async def eval_mmmu(args) -> None: eval_args = EvalArgs.from_cli_args(args) sampling_params = get_sampling_params(eval_args) samples = prepare_samples(eval_args) + model = args.model + reasoning_effort = eval_args.reasoning_effort lora_path = eval_args.lora_path answer_dict = {} out_samples = {} @@ -146,7 +169,7 @@ async def eval_mmmu(args) -> None: # this is mainly for profiling for sample in tqdm(samples): _, response = await process_sample( - client, sample, sampling_params, lora_path + client, sample, sampling_params, model, reasoning_effort, lora_path ) sample["original_response"] = response answer = ( @@ -164,7 +187,8 @@ async def eval_mmmu(args) -> None: semaphore = asyncio.Semaphore(args.concurrency) tasks = [ process_sample_with_semaphore( - semaphore, client, sample, sampling_params, lora_path + semaphore, client, sample, sampling_params, model, + reasoning_effort, lora_path, ) for sample in samples ] @@ -202,6 +226,10 @@ async def eval_mmmu(args) -> None: def parse_args(): parser = argparse.ArgumentParser() + parser.add_argument( + "--model", type=str, default="default", + help="Model name to use in API requests.", + ) EvalArgs.add_cli_args(parser) args = add_common_sglang_args_and_parse(parser) return args diff --git a/benchmark/mmmu/eval_utils.py b/benchmark/mmmu/eval_utils.py index b3edd69fc1ce..1a358ab6c827 100644 --- a/benchmark/mmmu/eval_utils.py +++ b/benchmark/mmmu/eval_utils.py @@ -40,6 +40,7 @@ class EvalArgs: temperature: Optional[float] = None response_answer_regex: str = "(.*)" lora_path: Optional[str] = None + reasoning_effort: Optional[str] = None @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -120,6 +121,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=EvalArgs.lora_path, help="Specify the LoRA path to use for evaluation. If specified, the value will be specified in the body of every request as `lora-path`.", ) + parser.add_argument( + "--reasoning-effort", + type=str, + default=EvalArgs.reasoning_effort, + choices=["none", "high"], + help="Reasoning effort for the model (none or high).", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): From 0f1471ec8dbe3a3b704d267f2484f63d49ac6876 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Thu, 12 Mar 2026 20:03:18 +0000 Subject: [PATCH 12/27] cleanup: remove redundant activation_type mapping and unused ncols variable The flashinfer trtllm_fp8_per_tensor_scale_moe already defaults activation_type to Swiglu (3), which matches Mistral-Small-4's silu+gated config. Also replace unused ncols with _ in pixtral processor. --- .../moe/moe_runner/flashinfer_trtllm.py | 28 ------------------- .../srt/multimodal/processors/pixtral.py | 2 +- 2 files changed, 1 insertion(+), 29 deletions(-) diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index fa7b700cfb8d..b76e37b0b6bb 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -17,31 +17,6 @@ MoeRunnerConfig, register_fused_func, ) - - -def _get_flashinfer_activation_type(activation: str, is_gated: bool) -> int: - """Map SGLang activation config to flashinfer ActivationType int values. - - See flashinfer.fused_moe.core.ActivationType for the enum definition. - """ - if is_gated: - if activation == "silu": - return 3 # Swiglu - elif activation == "gelu": - return 4 # Geglu - else: - raise ValueError(f"Unsupported gated activation: {activation}") - else: - if activation == "silu": - return 2 # Silu - elif activation == "gelu": - return 0 # Gelu - elif activation == "relu": - return 1 # Relu - else: - raise ValueError(f"Unsupported activation: {activation}") - - from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, scaled_fp8_quant, @@ -353,9 +328,6 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( use_routing_scales_on_input=quant_info.use_routing_scales_on_input, routing_method_type=routing_method_type, tune_max_num_tokens=next_power_of_2(a_q.shape[0]), - activation_type=_get_flashinfer_activation_type( - runner_config.activation, runner_config.is_gated - ), ) return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index c8e6b5be0307..df50d10698ce 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -81,7 +81,7 @@ async def process_mm_data_async( if ratio > 1: w = int(math.floor(w / ratio)) h = int(math.floor(h / ratio)) - nrows, ncols = _get_pixtral_hf_num_image_tokens( + nrows, _ = _get_pixtral_hf_num_image_tokens( (h, w), (effective_patch, effective_patch) ) image_nrows.append(nrows) From e10aa5a53782acca0d3ef05e5b5e5f4be8a5b9a1 Mon Sep 17 00:00:00 2001 From: Alex Nails Date: Mon, 16 Mar 2026 09:06:45 +0000 Subject: [PATCH 13/27] fix reasoning trace having answer and benchmark getting no answers eval with 0% accuracy when thinking --- benchmark/mmmu/bench_sglang.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/benchmark/mmmu/bench_sglang.py b/benchmark/mmmu/bench_sglang.py index 27b852c83bf0..24ea8f3bff34 100644 --- a/benchmark/mmmu/bench_sglang.py +++ b/benchmark/mmmu/bench_sglang.py @@ -117,7 +117,11 @@ async def process_sample( if reasoning_effort: payload["reasoning_effort"] = reasoning_effort response = await client.chat.completions.create(**payload) - return sample, response.choices[0].message.content + msg = response.choices[0].message + content = msg.content + if content is None: + content = getattr(msg, "reasoning_content", None) + return sample, content async def process_sample_with_semaphore( From 2041c65642c54dcf0c541b99c9b5b36a91768ecc Mon Sep 17 00:00:00 2001 From: Alex Nails Date: Mon, 16 Mar 2026 09:07:06 +0000 Subject: [PATCH 14/27] possible fix for -HF chkpt --- .../srt/multimodal/processors/pixtral.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index df50d10698ce..9d90baaae27d 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -33,9 +33,17 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): self.image_size = self.vision_config.image_size self.patch_size = self.vision_config.patch_size + # spatial_merge_size may live on vision_config (Mistral native) or + # on the top-level config (HF native Mistral3Config). + self._spatial_merge_size = getattr( + self.vision_config, + "spatial_merge_size", + getattr(hf_config, "spatial_merge_size", 1), + ) + self._processor.patch_size = self.patch_size - if hasattr(self.vision_config, "spatial_merge_size"): - self._processor.spatial_merge_size = self.vision_config.spatial_merge_size + if self._spatial_merge_size > 1: + self._processor.spatial_merge_size = self._spatial_merge_size tokenizer = ( _processor @@ -69,11 +77,7 @@ async def process_mm_data_async( return_text=True, ) if mm_data.images: - # Track per-image row counts for multi-image offset splitting - spatial_merge_size = getattr( - self.vision_config, "spatial_merge_size", 1 - ) - effective_patch = self.patch_size * spatial_merge_size + effective_patch = self.patch_size * self._spatial_merge_size image_nrows = [] for img in mm_data.images: w, h = img.size @@ -99,6 +103,7 @@ async def process_mm_data_async( ) all_offsets = old_item.offsets old_feature = old_item.feature + old_image_sizes = getattr(old_item, "image_sizes", None) mm_items = [ item for item in mm_items if item.modality != Modality.IMAGE @@ -111,6 +116,10 @@ async def process_mm_data_async( new_item = MultimodalDataItem(modality=Modality.IMAGE) new_item.feature = old_feature[i : i + 1] new_item.offsets = item_offsets + if old_image_sizes is not None: + new_item.model_specific_data["image_sizes"] = ( + old_image_sizes[i : i + 1] + ) mm_items.append(new_item) else: mm_items, input_ids, _ = self.process_and_combine_mm_data( From 01c72d67426681fddfd613fe18759c8851cba629 Mon Sep 17 00:00:00 2001 From: Alex Nails Date: Mon, 16 Mar 2026 09:07:23 +0000 Subject: [PATCH 15/27] LeanStral works --- python/sglang/srt/utils/hf_transformers_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index ddfbe75fe62c..471c9a6e5125 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -309,6 +309,7 @@ def get_config( if ( "mistral-large-3" in str(model).lower() or "mistral-small-4" in str(model).lower() + or "leanstral" in str(model).lower() ): config = _load_mistral_large_3_for_causal_LM( model, trust_remote_code=trust_remote_code, revision=revision @@ -590,6 +591,7 @@ def get_processor( if ( "mistral-large-3" in str(tokenizer_name).lower() or "mistral-small-4" in str(tokenizer_name).lower() + or "leanstral" in str(tokenizer_name).lower() ): config = _load_mistral_large_3_for_causal_LM( tokenizer_name, @@ -698,6 +700,15 @@ def get_processor( tokenizer = get_tokenizer_from_processor(processor) + if tokenizer.chat_template is None: + local_path = download_from_hf( + tokenizer_name, allow_patterns=["*.json", "*.jinja", "*.model"] + ) + jinja_path = Path(local_path) / "chat_template.jinja" + if jinja_path.is_file(): + tokenizer.chat_template = jinja_path.read_text() + logger.info("Loaded chat_template from %s", jinja_path) + attach_additional_stop_token_ids(tokenizer) return processor From afe877265e6363f7880d8bb59b8f0eab546c7f51 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Mar 2026 17:26:33 +0000 Subject: [PATCH 16/27] lint Signed-off-by: Xinyuan Tong --- benchmark/mmmu/bench_sglang.py | 13 ++++++++++--- python/sglang/srt/multimodal/processors/pixtral.py | 6 +++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/benchmark/mmmu/bench_sglang.py b/benchmark/mmmu/bench_sglang.py index 24ea8f3bff34..0a28c7fc270c 100644 --- a/benchmark/mmmu/bench_sglang.py +++ b/benchmark/mmmu/bench_sglang.py @@ -191,8 +191,13 @@ async def eval_mmmu(args) -> None: semaphore = asyncio.Semaphore(args.concurrency) tasks = [ process_sample_with_semaphore( - semaphore, client, sample, sampling_params, model, - reasoning_effort, lora_path, + semaphore, + client, + sample, + sampling_params, + model, + reasoning_effort, + lora_path, ) for sample in samples ] @@ -231,7 +236,9 @@ async def eval_mmmu(args) -> None: def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--model", type=str, default="default", + "--model", + type=str, + default="default", help="Model name to use in API requests.", ) EvalArgs.add_cli_args(parser) diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index 9d90baaae27d..47b1513e8fd6 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -117,9 +117,9 @@ async def process_mm_data_async( new_item.feature = old_feature[i : i + 1] new_item.offsets = item_offsets if old_image_sizes is not None: - new_item.model_specific_data["image_sizes"] = ( - old_image_sizes[i : i + 1] - ) + new_item.model_specific_data["image_sizes"] = old_image_sizes[ + i : i + 1 + ] mm_items.append(new_item) else: mm_items, input_ids, _ = self.process_and_combine_mm_data( From d508481e63ecacff7bc1919155da14aeaf46b7e3 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Mar 2026 17:38:24 +0000 Subject: [PATCH 17/27] fix: update model name in MistralDetector docstring (2602 -> 2603) --- python/sglang/srt/parser/reasoning_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/parser/reasoning_parser.py b/python/sglang/srt/parser/reasoning_parser.py index 82fd5c814c0b..f46ebf8379ba 100644 --- a/python/sglang/srt/parser/reasoning_parser.py +++ b/python/sglang/srt/parser/reasoning_parser.py @@ -444,7 +444,7 @@ def __init__( class MistralDetector(BaseReasoningFormatDetector): """ - Detector for Mistral models with reasoning (e.g., Mistral-Small-4-119B-2602). + Detector for Mistral models with reasoning (e.g., Mistral-Small-4-119B-2603). Assumes reasoning format: [THINK]reasoning content[/THINK]answer From 34a699f0f0ca2cfe5367e716bde1251c4d8354ee Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Mar 2026 18:37:23 +0000 Subject: [PATCH 18/27] fix: expose mistral load format and update MistralDetector docstring --- python/sglang/srt/server_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 43c62b0bed5c..4d8994ed132b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -82,6 +82,7 @@ "sharded_state", "gguf", "bitsandbytes", + "mistral", "layered", "flash_rl", "remote", From 7da76666d910c05b1049a05abaa280cacf9c1b78 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Mar 2026 19:06:21 +0000 Subject: [PATCH 19/27] fix: use correct custom op name for trtllm_fp8_per_tensor_scale_moe_wrapper --- python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index 1a3ef033b891..62005ed46311 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -465,7 +465,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( # Move kernel call outside context manager to avoid graph breaks # during torch.compile for piecewise cuda graph. # Use custom op wrapper for torch.compile compatibility. - output = torch.ops.sglang.trtllm_fp8_per_tensor_scale_moe( + output = torch.ops.sglang.trtllm_fp8_per_tensor_scale_moe_wrapper( routing_logits=router_logits.to(torch.bfloat16), routing_bias=routing_bias_cast, hidden_states=a_q, From c04df33ee7a782010d8042b86621da9d52625e25 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Mar 2026 19:44:39 +0000 Subject: [PATCH 20/27] feat: auto-detect Mistral native format and set load_format='mistral' --- python/sglang/srt/server_args.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4d8994ed132b..075a5dc59134 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2905,6 +2905,12 @@ def _handle_load_format(self): ) and check_gguf_file(self.model_path): self.quantization = self.load_format = "gguf" + if self.load_format == "auto" and self._is_mistral_native_format(): + self.load_format = "mistral" + logger.info( + "Detected Mistral native format checkpoint, setting load_format='mistral'" + ) + if is_remote_url(self.model_path): self.load_format = "remote" @@ -2943,6 +2949,22 @@ def _handle_load_format(self): self.validate_transfer_engine() ) + def _is_mistral_native_format(self) -> bool: + """Detect if the model uses Mistral native format (params.json + consolidated weights).""" + if os.path.isdir(self.model_path): + return os.path.exists(os.path.join(self.model_path, "params.json")) + # For hub models, check remote files + try: + from huggingface_hub import HfApi + + files = { + s.rfilename + for s in HfApi().model_info(self.model_path).siblings + } + return "params.json" in files + except Exception: + return False + def _handle_pd_disaggregation(self): if self.disaggregation_mode == "decode": self.disable_radix_cache = True From e22540b4d62059350b933327757e4df9787d9248 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Mar 2026 22:04:14 +0000 Subject: [PATCH 21/27] lint Signed-off-by: Xinyuan Tong --- python/sglang/srt/server_args.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 075a5dc59134..fd908155ae11 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2957,10 +2957,7 @@ def _is_mistral_native_format(self) -> bool: try: from huggingface_hub import HfApi - files = { - s.rfilename - for s in HfApi().model_info(self.model_path).siblings - } + files = {s.rfilename for s in HfApi().model_info(self.model_path).siblings} return "params.json" in files except Exception: return False From bbc726708abb65ff233db0ccfe7cb8e7e001ef35 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Mar 2026 22:04:34 +0000 Subject: [PATCH 22/27] fix: add defaults to PretrainedConfig subclass annotations for transformers v5 compat --- python/sglang/srt/configs/deepseek_ocr.py | 4 ++-- python/sglang/srt/configs/deepseekvl2.py | 6 +++--- python/sglang/srt/configs/janus_pro.py | 24 +++++++++++------------ python/sglang/srt/configs/jet_nemotron.py | 24 +++++++++++------------ 4 files changed, 29 insertions(+), 29 deletions(-) diff --git a/python/sglang/srt/configs/deepseek_ocr.py b/python/sglang/srt/configs/deepseek_ocr.py index 1677423d1811..b742ff036bc0 100644 --- a/python/sglang/srt/configs/deepseek_ocr.py +++ b/python/sglang/srt/configs/deepseek_ocr.py @@ -781,8 +781,8 @@ def __init__( class DeepseekVLV2Config(PretrainedConfig): # model_type = "deepseek_vl_v2" model_type = "deepseek-ocr" - vision_config: VisionEncoderConfig - projector_config: MlpProjectorConfig + vision_config: VisionEncoderConfig = None + projector_config: MlpProjectorConfig = None tile_tag: str = "2D" global_view_pos: str = "head" diff --git a/python/sglang/srt/configs/deepseekvl2.py b/python/sglang/srt/configs/deepseekvl2.py index 9621f058bf63..e8f784258954 100644 --- a/python/sglang/srt/configs/deepseekvl2.py +++ b/python/sglang/srt/configs/deepseekvl2.py @@ -649,9 +649,9 @@ def __init__( class DeepseekVL2Config(PretrainedConfig): model_type = "deepseek_vl_v2" - vision_config: DeepseekVL2VisionEncoderConfig - projector_config: DeepseekVL2MlpProjectorConfig - language_config: DeepseekV2Config + vision_config: DeepseekVL2VisionEncoderConfig = None + projector_config: DeepseekVL2MlpProjectorConfig = None + language_config: DeepseekV2Config = None tile_tag: str = "2D" global_view_pos: str = "head" diff --git a/python/sglang/srt/configs/janus_pro.py b/python/sglang/srt/configs/janus_pro.py index d574953e95d9..47bb92d2fa41 100644 --- a/python/sglang/srt/configs/janus_pro.py +++ b/python/sglang/srt/configs/janus_pro.py @@ -123,14 +123,14 @@ class SigLIPVisionCfg: class MultiModalityConfig(PretrainedConfig): model_type = "multi_modality" - vision_config: VisionConfig - aligner_config: AlignerConfig + vision_config: VisionConfig = None + aligner_config: AlignerConfig = None - gen_vision_config: GenVisionConfig - gen_aligner_config: GenAlignerConfig - gen_head_config: GenHeadConfig + gen_vision_config: GenVisionConfig = None + gen_aligner_config: GenAlignerConfig = None + gen_head_config: GenHeadConfig = None - language_config: LlamaConfig + language_config: LlamaConfig = None def __init__(self, **kwargs): super().__init__(**kwargs) @@ -595,12 +595,12 @@ def batchify( class VLMImageProcessorConfig(PretrainedConfig): model_type = "deepseek_vlm" - image_size: int - min_size: int - image_mean: Union[Tuple[float, float, float], List[float]] - image_std: Union[Tuple[float, float, float], List[float]] - rescale_factor: float - do_normalize: bool + image_size: int = None + min_size: int = None + image_mean: Union[Tuple[float, float, float], List[float]] = None + image_std: Union[Tuple[float, float, float], List[float]] = None + rescale_factor: float = None + do_normalize: bool = None def __init__( self, diff --git a/python/sglang/srt/configs/jet_nemotron.py b/python/sglang/srt/configs/jet_nemotron.py index 1670da3b67f5..9fa172699f08 100644 --- a/python/sglang/srt/configs/jet_nemotron.py +++ b/python/sglang/srt/configs/jet_nemotron.py @@ -25,18 +25,18 @@ class JetBlockConfig: class JetNemotronConfig(PretrainedConfig): model_type: str = "jet_nemotron" - efficient_attention_config: dict[str, dict[str, Any]] - hidden_act: str - hidden_size: int - initializer_range: float - intermediate_size: int - layer_types: list[str] - max_position_embeddings: int - num_attention_heads: int - num_key_value_heads: int - rms_norm_eps: float - rope_scaling: None - rope_theta: float + efficient_attention_config: dict[str, dict[str, Any]] = None + hidden_act: str = None + hidden_size: int = None + initializer_range: float = None + intermediate_size: int = None + layer_types: list[str] = None + max_position_embeddings: int = None + num_attention_heads: int = None + num_key_value_heads: int = None + rms_norm_eps: float = None + rope_scaling: None = None + rope_theta: float = None @property def full_attention_layer_ids(self) -> list[int]: From 638f439a168db7add015e3a271ba51c3c065639c Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Mar 2026 23:20:16 +0000 Subject: [PATCH 23/27] fix: only pass reasoning_effort to chat template when explicitly set Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> --- .../srt/entrypoints/openai/serving_chat.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 9a218ffc3200..4782fa7e23eb 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -469,19 +469,20 @@ def _apply_jinja_template( self._handle_last_assistant_message(openai_compatible_messages, request) ) + extra_template_kwargs = {} + if request.reasoning_effort is not None: + extra_template_kwargs["reasoning_effort"] = request.reasoning_effort + if request.chat_template_kwargs: + extra_template_kwargs.update(request.chat_template_kwargs) + try: prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template( openai_compatible_messages, tokenize=True, add_generation_prompt=True, tools=tools, - reasoning_effort=request.reasoning_effort, - **( - request.chat_template_kwargs - if request.chat_template_kwargs - else {} - ), return_dict=False, + **extra_template_kwargs, ) except Exception as e: # If the first attempt fails, try with flat function-only format. @@ -497,13 +498,8 @@ def _apply_jinja_template( tokenize=True, add_generation_prompt=True, tools=tools, - reasoning_effort=request.reasoning_effort, - **( - request.chat_template_kwargs - if request.chat_template_kwargs - else {} - ), return_dict=False, + **extra_template_kwargs, ) except jinja2.TemplateError as template_error: # Template errors (e.g., from raise_exception in Jinja templates) @@ -1259,8 +1255,12 @@ def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: or request.chat_template_kwargs.get("enable_thinking") is not False ) if self.reasoning_parser in ["mistral"]: - # Mistral models only reason when reasoning_effort="high" - return request.reasoning_effort == "high" + # Mistral models only reason when reasoning_effort is explicitly + # set to a value other than None/"none" (typically "high"). + return ( + request.reasoning_effort is not None + and request.reasoning_effort != "none" + ) return True # default async def _process_tool_call_stream( From 77675d1f4a5f69bf424d8603dc61496ea4f3b266 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Mar 2026 23:20:44 +0000 Subject: [PATCH 24/27] fix: support multiple consecutive compact tool calls in Mistral detector Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> --- .../srt/function_call/mistral_detector.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/function_call/mistral_detector.py b/python/sglang/srt/function_call/mistral_detector.py index b1268b90fa0a..8cb412d3f00e 100644 --- a/python/sglang/srt/function_call/mistral_detector.py +++ b/python/sglang/srt/function_call/mistral_detector.py @@ -90,19 +90,27 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult return StreamingParseResult(normal_text=combined_normal, calls=calls) # Compact: `[TOOL_CALLS]tool_name[ARGS]{...}` - parsed = self._try_parse_compact_args_format(tool_part) - if not parsed: + # Loop to extract all consecutive compact tool calls. + all_calls: list = [] + remaining = tool_part + while remaining: + parsed = self._try_parse_compact_args_format(remaining) + if not parsed: + break + func_name, args_obj, consumed = parsed + new_calls = self.parse_base_json( + {"name": func_name, "arguments": args_obj}, tools + ) + all_calls.extend(new_calls) + remaining = remaining[consumed:].strip() + + if not all_calls: return StreamingParseResult(normal_text=normal_text, calls=[]) - func_name, args_obj, consumed = parsed - calls = self.parse_base_json({"name": func_name, "arguments": args_obj}, tools) - trailing_text = tool_part[consumed:].strip() combined_normal = ( - (normal_text + " " + trailing_text).strip() - if trailing_text - else normal_text + (normal_text + " " + remaining).strip() if remaining else normal_text ) - return StreamingParseResult(normal_text=combined_normal, calls=calls) + return StreamingParseResult(normal_text=combined_normal, calls=all_calls) def parse_streaming_increment( self, new_text: str, tools: List[Tool] From a6444027031c92bc5370fa1f974a45953e620c9d Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Tue, 17 Mar 2026 01:46:05 +0000 Subject: [PATCH 25/27] fix: workaround Mistral tokenizer marking [THINK]/[/THINK] as special tokens MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mistral's tokenizer defines [THINK] (id=34) and [/THINK] (id=35) as special tokens. When skip_special_tokens=True (the default), these tokens are stripped during decoding, making the reasoning parser unable to detect thinking boundaries and split reasoning_content from content. This is an upstream issue in the Mistral checkpoint/tokenizer config — reasoning markers should not be special tokens (cf. DeepSeek's /<​/think> which are regular tokens and work without workarounds). As a workaround, disable skip_special_tokens when the Mistral reasoning parser is active and reasoning_effort is set. --- .../sglang/srt/entrypoints/openai/serving_chat.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 4782fa7e23eb..1c8f4d94f6a2 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -333,6 +333,8 @@ def _process_messages( if self.is_gpt_oss: request.skip_special_tokens = False + self._patch_mistral_skip_special_tokens(request) + tool_call_constraint = None # Apply chat template and its stop strings @@ -1230,6 +1232,18 @@ def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int: idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa return idx + def _patch_mistral_skip_special_tokens( + self, request: ChatCompletionRequest + ) -> None: + """Mistral uses special tokens ([THINK]/[/THINK]) for reasoning markers, + which get stripped when skip_special_tokens=True.""" + if ( + self.reasoning_parser in ["mistral"] + and request.reasoning_effort is not None + and request.reasoning_effort != "none" + ): + request.skip_special_tokens = False + def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: """Judge whether the request needs reasoning for hybrid reasoning models NOTE: This is predefined based on model's chat template From 773f8517125b81576587d2f4e3701f86b692589a Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Tue, 17 Mar 2026 04:50:17 +0000 Subject: [PATCH 26/27] fix: support dense EAGLE speculative decoding for Mistral Small 4 The EAGLE draft model for Mistral Small 4 (mistralai/Mistral-Small-4-119B-2603-eagle) uses dense MLA layers without MoE, unlike the Mistral Large 3 EAGLE which has MoE. This caused three issues: 1. `adapt_config_dict` in mistral_utils.py did not handle dense EAGLE models (moe=null in params.json), falling through to an unsupported architecture. Fix: add a branch for `is_eagle and not is_moe` that sets model_type=deepseek_v3 with all-dense MoE overrides (first_k_dense_replace=num_layers). 2. `_remap_mistral_yarn_args` did not include rope_theta in rope_scaling, causing transformers yarn validation to fail. Fix: copy rope_theta into the rope_scaling dict. 3. `MistralLarge3ForCausalLMEagle.__init__` set `self.model_cls` but `DeepseekV2ForCausalLM.__init__` hardcodes `self.model = DeepseekV2Model`, so the EAGLE fc layer was never created. The draft model ran without fusing token embeddings with target hidden states, producing garbage draft tokens (accept rate 0.25). Fix: call super().__init__() then replace self.model with MistralLarge3EagleModel which has the fc layer. Accept rate: 0.25 -> 0.83. --- .../srt/models/mistral_large_3_eagle.py | 14 +++++++--- python/sglang/srt/utils/mistral_utils.py | 26 ++++++++++++++++++- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/models/mistral_large_3_eagle.py b/python/sglang/srt/models/mistral_large_3_eagle.py index a5ce7b6aabb6..0860b2801503 100644 --- a/python/sglang/srt/models/mistral_large_3_eagle.py +++ b/python/sglang/srt/models/mistral_large_3_eagle.py @@ -18,7 +18,10 @@ from sglang.srt.utils import add_prefix -class MistralLarge3Model(DeepseekV2Model): +class MistralLarge3EagleModel(DeepseekV2Model): + """EAGLE draft model with an fc layer that fuses token embeddings and + target-model hidden states before passing through transformer layers.""" + def __init__( self, config: PretrainedConfig, @@ -99,9 +102,14 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): - config.quant_config = quant_config - self.model_cls = MistralLarge3Model + # DeepseekV2ForCausalLM.__init__ hardcodes self.model = DeepseekV2Model. + # We let the parent init run (it sets up weight loading attrs, lm_head, + # etc.), then replace self.model with MistralLarge3EagleModel which has + # the EAGLE fc layer. The discarded 2-layer DeepseekV2Model is tiny. super().__init__(config=config, quant_config=quant_config, prefix=prefix) + self.model = MistralLarge3EagleModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) EntryClass = [MistralLarge3ForCausalLMEagle] diff --git a/python/sglang/srt/utils/mistral_utils.py b/python/sglang/srt/utils/mistral_utils.py index 52f8769b3084..4955c0575094 100644 --- a/python/sglang/srt/utils/mistral_utils.py +++ b/python/sglang/srt/utils/mistral_utils.py @@ -23,7 +23,27 @@ def adapt_config_dict( is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0 ) is_eagle = "eagle" in model.lower() - if is_moe: + if is_eagle and not is_moe: + # Dense EAGLE draft model (e.g. Mistral Small 4 EAGLE). + # Uses MLA attention like MistralLarge3 but has no MoE layers. + # Set model_type to deepseek_v3 for MLA support, and override + # MoE fields so all layers are dense. + config_dict["model_type"] = "deepseek_v3" + config_dict["architectures"] = ["MistralLarge3ForCausalLMEagle"] + num_layers = config_dict.get("num_hidden_layers", 0) + config_dict["n_routed_experts"] = 1 + config_dict["first_k_dense_replace"] = num_layers + config_dict["moe_layer_freq"] = 1 + config_dict["n_shared_experts"] = 0 + config_dict["n_group"] = 1 + config_dict["topk_group"] = 1 + config_dict["num_experts_per_tok"] = 1 + config_dict["moe_intermediate_size"] = 1 + config_dict["routed_scaling_factor"] = 1.0 + config_dict["topk_method"] = None + config_dict["scoring_func"] = "softmax" + config_dict["routing_method_type"] = 1 + elif is_moe: if is_mistral_large_3: config_dict = _remap_moe_args(config_dict) config_dict["model_type"] = "deepseek_v3" @@ -121,6 +141,10 @@ def _remap_mistral_yarn_args(config: dict) -> dict: "rope_type": "yarn", "mscale_all_dim": 1, } + # Include rope_theta in rope_scaling if present at the top level, + # as transformers yarn validation requires it. + if "rope_theta" in config: + config["rope_scaling"]["rope_theta"] = config["rope_theta"] for old_name, new_name in yarn_config_map.items(): if old_name in yarn_config: value = yarn_config.pop(old_name) From 2daf4ccc982d1f2ff36c66f6d7181ff91d7f1cfb Mon Sep 17 00:00:00 2001 From: dbari Date: Tue, 17 Mar 2026 16:36:20 +0000 Subject: [PATCH 27/27] fix: respect apply_scale=false in Mistral yarn RoPE config Mistral Small 4's params.json sets "apply_scale": false in the yarn config, meaning the mscale factor should NOT be applied to attention logits scaling. Previously this field was discarded, causing an incorrect 2.2x mscale to be applied unconditionally. Changes: - Map "apply_scale" to "apply_yarn_scaling" in rope_scaling dict instead of dropping it - Use "deepseek_yarn" rope_type to avoid transformers yarn validation issues - Gate mscale application in DeepseekV2AttentionMLA on apply_yarn_scaling gsm8k 5-shot exact_match: 0.7976 -> 0.8901 (+9.3%) --- python/sglang/srt/models/deepseek_v2.py | 2 +- python/sglang/srt/utils/mistral_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 8f061714223b..4750a6532095 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1198,7 +1198,7 @@ def __init__( device=get_global_server_args().device, ) - if rope_scaling: + if rope_scaling and rope_scaling.get("apply_yarn_scaling", True): mscale_all_dim = rope_scaling.get("mscale_all_dim", False) scaling_factor = rope_scaling["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) diff --git a/python/sglang/srt/utils/mistral_utils.py b/python/sglang/srt/utils/mistral_utils.py index 4955c0575094..dc9e08d945e0 100644 --- a/python/sglang/srt/utils/mistral_utils.py +++ b/python/sglang/srt/utils/mistral_utils.py @@ -134,11 +134,11 @@ def _remap_mistral_yarn_args(config: dict) -> dict: "original_max_position_embeddings": "original_max_position_embeddings", "beta": "beta_fast", "alpha": "beta_slow", - "apply_scale": None, + "apply_scale": "apply_yarn_scaling", } yarn_config = config.get("yarn") or {} config["rope_scaling"] = { - "rope_type": "yarn", + "rope_type": "deepseek_yarn", "mscale_all_dim": 1, } # Include rope_theta in rope_scaling if present at the top level,