Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
2e20be6
[qwen3_next vl] first push draft model file
Jan 21, 2026
fd9bdb7
[qwen3_next vl] load moe, need to fix some bugs
Jan 22, 2026
f1cec48
[mrope] fix for transformers v5:
Jan 22, 2026
6f2dfe1
[qwen3_next vl] support moe version
Jan 22, 2026
89f10c9
[qwen3_next vl] support dense model
Jan 22, 2026
c6a6eaa
[lint] delete debug info
Jan 23, 2026
ab2f31b
[qwen3_next vl] support mtp model
Jan 23, 2026
c046573
fix : add multimodel config for qwen3.5
Jan 24, 2026
afc5a35
refactor(srt): Renamed modules and simplified the Qwen 3.5 model code…
Jan 24, 2026
39f3499
fix: rope_theta and partial_rotary_factor factor
Jan 24, 2026
6c04e0b
[qwen3_5] add qwen3_5 dense and moe model
Jan 24, 2026
f5eeb96
fix(srt): import FusedMoE
Jan 24, 2026
20ed51f
fix: offline preprocessed video
Jan 29, 2026
ed4bb6d
1. fallback qwen3vl
Feb 4, 2026
f004a2c
fix: remove visual parameter for Qwen3_5ForCausalLM
Feb 4, 2026
7a8906f
fix import
Feb 5, 2026
14f5fa3
[qwen3_next vl] first push draft model file
Jan 21, 2026
69c5b3a
[qwen3_next vl] load moe, need to fix some bugs
Jan 22, 2026
0ad9dc1
[mrope] fix for transformers v5:
Jan 22, 2026
e263cac
[qwen3_next vl] support moe version
Jan 22, 2026
e71b10c
[qwen3_next vl] support dense model
Jan 22, 2026
11ef96c
[lint] delete debug info
Jan 23, 2026
36ddc67
[qwen3_next vl] support mtp model
Jan 23, 2026
f3e43e8
fix : add multimodel config for qwen3.5
Jan 24, 2026
87e60df
refactor(srt): Renamed modules and simplified the Qwen 3.5 model code…
Jan 24, 2026
542670c
fix: rope_theta and partial_rotary_factor factor
Jan 24, 2026
2bbfe83
[qwen3_5] add qwen3_5 dense and moe model
Jan 24, 2026
12922e6
fix(srt): import FusedMoE
Jan 24, 2026
0cc1cba
fix: offline preprocessed video
Jan 29, 2026
e7b8130
1. fallback qwen3vl
Feb 4, 2026
f867390
fix: remove visual parameter for Qwen3_5ForCausalLM
Feb 4, 2026
b7e00fd
fix import
Feb 5, 2026
a4a6b8c
remove Qwen3NextVLConfig
Feb 5, 2026
bba3160
Merge branch 'yuche/qwen_next_vl' of http://gitlab.alibaba-inc.com/Da…
Feb 5, 2026
68b00b6
fix(logits_processor): clean up redundant code and merge conflict mar…
Feb 5, 2026
4ddc7c9
fix: support RadixLinearAttention
Feb 5, 2026
ac32bb8
fix: add target_extend_input_embeds for mtp(vl embedding)
Feb 5, 2026
179f122
refactor(models): optimize convolution weight passing logic in Qwen3_…
Feb 7, 2026
5ede952
fix: add activation type
Feb 7, 2026
5488ab2
fix: reduce_results for out_proj
Feb 7, 2026
6828225
refactor(logits): Replace target_extend_input_embeds with mm_input_em…
Feb 8, 2026
58dfb2b
fix: rename target_extend_input_embeds to mm_input_embeds
Feb 8, 2026
0df689a
fix: indent for deepstack
Feb 9, 2026
4b46a5a
refactor(srt): Streamline redundant comments in Qwen3_5MultiTokenPred…
Feb 9, 2026
43623bd
feat(moe): fix mamba cache support
Feb 9, 2026
f175dc3
fix: mamba cache support Qwen3_5ForConditionalGeneration
Feb 9, 2026
b551a40
pre-commit
Feb 9, 2026
35a201d
fix qwen3.5 video preprocess
Feb 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmark/kernels/fused_moe_triton/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def get_model_config(
"Qwen3MoeForCausalLM",
"Qwen3NextForCausalLM",
"Qwen3VLMoeForConditionalGeneration",
"Qwen3_5MoeForConditionalGeneration",
]:
E = config.num_experts // ep_size
topk = config.num_experts_per_tok
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sglang.srt.configs.nano_nemotron_vl import NemotronH_Nano_VL_V2_Config
from sglang.srt.configs.nemotron_h import NemotronHConfig
from sglang.srt.configs.olmo3 import Olmo3Config
from sglang.srt.configs.qwen3_5 import Qwen3_5Config, Qwen3_5MoeConfig
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
from sglang.srt.configs.step3_vl import (
Step3TextConfig,
Expand All @@ -43,6 +44,8 @@
"KimiLinearConfig",
"KimiK25Config",
"Qwen3NextConfig",
"Qwen3_5Config",
"Qwen3_5MoeConfig",
"DotsVLMConfig",
"DotsOCRConfig",
"FalconH1Config",
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,13 @@ def _config_draft_model(self):
self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
self.hf_config.num_nextn_predict_layers = 1

if is_draft_model and self.hf_config.architectures[0] in [
"Qwen3_5ForConditionalGeneration",
"Qwen3_5MoeForConditionalGeneration",
]:
self.hf_config.architectures[0] = "Qwen3_5ForCausalLMMTP"
self.hf_config.num_nextn_predict_layers = 1

if is_draft_model and self.hf_config.architectures[0] == "ExaoneMoEForCausalLM":
self.hf_config.architectures[0] = "ExaoneMoEForCausalLMMTP"
self.hf_config.num_nextn_predict_layers = 1
Expand Down Expand Up @@ -1193,6 +1200,8 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"Qwen3_5ForConditionalGeneration",
"Qwen3_5MoeForConditionalGeneration",
"Qwen3OmniMoeForConditionalGeneration",
"KimiVLForConditionalGeneration",
"InternVLChatModel",
Expand Down
113 changes: 113 additions & 0 deletions python/sglang/srt/configs/qwen3_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from transformers import PretrainedConfig

from sglang.srt.configs.qwen3_next import Qwen3NextConfig
from sglang.srt.configs.qwen3_vl import Qwen3VLVisionConfig


class Qwen3_5VisionConfig(Qwen3VLVisionConfig):
model_type = "qwen3_5"
base_config_key = "vision_config"


class Qwen3_5TextConfig(Qwen3NextConfig):
model_type = "qwen3_5_text"
base_config_key = "text_config"

def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
if self.rope_scaling is None:
self.rope_scaling = {}


class Qwen3_5Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Qwen3_5Model`]. It is used to instantiate a
Qwen3.5 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen3.5.

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.


Args:
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5TextConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3_5VisionConfig`):
The config object or dictionary of the vision backbone.
image_token_id (`int`, *optional*, defaults to 151655):
The image token index to encode the image prompt.
video_token_id (`int`, *optional*, defaults to 151656):
The video token index to encode the image prompt.
vision_start_token_id (`int`, *optional*, defaults to 151652):
The start token index to encode the image prompt.
vision_end_token_id (`int`, *optional*, defaults to 151653):
The end token index to encode the image prompt.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie the word embeddings.

```python
>>> from transformers import Qwen3_5ForConditionalGeneration, Qwen3_5Config

>>> # Initializing a Qwen3.5 style configuration
>>> configuration = Qwen3_5Config()

>>> # Initializing a model from the Qwen3.5 style configuration
>>> model = Qwen3_5ForConditionalGeneration(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "qwen3_5"
sub_configs = {
"vision_config": Qwen3_5VisionConfig,
"text_config": Qwen3_5TextConfig,
}
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
text_config=None,
vision_config=None,
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
vision_end_token_id=151653,
tie_word_embeddings=False,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()

if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
self.text_config = self.sub_configs["text_config"]()

self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.vision_start_token_id = vision_start_token_id
self.vision_end_token_id = vision_end_token_id
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)


class Qwen3_5MoeVisionConfig(Qwen3_5VisionConfig):
model_type = "qwen3_5_moe"


class Qwen3_5MoeTextConfig(Qwen3_5TextConfig):
model_type = "qwen3_5_moe_text"


class Qwen3_5MoeConfig(Qwen3_5Config):
model_type = "qwen3_5_moe"
sub_configs = {
"vision_config": Qwen3_5MoeVisionConfig,
"text_config": Qwen3_5MoeTextConfig,
}
11 changes: 11 additions & 0 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class LogitsProcessorOutput:
## Part 5: Customized Info
customized_info: Optional[Dict[str, List[Any]]] = None

mm_input_embeds: Optional[torch.Tensor] = None


@dataclasses.dataclass
class LogitsMetadata:
Expand Down Expand Up @@ -146,6 +148,8 @@ class LogitsMetadata:
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False

mm_input_embeds: Optional[torch.Tensor] = None

@classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch):
if (
Expand Down Expand Up @@ -196,6 +200,7 @@ def from_forward_batch(cls, forward_batch: ForwardBatch):
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DpPaddingMode.SUM_LEN,
mm_input_embeds=forward_batch.mm_input_embeds,
)

def compute_dp_attention_metadata(self):
Expand Down Expand Up @@ -341,6 +346,7 @@ def forward(
return LogitsProcessorOutput(
next_token_logits=sampled_logits,
hidden_states=hidden_states_to_store,
mm_input_embeds=logits_metadata.mm_input_embeds,
)

# Start to process input logprobs
Expand Down Expand Up @@ -386,6 +392,7 @@ def forward(
input_top_logprobs_idx=logprobs_result.input_top_logprobs_idx,
input_token_ids_logprobs_val=logprobs_result.input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=logprobs_result.input_token_ids_logprobs_idx,
mm_input_embeds=logits_metadata.mm_input_embeds,
)

def _get_pruned_states(
Expand Down Expand Up @@ -1067,6 +1074,10 @@ def compute_logprobs_for_multi_item_scoring(
input_top_logprobs_idx=input_top_logprobs_idx,
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
# FIXME: These fields are not logits-related but are passed through here as a
# workaround since ForwardBatch is local to forward_batch_generation().
# They should be moved to GenerationBatchResult to keep this class clean.
mm_input_embeds=logits_metadata.mm_input_embeds,
)


Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1822,7 +1822,9 @@ def get_rope_index(
**kwargs,
)
if (
model_type.startswith("qwen3_vl") or model_type.startswith("qwen3_vl_moe")
model_type.startswith("qwen3_vl")
or model_type.startswith("qwen3_vl_moe")
or model_type.startswith("qwen3_5")
) and video_grid_thw is not None:
video_grid_thw = torch.repeat_interleave(
video_grid_thw, video_grid_thw[:, 0], dim=0
Expand Down Expand Up @@ -1922,6 +1924,8 @@ def get_rope_index(
"qwen2_vl",
"qwen3_vl",
"qwen3_vl_moe",
"qwen3_5",
"qwen3_5_moe",
):
t_index = (
torch.arange(llm_grid_t, device=position_ids.device)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,7 @@ def general_mm_embed_routine(
if isinstance(feature, torch.Tensor) and feature.is_cuda:
mm_item.feature = feature.to("cpu", non_blocking=True)
forward_batch.mm_inputs = None
forward_batch.mm_input_embeds = input_embeds
else:
input_embeds = embed_tokens(input_ids)
# Copy to pre-allocated buffer if available (for CUDA graph address stability)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin):
# Speculative decoding
spec_info: Optional[SpecInput] = None
spec_algorithm: SpeculativeAlgorithm = None
mm_input_embeds: Optional[torch.Tensor] = None
capture_hidden_mode: CaptureHiddenMode = None

# For padding
Expand Down
17 changes: 14 additions & 3 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
Lfm2Config,
NemotronH_Nano_VL_V2_Config,
NemotronHConfig,
Qwen3_5Config,
Qwen3_5MoeConfig,
Qwen3NextConfig,
)
from sglang.srt.configs.device_config import DeviceConfig
Expand Down Expand Up @@ -1498,8 +1500,15 @@ def qwen3_next_config(self):

@property
def hybrid_gdn_config(self):
config = self.model_config.hf_config
if isinstance(config, Qwen3NextConfig | JetNemotronConfig | JetVLMConfig):
config = self.model_config.hf_config.get_text_config()
if isinstance(
config,
Qwen3NextConfig
| Qwen3_5Config
| Qwen3_5MoeConfig
| JetNemotronConfig
| JetVLMConfig,
):
return config
return None

Expand Down Expand Up @@ -2476,7 +2485,9 @@ def compute_logprobs_only(
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
rope_scaling = getattr(
self.model_config.hf_text_config, "rope_parameters", None
) or getattr(self.model_config.hf_text_config, "rope_scaling", {})
if rope_scaling is None:
return False
is_mrope_enabled = "mrope_section" in rope_scaling
Expand Down
Loading
Loading