Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7eeffcb
Add Mistral4/Pixtral support changes
JustinTong0323 Feb 28, 2026
c3297fc
lint
JustinTong0323 Mar 5, 2026
296fcd5
Add special handling for mistral 4
JustinTong0323 Mar 5, 2026
c7457c9
add reasoning parser for mistral
JustinTong0323 Mar 5, 2026
557d6fd
Set default reasoning_effort to None in ChatCompletionRequest
JustinTong0323 Mar 5, 2026
2c4349f
fix: Add activation type mapping for FlashInfer in moe_runner
JustinTong0323 Mar 5, 2026
0322d01
fix: add reasoning request handling for mistral 4
JustinTong0323 Mar 6, 2026
4802ecc
fix: streamline vision config handling in get_processor function
JustinTong0323 Mar 6, 2026
04a8673
fix: adjust patch grid size calculation to incorporate spatial merge …
JustinTong0323 Mar 6, 2026
3ae9d1e
fix(pixtral): use effective_patch_size for image resize and simplify …
JustinTong0323 Mar 12, 2026
07744ec
feat(mmmu): add --model and --reasoning-effort flags to benchmark
JustinTong0323 Mar 12, 2026
0f1471e
cleanup: remove redundant activation_type mapping and unused ncols va…
JustinTong0323 Mar 12, 2026
e10aa5a
fix reasoning trace having answer and benchmark getting no answers ev…
Mar 16, 2026
2041c65
possible fix for -HF chkpt
Mar 16, 2026
01c72d6
LeanStral works
Mar 16, 2026
0da0ae3
Merge branch 'main' into mistral4-support
JustinTong0323 Mar 16, 2026
afe8772
lint
JustinTong0323 Mar 16, 2026
d508481
fix: update model name in MistralDetector docstring (2602 -> 2603)
JustinTong0323 Mar 16, 2026
34a699f
fix: expose mistral load format and update MistralDetector docstring
JustinTong0323 Mar 16, 2026
7da7666
fix: use correct custom op name for trtllm_fp8_per_tensor_scale_moe_w…
JustinTong0323 Mar 16, 2026
c04df33
feat: auto-detect Mistral native format and set load_format='mistral'
JustinTong0323 Mar 16, 2026
e22540b
lint
JustinTong0323 Mar 16, 2026
bbc7267
fix: add defaults to PretrainedConfig subclass annotations for transf…
JustinTong0323 Mar 16, 2026
638f439
fix: only pass reasoning_effort to chat template when explicitly set
JustinTong0323 Mar 16, 2026
77675d1
fix: support multiple consecutive compact tool calls in Mistral detector
JustinTong0323 Mar 16, 2026
a644402
fix: workaround Mistral tokenizer marking [THINK]/[/THINK] as special…
JustinTong0323 Mar 17, 2026
943abd5
Merge branch 'main' into mistral4-support
JustinTong0323 Mar 17, 2026
773f851
fix: support dense EAGLE speculative decoding for Mistral Small 4
JustinTong0323 Mar 17, 2026
2daf4cc
fix: respect apply_scale=false in Mistral yarn RoPE config
dbari Mar 17, 2026
5d7e3d5
Merge branch 'main' into mistral4-support
JustinTong0323 Mar 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 49 additions & 10 deletions benchmark/mmmu/bench_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@

import argparse
import asyncio
import base64
import mimetypes
import re
import sys
import time
import traceback
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, List, Optional, Tuple

import aiohttp
Expand Down Expand Up @@ -74,52 +77,76 @@ def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:


async def process_sample(
client: Any, sample: dict, sampling_params: dict, lora_path: Optional[str] = None
client: Any,
sample: dict,
sampling_params: dict,
model: str,
reasoning_effort: Optional[str] = None,
lora_path: Optional[str] = None,
) -> Tuple[dict, str]:
"""Send a single sample to the LLM and return (sample, response)."""
prompt = sample["final_input_prompt"]
prefix, suffix = _get_prefix_suffix(prompt)
image = sample["image"]
assert image is not None
image_path = sample["image_path"]
extra_body = None if lora_path is None else {"lora_path": lora_path}
if image_path and not image_path.startswith(("http://", "https://", "data:")):
p = Path(image_path)
mime = mimetypes.guess_type(str(p))[0] or "image/png"
with open(p, "rb") as f:
b64 = base64.b64encode(f.read()).decode()
image_url = f"data:{mime};base64,{b64}"
else:
image_url = image_path
extra_body = {"lora_path": lora_path} if lora_path else None
payload = {
"model": "default",
"model": model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prefix},
{"type": "image_url", "image_url": {"url": image_path}},
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": suffix},
],
}
],
"extra_body": extra_body,
**sampling_params,
}
if sampling_params:
payload.update(sampling_params)
if reasoning_effort:
payload["reasoning_effort"] = reasoning_effort
response = await client.chat.completions.create(**payload)
return sample, response.choices[0].message.content
msg = response.choices[0].message
content = msg.content
if content is None:
content = getattr(msg, "reasoning_content", None)
return sample, content


async def process_sample_with_semaphore(
semaphore: asyncio.Semaphore,
client: Any,
sample: dict,
sampling_params: dict,
model: str,
reasoning_effort: Optional[str] = None,
lora_path: Optional[str] = None,
) -> Tuple[dict, str]:
"""Wrap process_sample with a semaphore for concurrency control."""
async with semaphore:
return await process_sample(client, sample, sampling_params, lora_path)
return await process_sample(
client, sample, sampling_params, model, reasoning_effort, lora_path
)


async def eval_mmmu(args) -> None:
"""Main evaluation loop with concurrency control."""
eval_args = EvalArgs.from_cli_args(args)
sampling_params = get_sampling_params(eval_args)
samples = prepare_samples(eval_args)
model = args.model
reasoning_effort = eval_args.reasoning_effort
lora_path = eval_args.lora_path
answer_dict = {}
out_samples = {}
Expand All @@ -146,7 +173,7 @@ async def eval_mmmu(args) -> None:
# this is mainly for profiling
for sample in tqdm(samples):
_, response = await process_sample(
client, sample, sampling_params, lora_path
client, sample, sampling_params, model, reasoning_effort, lora_path
)
sample["original_response"] = response
answer = (
Expand All @@ -164,7 +191,13 @@ async def eval_mmmu(args) -> None:
semaphore = asyncio.Semaphore(args.concurrency)
tasks = [
process_sample_with_semaphore(
semaphore, client, sample, sampling_params, lora_path
semaphore,
client,
sample,
sampling_params,
model,
reasoning_effort,
lora_path,
)
for sample in samples
]
Expand Down Expand Up @@ -202,6 +235,12 @@ async def eval_mmmu(args) -> None:

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="default",
help="Model name to use in API requests.",
)
EvalArgs.add_cli_args(parser)
args = add_common_sglang_args_and_parse(parser)
return args
Expand Down
8 changes: 8 additions & 0 deletions benchmark/mmmu/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class EvalArgs:
temperature: Optional[float] = None
response_answer_regex: str = "(.*)"
lora_path: Optional[str] = None
reasoning_effort: Optional[str] = None

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -120,6 +121,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=EvalArgs.lora_path,
help="Specify the LoRA path to use for evaluation. If specified, the value will be specified in the body of every request as `lora-path`.",
)
parser.add_argument(
"--reasoning-effort",
type=str,
default=EvalArgs.reasoning_effort,
choices=["none", "high"],
Comment thread
JustinTong0323 marked this conversation as resolved.
help="Reasoning effort for the model (none or high).",
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/configs/deepseek_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,8 @@ def __init__(
class DeepseekVLV2Config(PretrainedConfig):
# model_type = "deepseek_vl_v2"
model_type = "deepseek-ocr"
vision_config: VisionEncoderConfig
projector_config: MlpProjectorConfig
vision_config: VisionEncoderConfig = None
projector_config: MlpProjectorConfig = None

tile_tag: str = "2D"
global_view_pos: str = "head"
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/configs/deepseekvl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,9 +649,9 @@ def __init__(

class DeepseekVL2Config(PretrainedConfig):
model_type = "deepseek_vl_v2"
vision_config: DeepseekVL2VisionEncoderConfig
projector_config: DeepseekVL2MlpProjectorConfig
language_config: DeepseekV2Config
vision_config: DeepseekVL2VisionEncoderConfig = None
projector_config: DeepseekVL2MlpProjectorConfig = None
language_config: DeepseekV2Config = None

tile_tag: str = "2D"
global_view_pos: str = "head"
Expand Down
24 changes: 12 additions & 12 deletions python/sglang/srt/configs/janus_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ class SigLIPVisionCfg:

class MultiModalityConfig(PretrainedConfig):
model_type = "multi_modality"
vision_config: VisionConfig
aligner_config: AlignerConfig
vision_config: VisionConfig = None
aligner_config: AlignerConfig = None

gen_vision_config: GenVisionConfig
gen_aligner_config: GenAlignerConfig
gen_head_config: GenHeadConfig
gen_vision_config: GenVisionConfig = None
gen_aligner_config: GenAlignerConfig = None
gen_head_config: GenHeadConfig = None

language_config: LlamaConfig
language_config: LlamaConfig = None

def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -595,12 +595,12 @@ def batchify(

class VLMImageProcessorConfig(PretrainedConfig):
model_type = "deepseek_vlm"
image_size: int
min_size: int
image_mean: Union[Tuple[float, float, float], List[float]]
image_std: Union[Tuple[float, float, float], List[float]]
rescale_factor: float
do_normalize: bool
image_size: int = None
min_size: int = None
image_mean: Union[Tuple[float, float, float], List[float]] = None
image_std: Union[Tuple[float, float, float], List[float]] = None
rescale_factor: float = None
do_normalize: bool = None

def __init__(
self,
Expand Down
24 changes: 12 additions & 12 deletions python/sglang/srt/configs/jet_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ class JetBlockConfig:
class JetNemotronConfig(PretrainedConfig):
model_type: str = "jet_nemotron"

efficient_attention_config: dict[str, dict[str, Any]]
hidden_act: str
hidden_size: int
initializer_range: float
intermediate_size: int
layer_types: list[str]
max_position_embeddings: int
num_attention_heads: int
num_key_value_heads: int
rms_norm_eps: float
rope_scaling: None
rope_theta: float
efficient_attention_config: dict[str, dict[str, Any]] = None
hidden_act: str = None
hidden_size: int = None
initializer_range: float = None
intermediate_size: int = None
layer_types: list[str] = None
max_position_embeddings: int = None
num_attention_heads: int = None
num_key_value_heads: int = None
rms_norm_eps: float = None
rope_scaling: None = None
rope_theta: float = None

@property
def full_attention_layer_ids(self) -> list[int]:
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ class ChatCompletionRequest(BaseModel):
return_routed_experts: bool = False
return_cached_tokens_details: bool = False
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = Field(
default="medium",
default=None,
description="Constrains effort on reasoning for reasoning models. "
"'none' disables reasoning entirely, 'low' is the least effort, 'high' is the most effort. "
"Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning "
Expand Down
45 changes: 32 additions & 13 deletions python/sglang/srt/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ def _process_messages(
if self.is_gpt_oss:
request.skip_special_tokens = False

self._patch_mistral_skip_special_tokens(request)

tool_call_constraint = None

# Apply chat template and its stop strings
Expand Down Expand Up @@ -469,19 +471,20 @@ def _apply_jinja_template(
self._handle_last_assistant_message(openai_compatible_messages, request)
)

extra_template_kwargs = {}
if request.reasoning_effort is not None:
extra_template_kwargs["reasoning_effort"] = request.reasoning_effort
if request.chat_template_kwargs:
extra_template_kwargs.update(request.chat_template_kwargs)

try:
prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
tools=tools,
reasoning_effort=request.reasoning_effort,
**(
request.chat_template_kwargs
if request.chat_template_kwargs
else {}
),
return_dict=False,
**extra_template_kwargs,
)
except Exception as e:
# If the first attempt fails, try with flat function-only format.
Expand All @@ -497,13 +500,8 @@ def _apply_jinja_template(
tokenize=True,
add_generation_prompt=True,
tools=tools,
reasoning_effort=request.reasoning_effort,
**(
request.chat_template_kwargs
if request.chat_template_kwargs
else {}
),
return_dict=False,
**extra_template_kwargs,
)
except jinja2.TemplateError as template_error:
# Template errors (e.g., from raise_exception in Jinja templates)
Expand Down Expand Up @@ -1234,8 +1232,22 @@ def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int:
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
return idx

def _patch_mistral_skip_special_tokens(
self, request: ChatCompletionRequest
) -> None:
"""Mistral uses special tokens ([THINK]/[/THINK]) for reasoning markers,
which get stripped when skip_special_tokens=True."""
if (
self.reasoning_parser in ["mistral"]
and request.reasoning_effort is not None
and request.reasoning_effort != "none"
):
request.skip_special_tokens = False

def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool:
"""Judge whether the request needs reasoning"""
"""Judge whether the request needs reasoning for hybrid reasoning models
NOTE: This is predefined based on model's chat template
"""
if not self.reasoning_parser:
return False
if self.reasoning_parser in ["deepseek-v3"]:
Expand All @@ -1256,6 +1268,13 @@ def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool:
not request.chat_template_kwargs
or request.chat_template_kwargs.get("enable_thinking") is not False
)
if self.reasoning_parser in ["mistral"]:
# Mistral models only reason when reasoning_effort is explicitly
# set to a value other than None/"none" (typically "high").
return (
request.reasoning_effort is not None
and request.reasoning_effort != "none"
)
return True # default

async def _process_tool_call_stream(
Expand Down
26 changes: 17 additions & 9 deletions python/sglang/srt/function_call/mistral_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,27 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult
return StreamingParseResult(normal_text=combined_normal, calls=calls)

# Compact: `[TOOL_CALLS]tool_name[ARGS]{...}`
parsed = self._try_parse_compact_args_format(tool_part)
if not parsed:
# Loop to extract all consecutive compact tool calls.
all_calls: list = []
remaining = tool_part
while remaining:
parsed = self._try_parse_compact_args_format(remaining)
if not parsed:
break
func_name, args_obj, consumed = parsed
new_calls = self.parse_base_json(
{"name": func_name, "arguments": args_obj}, tools
)
all_calls.extend(new_calls)
remaining = remaining[consumed:].strip()

if not all_calls:
return StreamingParseResult(normal_text=normal_text, calls=[])
func_name, args_obj, consumed = parsed

calls = self.parse_base_json({"name": func_name, "arguments": args_obj}, tools)
trailing_text = tool_part[consumed:].strip()
combined_normal = (
(normal_text + " " + trailing_text).strip()
if trailing_text
else normal_text
(normal_text + " " + remaining).strip() if remaining else normal_text
)
return StreamingParseResult(normal_text=combined_normal, calls=calls)
return StreamingParseResult(normal_text=combined_normal, calls=all_calls)

def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp8(
# Move kernel call outside context manager to avoid graph breaks
# during torch.compile for piecewise cuda graph.
# Use custom op wrapper for torch.compile compatibility.
output = torch.ops.sglang.trtllm_fp8_per_tensor_scale_moe(
output = torch.ops.sglang.trtllm_fp8_per_tensor_scale_moe_wrapper(
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=routing_bias_cast,
hidden_states=a_q,
Expand Down
Loading
Loading