diff --git a/benchmark/mmmu/bench_sglang.py b/benchmark/mmmu/bench_sglang.py index d9426ae5a3ac..0a28c7fc270c 100644 --- a/benchmark/mmmu/bench_sglang.py +++ b/benchmark/mmmu/bench_sglang.py @@ -11,11 +11,14 @@ import argparse import asyncio +import base64 +import mimetypes import re import sys import time import traceback from dataclasses import dataclass, field +from pathlib import Path from typing import Any, List, Optional, Tuple import aiohttp @@ -74,7 +77,12 @@ def _get_prefix_suffix(prompt: str) -> Tuple[str, str]: async def process_sample( - client: Any, sample: dict, sampling_params: dict, lora_path: Optional[str] = None + client: Any, + sample: dict, + sampling_params: dict, + model: str, + reasoning_effort: Optional[str] = None, + lora_path: Optional[str] = None, ) -> Tuple[dict, str]: """Send a single sample to the LLM and return (sample, response).""" prompt = sample["final_input_prompt"] @@ -82,25 +90,38 @@ async def process_sample( image = sample["image"] assert image is not None image_path = sample["image_path"] - extra_body = None if lora_path is None else {"lora_path": lora_path} + if image_path and not image_path.startswith(("http://", "https://", "data:")): + p = Path(image_path) + mime = mimetypes.guess_type(str(p))[0] or "image/png" + with open(p, "rb") as f: + b64 = base64.b64encode(f.read()).decode() + image_url = f"data:{mime};base64,{b64}" + else: + image_url = image_path + extra_body = {"lora_path": lora_path} if lora_path else None payload = { - "model": "default", + "model": model, "messages": [ { "role": "user", "content": [ {"type": "text", "text": prefix}, - {"type": "image_url", "image_url": {"url": image_path}}, + {"type": "image_url", "image_url": {"url": image_url}}, {"type": "text", "text": suffix}, ], } ], "extra_body": extra_body, + **sampling_params, } - if sampling_params: - payload.update(sampling_params) + if reasoning_effort: + payload["reasoning_effort"] = reasoning_effort response = await client.chat.completions.create(**payload) - return sample, response.choices[0].message.content + msg = response.choices[0].message + content = msg.content + if content is None: + content = getattr(msg, "reasoning_content", None) + return sample, content async def process_sample_with_semaphore( @@ -108,11 +129,15 @@ async def process_sample_with_semaphore( client: Any, sample: dict, sampling_params: dict, + model: str, + reasoning_effort: Optional[str] = None, lora_path: Optional[str] = None, ) -> Tuple[dict, str]: """Wrap process_sample with a semaphore for concurrency control.""" async with semaphore: - return await process_sample(client, sample, sampling_params, lora_path) + return await process_sample( + client, sample, sampling_params, model, reasoning_effort, lora_path + ) async def eval_mmmu(args) -> None: @@ -120,6 +145,8 @@ async def eval_mmmu(args) -> None: eval_args = EvalArgs.from_cli_args(args) sampling_params = get_sampling_params(eval_args) samples = prepare_samples(eval_args) + model = args.model + reasoning_effort = eval_args.reasoning_effort lora_path = eval_args.lora_path answer_dict = {} out_samples = {} @@ -146,7 +173,7 @@ async def eval_mmmu(args) -> None: # this is mainly for profiling for sample in tqdm(samples): _, response = await process_sample( - client, sample, sampling_params, lora_path + client, sample, sampling_params, model, reasoning_effort, lora_path ) sample["original_response"] = response answer = ( @@ -164,7 +191,13 @@ async def eval_mmmu(args) -> None: semaphore = asyncio.Semaphore(args.concurrency) tasks = [ process_sample_with_semaphore( - semaphore, client, sample, sampling_params, lora_path + semaphore, + client, + sample, + sampling_params, + model, + reasoning_effort, + lora_path, ) for sample in samples ] @@ -202,6 +235,12 @@ async def eval_mmmu(args) -> None: def parse_args(): parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + default="default", + help="Model name to use in API requests.", + ) EvalArgs.add_cli_args(parser) args = add_common_sglang_args_and_parse(parser) return args diff --git a/benchmark/mmmu/eval_utils.py b/benchmark/mmmu/eval_utils.py index b3edd69fc1ce..1a358ab6c827 100644 --- a/benchmark/mmmu/eval_utils.py +++ b/benchmark/mmmu/eval_utils.py @@ -40,6 +40,7 @@ class EvalArgs: temperature: Optional[float] = None response_answer_regex: str = "(.*)" lora_path: Optional[str] = None + reasoning_effort: Optional[str] = None @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -120,6 +121,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=EvalArgs.lora_path, help="Specify the LoRA path to use for evaluation. If specified, the value will be specified in the body of every request as `lora-path`.", ) + parser.add_argument( + "--reasoning-effort", + type=str, + default=EvalArgs.reasoning_effort, + choices=["none", "high"], + help="Reasoning effort for the model (none or high).", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/python/sglang/srt/configs/deepseek_ocr.py b/python/sglang/srt/configs/deepseek_ocr.py index 1677423d1811..b742ff036bc0 100644 --- a/python/sglang/srt/configs/deepseek_ocr.py +++ b/python/sglang/srt/configs/deepseek_ocr.py @@ -781,8 +781,8 @@ def __init__( class DeepseekVLV2Config(PretrainedConfig): # model_type = "deepseek_vl_v2" model_type = "deepseek-ocr" - vision_config: VisionEncoderConfig - projector_config: MlpProjectorConfig + vision_config: VisionEncoderConfig = None + projector_config: MlpProjectorConfig = None tile_tag: str = "2D" global_view_pos: str = "head" diff --git a/python/sglang/srt/configs/deepseekvl2.py b/python/sglang/srt/configs/deepseekvl2.py index 9621f058bf63..e8f784258954 100644 --- a/python/sglang/srt/configs/deepseekvl2.py +++ b/python/sglang/srt/configs/deepseekvl2.py @@ -649,9 +649,9 @@ def __init__( class DeepseekVL2Config(PretrainedConfig): model_type = "deepseek_vl_v2" - vision_config: DeepseekVL2VisionEncoderConfig - projector_config: DeepseekVL2MlpProjectorConfig - language_config: DeepseekV2Config + vision_config: DeepseekVL2VisionEncoderConfig = None + projector_config: DeepseekVL2MlpProjectorConfig = None + language_config: DeepseekV2Config = None tile_tag: str = "2D" global_view_pos: str = "head" diff --git a/python/sglang/srt/configs/janus_pro.py b/python/sglang/srt/configs/janus_pro.py index d574953e95d9..47bb92d2fa41 100644 --- a/python/sglang/srt/configs/janus_pro.py +++ b/python/sglang/srt/configs/janus_pro.py @@ -123,14 +123,14 @@ class SigLIPVisionCfg: class MultiModalityConfig(PretrainedConfig): model_type = "multi_modality" - vision_config: VisionConfig - aligner_config: AlignerConfig + vision_config: VisionConfig = None + aligner_config: AlignerConfig = None - gen_vision_config: GenVisionConfig - gen_aligner_config: GenAlignerConfig - gen_head_config: GenHeadConfig + gen_vision_config: GenVisionConfig = None + gen_aligner_config: GenAlignerConfig = None + gen_head_config: GenHeadConfig = None - language_config: LlamaConfig + language_config: LlamaConfig = None def __init__(self, **kwargs): super().__init__(**kwargs) @@ -595,12 +595,12 @@ def batchify( class VLMImageProcessorConfig(PretrainedConfig): model_type = "deepseek_vlm" - image_size: int - min_size: int - image_mean: Union[Tuple[float, float, float], List[float]] - image_std: Union[Tuple[float, float, float], List[float]] - rescale_factor: float - do_normalize: bool + image_size: int = None + min_size: int = None + image_mean: Union[Tuple[float, float, float], List[float]] = None + image_std: Union[Tuple[float, float, float], List[float]] = None + rescale_factor: float = None + do_normalize: bool = None def __init__( self, diff --git a/python/sglang/srt/configs/jet_nemotron.py b/python/sglang/srt/configs/jet_nemotron.py index 1670da3b67f5..9fa172699f08 100644 --- a/python/sglang/srt/configs/jet_nemotron.py +++ b/python/sglang/srt/configs/jet_nemotron.py @@ -25,18 +25,18 @@ class JetBlockConfig: class JetNemotronConfig(PretrainedConfig): model_type: str = "jet_nemotron" - efficient_attention_config: dict[str, dict[str, Any]] - hidden_act: str - hidden_size: int - initializer_range: float - intermediate_size: int - layer_types: list[str] - max_position_embeddings: int - num_attention_heads: int - num_key_value_heads: int - rms_norm_eps: float - rope_scaling: None - rope_theta: float + efficient_attention_config: dict[str, dict[str, Any]] = None + hidden_act: str = None + hidden_size: int = None + initializer_range: float = None + intermediate_size: int = None + layer_types: list[str] = None + max_position_embeddings: int = None + num_attention_heads: int = None + num_key_value_heads: int = None + rms_norm_eps: float = None + rope_scaling: None = None + rope_theta: float = None @property def full_attention_layer_ids(self) -> list[int]: diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 1f44b08fe07d..16fa9cbc807c 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -589,7 +589,7 @@ class ChatCompletionRequest(BaseModel): return_routed_experts: bool = False return_cached_tokens_details: bool = False reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = Field( - default="medium", + default=None, description="Constrains effort on reasoning for reasoning models. " "'none' disables reasoning entirely, 'low' is the least effort, 'high' is the most effort. " "Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning " diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 9e0f74dd3760..1c8f4d94f6a2 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -333,6 +333,8 @@ def _process_messages( if self.is_gpt_oss: request.skip_special_tokens = False + self._patch_mistral_skip_special_tokens(request) + tool_call_constraint = None # Apply chat template and its stop strings @@ -469,19 +471,20 @@ def _apply_jinja_template( self._handle_last_assistant_message(openai_compatible_messages, request) ) + extra_template_kwargs = {} + if request.reasoning_effort is not None: + extra_template_kwargs["reasoning_effort"] = request.reasoning_effort + if request.chat_template_kwargs: + extra_template_kwargs.update(request.chat_template_kwargs) + try: prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template( openai_compatible_messages, tokenize=True, add_generation_prompt=True, tools=tools, - reasoning_effort=request.reasoning_effort, - **( - request.chat_template_kwargs - if request.chat_template_kwargs - else {} - ), return_dict=False, + **extra_template_kwargs, ) except Exception as e: # If the first attempt fails, try with flat function-only format. @@ -497,13 +500,8 @@ def _apply_jinja_template( tokenize=True, add_generation_prompt=True, tools=tools, - reasoning_effort=request.reasoning_effort, - **( - request.chat_template_kwargs - if request.chat_template_kwargs - else {} - ), return_dict=False, + **extra_template_kwargs, ) except jinja2.TemplateError as template_error: # Template errors (e.g., from raise_exception in Jinja templates) @@ -1234,8 +1232,22 @@ def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int: idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa return idx + def _patch_mistral_skip_special_tokens( + self, request: ChatCompletionRequest + ) -> None: + """Mistral uses special tokens ([THINK]/[/THINK]) for reasoning markers, + which get stripped when skip_special_tokens=True.""" + if ( + self.reasoning_parser in ["mistral"] + and request.reasoning_effort is not None + and request.reasoning_effort != "none" + ): + request.skip_special_tokens = False + def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: - """Judge whether the request needs reasoning""" + """Judge whether the request needs reasoning for hybrid reasoning models + NOTE: This is predefined based on model's chat template + """ if not self.reasoning_parser: return False if self.reasoning_parser in ["deepseek-v3"]: @@ -1256,6 +1268,13 @@ def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: not request.chat_template_kwargs or request.chat_template_kwargs.get("enable_thinking") is not False ) + if self.reasoning_parser in ["mistral"]: + # Mistral models only reason when reasoning_effort is explicitly + # set to a value other than None/"none" (typically "high"). + return ( + request.reasoning_effort is not None + and request.reasoning_effort != "none" + ) return True # default async def _process_tool_call_stream( diff --git a/python/sglang/srt/function_call/mistral_detector.py b/python/sglang/srt/function_call/mistral_detector.py index b1268b90fa0a..8cb412d3f00e 100644 --- a/python/sglang/srt/function_call/mistral_detector.py +++ b/python/sglang/srt/function_call/mistral_detector.py @@ -90,19 +90,27 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult return StreamingParseResult(normal_text=combined_normal, calls=calls) # Compact: `[TOOL_CALLS]tool_name[ARGS]{...}` - parsed = self._try_parse_compact_args_format(tool_part) - if not parsed: + # Loop to extract all consecutive compact tool calls. + all_calls: list = [] + remaining = tool_part + while remaining: + parsed = self._try_parse_compact_args_format(remaining) + if not parsed: + break + func_name, args_obj, consumed = parsed + new_calls = self.parse_base_json( + {"name": func_name, "arguments": args_obj}, tools + ) + all_calls.extend(new_calls) + remaining = remaining[consumed:].strip() + + if not all_calls: return StreamingParseResult(normal_text=normal_text, calls=[]) - func_name, args_obj, consumed = parsed - calls = self.parse_base_json({"name": func_name, "arguments": args_obj}, tools) - trailing_text = tool_part[consumed:].strip() combined_normal = ( - (normal_text + " " + trailing_text).strip() - if trailing_text - else normal_text + (normal_text + " " + remaining).strip() if remaining else normal_text ) - return StreamingParseResult(normal_text=combined_normal, calls=calls) + return StreamingParseResult(normal_text=combined_normal, calls=all_calls) def parse_streaming_increment( self, new_text: str, tools: List[Tool] diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index 1a3ef033b891..62005ed46311 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -465,7 +465,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( # Move kernel call outside context manager to avoid graph breaks # during torch.compile for piecewise cuda graph. # Use custom op wrapper for torch.compile compatibility. - output = torch.ops.sglang.trtllm_fp8_per_tensor_scale_moe( + output = torch.ops.sglang.trtllm_fp8_per_tensor_scale_moe_wrapper( routing_logits=router_logits.to(torch.bfloat16), routing_bias=routing_bias_cast, hidden_states=a_q, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 9f407bd546c0..57bae2913561 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1216,8 +1216,11 @@ def __init__( device=get_global_server_args().device, ) - if rope_scaling: - self.scaling = compute_mla_mscale_scaling(rope_scaling, self.scaling) + if rope_scaling and rope_scaling.get("apply_yarn_scaling", True): + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale else: self.rotary_emb = None self.use_deepseek_yarn_rope = rope_scaling is not None diff --git a/python/sglang/srt/models/mistral_large_3_eagle.py b/python/sglang/srt/models/mistral_large_3_eagle.py index a5ce7b6aabb6..0860b2801503 100644 --- a/python/sglang/srt/models/mistral_large_3_eagle.py +++ b/python/sglang/srt/models/mistral_large_3_eagle.py @@ -18,7 +18,10 @@ from sglang.srt.utils import add_prefix -class MistralLarge3Model(DeepseekV2Model): +class MistralLarge3EagleModel(DeepseekV2Model): + """EAGLE draft model with an fc layer that fuses token embeddings and + target-model hidden states before passing through transformer layers.""" + def __init__( self, config: PretrainedConfig, @@ -99,9 +102,14 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): - config.quant_config = quant_config - self.model_cls = MistralLarge3Model + # DeepseekV2ForCausalLM.__init__ hardcodes self.model = DeepseekV2Model. + # We let the parent init run (it sets up weight loading attrs, lm_head, + # etc.), then replace self.model with MistralLarge3EagleModel which has + # the EAGLE fc layer. The discarded 2-layer DeepseekV2Model is tiny. super().__init__(config=config, quant_config=quant_config, prefix=prefix) + self.model = MistralLarge3EagleModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) EntryClass = [MistralLarge3ForCausalLMEagle] diff --git a/python/sglang/srt/models/pixtral.py b/python/sglang/srt/models/pixtral.py index 2ce96da00c97..ac0f351a994e 100644 --- a/python/sglang/srt/models/pixtral.py +++ b/python/sglang/srt/models/pixtral.py @@ -83,10 +83,13 @@ def __init__(self, *, config, prefix: str = "", **kwargs): super().__init__() self.config = config dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} + config_dict = self.config.vision_config.to_dict() + if config_dict.get("rope_parameters"): # transformers v5 compatibility + config_dict["rope_theta"] = config_dict["rope_parameters"].get("rope_theta") + config_dict["rope_scaling"] = config_dict["rope_parameters"] + config_dict.pop("rope_parameters") vision_args = { - key: value - for key, value in self.config.vision_config.to_dict().items() - if key in dataclass_fields + key: value for key, value in config_dict.items() if key in dataclass_fields } self.vision_args = VisionEncoderArgs(**vision_args) diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index b923ff342a19..47b1513e8fd6 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -1,11 +1,12 @@ -import asyncio import math from typing import List, Union +from transformers import PreTrainedTokenizerBase from transformers.models.pixtral.image_processing_pixtral import ( _num_image_tokens as _get_pixtral_hf_num_image_tokens, ) +from sglang.srt.managers.schedule_batch import Modality from sglang.srt.models.pixtral import ( PixtralForConditionalGeneration, PixtralVisionModel, @@ -20,63 +21,47 @@ class PixtralProcessor(BaseMultimodalProcessor): models = [PixtralVisionModel, PixtralForConditionalGeneration] PAD_TOKEN = "" - IMG_BREAK_TOKEN_ID = 12 - IMG_END_TOKEN_ID = 13 - - def get_patch_grid_size( - self, - *, - image_width: int, - image_height: int, - ) -> tuple[int, int]: - max_width = max_height = self.image_size - patch_width = patch_height = self.patch_size - - ratio = max(image_width / max_width, image_height / max_height) - if ratio > 1: - image_width = int(math.floor(image_width / ratio)) - image_height = int(math.floor(image_height / ratio)) - - nrows, ncols = _get_pixtral_hf_num_image_tokens( - (image_height, image_width), - (patch_height, patch_width), - ) - - return ncols, nrows + DEFAULT_IMAGE_TOKEN = "[IMG]" def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) self.IM_TOKEN_ID = getattr( hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID ) - # Instantiate the patcher logic helper using the class defined above self.vision_config = hf_config.vision_config self.image_size = self.vision_config.image_size self.patch_size = self.vision_config.patch_size + # spatial_merge_size may live on vision_config (Mistral native) or + # on the top-level config (HF native Mistral3Config). + self._spatial_merge_size = getattr( + self.vision_config, + "spatial_merge_size", + getattr(hf_config, "spatial_merge_size", 1), + ) + self._processor.patch_size = self.patch_size - if hasattr(self.vision_config, "spatial_merge_size"): - self._processor.spatial_merge_size = self.vision_config.spatial_merge_size + if self._spatial_merge_size > 1: + self._processor.spatial_merge_size = self._spatial_merge_size + + tokenizer = ( + _processor + if isinstance(_processor, PreTrainedTokenizerBase) + else _processor.tokenizer + ) + self.image_token = getattr(_processor, "image_token", self.DEFAULT_IMAGE_TOKEN) self.mm_tokens = MultimodalSpecialTokens( - image_token=_processor.image_token, + image_token=self.image_token, image_token_id=self.IM_TOKEN_ID, ).build(_processor) - _processor.tokenizer.add_special_tokens( + tokenizer.add_special_tokens( { "pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN), } ) - async def _resize(self, image): - num_w_tokens, num_h_tokens = self.get_patch_grid_size( - image_width=image.size[0], - image_height=image.size[1], - ) - new_size = (num_w_tokens * self.patch_size, num_h_tokens * self.patch_size) - return image.resize(new_size) - async def process_mm_data_async( self, image_data: List[Union[str, bytes]], @@ -92,16 +77,58 @@ async def process_mm_data_async( return_text=True, ) if mm_data.images: - resize_tasks = [self._resize(image) for image in mm_data.images] - mm_data.images = await asyncio.gather(*resize_tasks) - - mm_items, input_ids, _ = self.process_and_combine_mm_data( - mm_data, self.mm_tokens - ) + effective_patch = self.patch_size * self._spatial_merge_size + image_nrows = [] + for img in mm_data.images: + w, h = img.size + ratio = max(w / self.image_size, h / self.image_size) + if ratio > 1: + w = int(math.floor(w / ratio)) + h = int(math.floor(h / ratio)) + nrows, _ = _get_pixtral_hf_num_image_tokens( + (h, w), (effective_patch, effective_patch) + ) + image_nrows.append(nrows) + + mm_items, input_ids, _ = self.process_and_combine_mm_data( + mm_data, self.mm_tokens + ) + + # For multi-image: split single IMAGE mm_item into per-image items + if len(mm_data.images) > 1: + from sglang.srt.managers.schedule_batch import MultimodalDataItem + + old_item = next( + item for item in mm_items if item.modality == Modality.IMAGE + ) + all_offsets = old_item.offsets + old_feature = old_item.feature + old_image_sizes = getattr(old_item, "image_sizes", None) + + mm_items = [ + item for item in mm_items if item.modality != Modality.IMAGE + ] + offset_idx = 0 + for i, img in enumerate(mm_data.images): + nr = image_nrows[i] + item_offsets = all_offsets[offset_idx : offset_idx + nr] + offset_idx += nr + new_item = MultimodalDataItem(modality=Modality.IMAGE) + new_item.feature = old_feature[i : i + 1] + new_item.offsets = item_offsets + if old_image_sizes is not None: + new_item.model_specific_data["image_sizes"] = old_image_sizes[ + i : i + 1 + ] + mm_items.append(new_item) + else: + mm_items, input_ids, _ = self.process_and_combine_mm_data( + mm_data, self.mm_tokens + ) return { "mm_items": mm_items, "input_ids": input_ids.tolist(), "im_token_id": self.IM_TOKEN_ID, - "im_token": self._processor.image_token, + "im_token": self.image_token, } diff --git a/python/sglang/srt/parser/reasoning_parser.py b/python/sglang/srt/parser/reasoning_parser.py index d2d4d5380f79..a6867c9f8b54 100644 --- a/python/sglang/srt/parser/reasoning_parser.py +++ b/python/sglang/srt/parser/reasoning_parser.py @@ -450,6 +450,33 @@ def detect_and_parse(self, text: str) -> StreamingParseResult: return ret +class MistralDetector(BaseReasoningFormatDetector): + """ + Detector for Mistral models with reasoning (e.g., Mistral-Small-4-119B-2603). + Assumes reasoning format: + [THINK]reasoning content[/THINK]answer + + Reasoning is optional — it only appears when reasoning_effort="high" is set. + When reasoning_effort="none", the model outputs directly without thinking tokens. + """ + + def __init__( + self, + stream_reasoning: bool = True, + force_reasoning: bool = False, + continue_final_message: bool = False, + previous_content: str = "", + ): + super().__init__( + "[THINK]", + "[/THINK]", + force_reasoning=force_reasoning, + stream_reasoning=stream_reasoning, + continue_final_message=continue_final_message, + previous_content=previous_content, + ) + + class ReasoningParser: """ Parser that handles both streaming and non-streaming scenarios for extracting @@ -474,6 +501,7 @@ class ReasoningParser: "minimax-append-think": MiniMaxAppendThinkDetector, "step3": DeepSeekR1Detector, "step3p5": DeepSeekR1Detector, + "mistral": MistralDetector, "nemotron_3": Nemotron3Detector, "interns1": Qwen3Detector, } diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index bad3d2529877..ed815b16a387 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -83,6 +83,7 @@ "sharded_state", "gguf", "bitsandbytes", + "mistral", "layered", "flash_rl", "remote", @@ -2963,6 +2964,12 @@ def _handle_load_format(self): ) and check_gguf_file(self.model_path): self.quantization = self.load_format = "gguf" + if self.load_format == "auto" and self._is_mistral_native_format(): + self.load_format = "mistral" + logger.info( + "Detected Mistral native format checkpoint, setting load_format='mistral'" + ) + if is_remote_url(self.model_path): self.load_format = "remote" @@ -3013,6 +3020,19 @@ def _handle_load_format(self): self.validate_transfer_engine() ) + def _is_mistral_native_format(self) -> bool: + """Detect if the model uses Mistral native format (params.json + consolidated weights).""" + if os.path.isdir(self.model_path): + return os.path.exists(os.path.join(self.model_path, "params.json")) + # For hub models, check remote files + try: + from huggingface_hub import HfApi + + files = {s.rfilename for s in HfApi().model_info(self.model_path).siblings} + return "params.json" in files + except Exception: + return False + def _handle_pd_disaggregation(self): if self.disaggregation_mode == "decode": self.disable_radix_cache = True diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 8e9e01e7c428..405c350383b0 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -19,6 +19,7 @@ import os import tempfile import warnings +from functools import lru_cache from pathlib import Path from typing import Any, Dict, List, Optional, Type, Union @@ -267,11 +268,11 @@ def _load_deepseek_v32_model( # Temporary hack for Mistral Large +@lru_cache(maxsize=2) def _load_mistral_large_3_for_causal_LM( model_path: str, trust_remote_code: bool = False, revision: Optional[str] = None, - **kwargs, ): # first get the local path local_path = download_from_hf(model_path) @@ -283,7 +284,7 @@ def _load_mistral_large_3_for_causal_LM( json.dump(config_dict, f) f.flush() loaded_config = AutoConfig.from_pretrained( - f.name, trust_remote_code=trust_remote_code, revision=revision, **kwargs + f.name, trust_remote_code=trust_remote_code, revision=revision ) text_config = getattr(loaded_config, "text_config", None) if text_config is not None and isinstance(text_config, dict): @@ -477,9 +478,13 @@ def get_config( client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) model = client.get_local_dir() - if "mistral-large-3" in str(model).lower(): + if ( + "mistral-large-3" in str(model).lower() + or "mistral-small-4" in str(model).lower() + or "leanstral" in str(model).lower() + ): config = _load_mistral_large_3_for_causal_LM( - model, trust_remote_code=trust_remote_code, revision=revision, **kwargs + model, trust_remote_code=trust_remote_code, revision=revision ) else: _ensure_llama_flash_attention2_compat() @@ -1104,12 +1109,15 @@ def get_processor( ): # pop 'revision' from kwargs if present. revision = kwargs.pop("revision", tokenizer_revision) - if "mistral-large-3" in str(tokenizer_name).lower(): + if ( + "mistral-large-3" in str(tokenizer_name).lower() + or "mistral-small-4" in str(tokenizer_name).lower() + or "leanstral" in str(tokenizer_name).lower() + ): config = _load_mistral_large_3_for_causal_LM( tokenizer_name, trust_remote_code=trust_remote_code, revision=revision, - **kwargs, ) else: _ensure_llama_flash_attention2_compat() @@ -1192,8 +1200,49 @@ def get_processor( ) else: raise e + # If processor is a bare tokenizer (e.g. Mistral-Small-4 has no processor_config.json) + # and the model is a vision model (pixtral), wrap it in a proper PixtralProcessor + # so that image data is actually processed through the image processor. + if ( + isinstance(processor, PreTrainedTokenizerBase) + and getattr(config, "model_type", None) == "pixtral" + ): + from transformers.models.pixtral.image_processing_pixtral import ( + PixtralImageProcessor, + ) + from transformers.models.pixtral.processing_pixtral import ( + PixtralProcessor as HFPixtralProcessor, + ) + + vision_config = config.vision_config + patch_size = vision_config.patch_size + image_size = vision_config.image_size + spatial_merge_size = getattr(vision_config, "spatial_merge_size", 1) + + effective_patch = patch_size * spatial_merge_size + image_processor = PixtralImageProcessor( + do_resize=True, + size={"longest_edge": image_size}, + patch_size={"height": effective_patch, "width": effective_patch}, + ) + processor = HFPixtralProcessor( + image_processor=image_processor, + tokenizer=processor, + patch_size=patch_size, + spatial_merge_size=spatial_merge_size, + ) + tokenizer = get_tokenizer_from_processor(processor) + if tokenizer.chat_template is None: + local_path = download_from_hf( + tokenizer_name, allow_patterns=["*.json", "*.jinja", "*.model"] + ) + jinja_path = Path(local_path) / "chat_template.jinja" + if jinja_path.is_file(): + tokenizer.chat_template = jinja_path.read_text() + logger.info("Loaded chat_template from %s", jinja_path) + _fix_special_tokens_pattern(tokenizer) _fix_added_tokens_encoding(tokenizer) attach_additional_stop_token_ids(tokenizer) diff --git a/python/sglang/srt/utils/mistral_utils.py b/python/sglang/srt/utils/mistral_utils.py index 52f8769b3084..dc9e08d945e0 100644 --- a/python/sglang/srt/utils/mistral_utils.py +++ b/python/sglang/srt/utils/mistral_utils.py @@ -23,7 +23,27 @@ def adapt_config_dict( is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0 ) is_eagle = "eagle" in model.lower() - if is_moe: + if is_eagle and not is_moe: + # Dense EAGLE draft model (e.g. Mistral Small 4 EAGLE). + # Uses MLA attention like MistralLarge3 but has no MoE layers. + # Set model_type to deepseek_v3 for MLA support, and override + # MoE fields so all layers are dense. + config_dict["model_type"] = "deepseek_v3" + config_dict["architectures"] = ["MistralLarge3ForCausalLMEagle"] + num_layers = config_dict.get("num_hidden_layers", 0) + config_dict["n_routed_experts"] = 1 + config_dict["first_k_dense_replace"] = num_layers + config_dict["moe_layer_freq"] = 1 + config_dict["n_shared_experts"] = 0 + config_dict["n_group"] = 1 + config_dict["topk_group"] = 1 + config_dict["num_experts_per_tok"] = 1 + config_dict["moe_intermediate_size"] = 1 + config_dict["routed_scaling_factor"] = 1.0 + config_dict["topk_method"] = None + config_dict["scoring_func"] = "softmax" + config_dict["routing_method_type"] = 1 + elif is_moe: if is_mistral_large_3: config_dict = _remap_moe_args(config_dict) config_dict["model_type"] = "deepseek_v3" @@ -114,13 +134,17 @@ def _remap_mistral_yarn_args(config: dict) -> dict: "original_max_position_embeddings": "original_max_position_embeddings", "beta": "beta_fast", "alpha": "beta_slow", - "apply_scale": None, + "apply_scale": "apply_yarn_scaling", } yarn_config = config.get("yarn") or {} config["rope_scaling"] = { - "rope_type": "yarn", + "rope_type": "deepseek_yarn", "mscale_all_dim": 1, } + # Include rope_theta in rope_scaling if present at the top level, + # as transformers yarn validation requires it. + if "rope_theta" in config: + config["rope_scaling"]["rope_theta"] = config["rope_theta"] for old_name, new_name in yarn_config_map.items(): if old_name in yarn_config: value = yarn_config.pop(old_name)