diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index e79899e6fe7a..14e78d383523 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -456,6 +456,7 @@ th {
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | |
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ |
| `Step1ForCausalLM` | Step-Audio | `stepfun-ai/Step-Audio-EditX`, etc. | ✅︎ | ✅︎ |
+| `Step3p5ForCausalLM` | Step-3.5-flash | `stepfun-ai/step-3.5-flash`, etc. | | ✅︎ |
| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ |
| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ |
| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ |
diff --git a/tests/kernels/core/test_activation.py b/tests/kernels/core/test_activation.py
index 8f28e967aec3..66727a3099ee 100644
--- a/tests/kernels/core/test_activation.py
+++ b/tests/kernels/core/test_activation.py
@@ -17,6 +17,8 @@
QuickGELU,
SiluAndMul,
SwigluOAIAndMul,
+ SwigluStepAndMul,
+ swiglustep_and_mul_triton,
)
from vllm.utils.torch_utils import set_random_seed
@@ -36,6 +38,7 @@
"gelu_tanh",
"fatrelu",
"swigluoai_and_mul",
+ "swiglustep_and_mul",
],
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@@ -75,9 +78,12 @@ def test_act_and_mul(
elif activation == "swigluoai_and_mul":
layer = SwigluOAIAndMul()
fn = torch.ops._C.swigluoai_and_mul
+ elif activation == "swiglustep_and_mul":
+ layer = SwigluStepAndMul()
+ fn = swiglustep_and_mul_triton
out = layer(x)
ref_out = layer.forward_native(x)
- if activation == "swigluoai_and_mul":
+ if activation in ["swigluoai_and_mul", "swiglustep_and_mul"]:
rtol = {
# For fp16, change the relative tolerance from 1e-3 to 2e-3
torch.float16: 2e-3,
@@ -104,7 +110,7 @@ def _get_rtol(output) -> float:
opcheck(fn, (out, x, threshold))
elif activation == "swigluoai_and_mul":
opcheck(fn, (out, x, layer.alpha, layer.limit))
- else:
+ elif activation != "swiglustep_and_mul":
opcheck(fn, (out, x))
diff --git a/tests/models/registry.py b/tests/models/registry.py
index c2760d37f4cd..1bee16c81d01 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -484,6 +484,9 @@ def check_available_online(
"Step1ForCausalLM": _HfExamplesInfo(
"stepfun-ai/Step-Audio-EditX", trust_remote_code=True
),
+ "Step3p5ForCausalLM": _HfExamplesInfo(
+ "stepfun-ai/step-3.5-flash", is_available_online=False
+ ),
"SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"),
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
@@ -1113,6 +1116,12 @@ def check_available_online(
"Qwen3NextMTP": _HfExamplesInfo(
"Qwen/Qwen3-Next-80B-A3B-Instruct", min_transformers_version="4.56.3"
),
+ "Step3p5MTP": _HfExamplesInfo(
+ "stepfun-ai/Step-3.5-Flash",
+ trust_remote_code=True,
+ speculative_model="stepfun-ai/Step-3.5-Flash",
+ is_available_online=False,
+ ),
}
_TRANSFORMERS_BACKEND_MODELS = {
diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py
index f3de1e171af2..966d168b47ab 100644
--- a/vllm/config/speculative.py
+++ b/vllm/config/speculative.py
@@ -41,6 +41,7 @@
"longcat_flash_mtp",
"mtp",
"pangu_ultra_moe_mtp",
+ "step3p5_mtp",
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
@@ -264,6 +265,11 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
{"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
)
+ if hf_config.model_type == "step3p5":
+ hf_config.model_type = "step3p5_mtp"
+ n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
+ hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]})
+
if initial_architecture == "MistralLarge3ForCausalLM":
hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})
diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py
index c8822aed2e3f..b53a37a31761 100644
--- a/vllm/model_executor/layers/activation.py
+++ b/vllm/model_executor/layers/activation.py
@@ -17,11 +17,63 @@
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
+from vllm.triton_utils import tl, triton
from vllm.utils.collection_utils import LazyDict
logger = init_logger(__name__)
+@triton.jit
+def _swiglustep_and_mul_kernel(
+ o_ptr,
+ o_stride,
+ x_ptr,
+ x_stride,
+ limit: tl.constexpr,
+ d: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+) -> None:
+ i = tl.program_id(axis=0).to(tl.int64)
+ j = tl.program_id(axis=1)
+ o_row_ptr = o_ptr + o_stride * i
+ x_row_ptr = x_ptr + x_stride * i
+ offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < d
+
+ gate = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32)
+ up = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32)
+
+ gate_silu = tl.sigmoid(gate) * gate
+ gate_clamped = tl.minimum(gate_silu, limit)
+ up_clamped = tl.minimum(tl.maximum(up, -limit), limit)
+
+ result = gate_clamped * up_clamped
+ result = result.to(x_ptr.dtype.element_ty)
+ tl.store(o_row_ptr + offsets, result, mask=mask)
+
+
+def swiglustep_and_mul_triton(
+ output: torch.Tensor, input: torch.Tensor, limit: float = 7.0
+):
+ b, n = input.shape
+ assert input.ndim == 2
+ assert n % 2 == 0
+ d = n // 2
+
+ def grid(meta):
+ return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))
+
+ _swiglustep_and_mul_kernel[grid](
+ output,
+ output.stride(0),
+ input,
+ input.stride(0),
+ limit=limit,
+ d=d,
+ BLOCK_SIZE=1024,
+ )
+
+
# --8<-- [start:fatrelu_and_mul]
@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
@@ -304,6 +356,44 @@ def extra_repr(self) -> str:
return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"
+# --8<-- [start:swiglustep_and_mul]
+@CustomOp.register("swiglustep_and_mul")
+class SwigluStepAndMul(CustomOp):
+ """An activation function for SwiGLU with clamping.
+
+ Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
+ where d = x.shape[-1] // 2.
+
+ Shapes:
+ x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
+ return: (num_tokens, d) or (batch_size, seq_len, d)
+ """
+
+ def __init__(self, limit: float = 7.0):
+ super().__init__()
+ if limit is None:
+ raise ValueError("SwigluStepAndMul requires limit to be set.")
+ self.limit = limit
+
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
+ """PyTorch-native implementation equivalent to forward()."""
+ gate, up = x.chunk(2, dim=-1)
+ gate = F.silu(gate)
+ gate = gate.clamp(max=self.limit)
+ up = up.clamp(min=-self.limit, max=self.limit)
+ return gate * up
+
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
+ d = x.shape[-1] // 2
+ output_shape = x.shape[:-1] + (d,)
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
+ swiglustep_and_mul_triton(out, x, self.limit)
+ return out
+
+ def extra_repr(self) -> str:
+ return f"limit={repr(self.limit)}"
+
+
# --8<-- [start:gelu_new]
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):
diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
index fafcf6de6140..00d55bfb7e04 100644
--- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
@@ -144,7 +144,7 @@ def _supports_quant_scheme(
@staticmethod
def _supports_activation(activation: str) -> bool:
- return activation in ["silu"]
+ return activation in ["silu", "swiglustep"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index 987388692725..120b3c2d1e2c 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -1939,7 +1939,7 @@ def _supports_quant_scheme(
@staticmethod
def _supports_activation(activation: str) -> bool:
- return activation in ["silu", "gelu", "swigluoai"]
+ return activation in ["silu", "gelu", "swigluoai", "swiglustep"]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py
index a4b20505ea32..75873a92abdb 100644
--- a/vllm/model_executor/layers/fused_moe/utils.py
+++ b/vllm/model_executor/layers/fused_moe/utils.py
@@ -358,6 +358,11 @@ def apply_moe_activation(
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
torch.ops._C.swigluoai_and_mul(output, input)
+ elif activation == "swiglustep":
+ from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
+
+ swiglustep_and_mul_triton(output, input)
+
# Activations without gated multiplication
elif activation == SILU_NO_MUL:
output.copy_(F.silu(input))
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index f38914a7ce33..95f6cc06527a 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -189,6 +189,7 @@
"SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
"Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
+ "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
@@ -495,6 +496,7 @@
"MedusaModel": ("medusa", "Medusa"),
"OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
+ "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
# Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py
new file mode 100644
index 000000000000..f0d7b4a75a9d
--- /dev/null
+++ b/vllm/model_executor/models/step3p5.py
@@ -0,0 +1,894 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Inference-only Jurassic model."""
+
+from collections.abc import Iterable
+from typing import Any
+
+import torch
+from torch import nn
+from torch.nn.parameter import Parameter
+
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, ModelConfig, VllmConfig
+from vllm.distributed import (
+ get_dp_group,
+ get_ep_group,
+ get_pp_group,
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ get_tp_group,
+)
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import SiluAndMul, SwigluStepAndMul
+from vllm.model_executor.layers.attention import Attention
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
+from vllm.model_executor.layers.layernorm import GemmaRMSNorm
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ DEFAULT_VOCAB_PADDING_SIZE,
+ ParallelLMHead,
+ VocabParallelEmbedding,
+)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.sequence import IntermediateTensors
+from vllm.v1.attention.backend import AttentionType
+
+from .interfaces import MixtureOfExperts, SupportsPP
+from .utils import (
+ AutoWeightsLoader,
+ PPMissingLayer,
+ WeightsMapper,
+ extract_layer_index,
+ is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory,
+ make_layers,
+ maybe_prefix,
+)
+
+logger = init_logger(__name__)
+
+
+class FP32ReplicatedLinear(ReplicatedLinear):
+ """
+ Use FP32 for higher precision.
+ """
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
+ assert self.params_dtype == torch.float32
+ return super().forward(x.to(torch.float32))
+
+
+class Step3p5MLP(nn.Module):
+ def __init__(
+ self,
+ config: ModelConfig,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ quant_config: QuantizationConfig | None = None,
+ reduce_results: bool = True,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj",
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ reduce_results=reduce_results,
+ prefix=f"{prefix}.down_proj",
+ )
+
+ if hidden_act != "silu":
+ raise ValueError(
+ f"Unsupported activation: {hidden_act}. Only silu is supported for now."
+ )
+ self.act_fn = SiluAndMul()
+ self.prefix = prefix
+ self.hidden_size = hidden_size
+ self.limit = None
+ layer_idx = extract_layer_index(prefix)
+ if (
+ config.swiglu_limits_shared
+ and config.swiglu_limits_shared[layer_idx] is not None
+ and config.swiglu_limits_shared[layer_idx] != 0
+ ):
+ self.limit = config.swiglu_limits_shared[layer_idx]
+ self.act_fn = SwigluStepAndMul(limit=self.limit)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ gate_up, _ = self.gate_up_proj(hidden_states)
+ intermediate_act = self.act_fn(gate_up)
+ output, _ = self.down_proj(intermediate_act)
+ return output
+
+
+class Step3p5Attention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ max_position: int = 4096 * 32,
+ head_dim: int | None = None,
+ rms_norm_eps: float = 1e-06,
+ qkv_bias: bool = False,
+ rope_theta: float | list[float] | None = 10000,
+ cache_config: CacheConfig | None = None,
+ quant_config: QuantizationConfig | None = None,
+ rope_scaling: dict[str, Any] | None = None,
+ prefix: str = "",
+ attn_type: str = AttentionType.DECODER,
+ # Step3p5 specific args
+ sliding_window: int | None = None,
+ use_head_wise_attn_gate: bool = False,
+ layer_types: list = None,
+ use_rope_layers: list = None,
+ yarn_only_types: list = None,
+ swa_num_attention_heads: int | None = None,
+ partial_rotary_factor: float = 1.0,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.total_num_heads = num_heads
+ tp_size = get_tensor_model_parallel_world_size()
+ self.layer_idx = extract_layer_index(prefix)
+ if layer_types:
+ enable_sliding_window = layer_types[self.layer_idx] == "sliding_attention"
+ else:
+ enable_sliding_window = self.layer_idx % 2 == 0
+ if yarn_only_types and layer_types[self.layer_idx] not in yarn_only_types:
+ rope_scaling = None
+
+ if sliding_window is not None and enable_sliding_window:
+ sliding_window = sliding_window
+ if swa_num_attention_heads is not None:
+ num_heads = swa_num_attention_heads
+ self.total_num_heads = swa_num_attention_heads
+ else:
+ sliding_window = None
+
+ if isinstance(rope_theta, list):
+ rope_theta = rope_theta[self.layer_idx]
+
+ self.rank = get_tensor_model_parallel_rank()
+ self.partial_rotary_factor = partial_rotary_factor
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ if self.total_num_kv_heads >= tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+ self.rope_theta = rope_theta
+ self.qkv_proj = QKVParallelLinear(
+ hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=qkv_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ )
+
+ if rope_scaling is not None and not isinstance(rope_scaling, dict):
+ raise ValueError("rope_scaling must be a dict for Step3p5Attention.")
+
+ rope_parameters: dict[str, Any] = (
+ dict(rope_scaling) if rope_scaling is not None else {}
+ )
+ rope_parameters.setdefault("rope_type", "default")
+ rope_parameters["rope_theta"] = self.rope_theta
+ rope_parameters["partial_rotary_factor"] = partial_rotary_factor
+
+ self.rotary_emb = get_rope(
+ head_size=self.head_dim,
+ max_position=max_position,
+ rope_parameters=rope_parameters,
+ )
+
+ self.q_norm = GemmaRMSNorm(self.head_dim, rms_norm_eps)
+ self.k_norm = GemmaRMSNorm(self.head_dim, rms_norm_eps)
+ self.use_head_wise_attn_gate = use_head_wise_attn_gate
+ if use_head_wise_attn_gate:
+ self.g_proj = ColumnParallelLinear(
+ hidden_size,
+ self.total_num_heads,
+ bias=False,
+ prefix=f"{prefix}.g_proj",
+ )
+
+ self.use_rope = True
+ if use_rope_layers:
+ self.use_rope = use_rope_layers[self.layer_idx]
+
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn",
+ per_layer_sliding_window=sliding_window,
+ attn_type=attn_type,
+ )
+
+ self.max_position_embeddings = max_position
+ assert self.partial_rotary_factor == 1 or self.partial_rotary_factor == 0.5
+ self.rotary_dim = (
+ self.head_dim if self.partial_rotary_factor == 1 else self.head_dim // 2
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ # Add qk-norm inline similar to Qwen3 MOE attention
+ q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
+ q_by_head = self.q_norm(q_by_head.contiguous())
+ q = q_by_head.view(q.shape)
+
+ k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
+ k_by_head = self.k_norm(k_by_head.contiguous())
+ k = k_by_head.view(k.shape)
+ if self.use_rope:
+ q, k = self.rotary_emb(positions, q, k)
+ attn_output = self.attn(q, k, v)
+ if self.use_head_wise_attn_gate:
+ extra_dims, _ = self.g_proj(hidden_states)
+ output = (
+ attn_output.view(*attn_output.shape[:-1], self.num_heads, self.head_dim)
+ * extra_dims.unsqueeze(-1).sigmoid()
+ )
+ attn_output = output.view(*attn_output.shape)
+ output, _ = self.o_proj(attn_output)
+ return output
+
+
+class FusedMoEBlock(nn.Module):
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ ):
+ super().__init__()
+
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.layer_idx = extract_layer_index(prefix)
+
+ self.ep_size = get_ep_group().device_group.size()
+ self.ep_rank = get_ep_group().device_group.rank()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ parallel_config = vllm_config.parallel_config
+
+ self.hidden_size = config.hidden_size
+ self.enable_eplb = parallel_config.enable_eplb
+ self.n_routed_experts = config.moe_num_experts
+ self.n_logical_experts = self.n_routed_experts
+ self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts
+ self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
+ self.n_local_physical_experts = self.n_physical_experts // self.ep_size
+
+ self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
+ self.physical_expert_end = (
+ self.physical_expert_start + self.n_local_physical_experts
+ )
+
+ if self.tp_size > config.moe_num_experts:
+ raise ValueError(
+ f"Tensor parallel size {self.tp_size} is greater than "
+ f"the number of experts {config.moe_num_experts}."
+ )
+
+ self.gate = FP32ReplicatedLinear(
+ config.hidden_size,
+ config.moe_num_experts,
+ bias=False,
+ quant_config=None,
+ params_dtype=torch.float32, # Use FP32 for higher precision.
+ prefix=f"{prefix}.gate",
+ )
+ self.use_moe_router_bias = config.use_moe_router_bias
+ assert self.use_moe_router_bias, "Only support use_moe_router_bias is true."
+ self.routed_scaling_factor = config.moe_router_scaling_factor
+ self.router_bias = nn.Parameter(
+ torch.zeros(config.moe_num_experts, dtype=torch.float32),
+ requires_grad=False,
+ )
+ self.need_fp32_gate = config.need_fp32_gate
+ assert self.need_fp32_gate, (
+ "Router logits must use FP32 precision for numerical stability."
+ )
+
+ activation = "silu"
+ swiglu_limits = config.swiglu_limits or []
+ swiglu_limit = (
+ swiglu_limits[self.layer_idx]
+ if self.layer_idx < len(swiglu_limits)
+ else None
+ )
+ if swiglu_limit not in (None, 0):
+ swiglu_limit = float(swiglu_limit)
+ assert swiglu_limit == 7.0, (
+ "Swiglu limit in fused moe block only suport 7.0 now."
+ )
+ activation = "swiglustep"
+ logger.debug(
+ "step3p5 layer_idx: %s, activation: %s, limit: %s",
+ self.layer_idx,
+ activation,
+ swiglu_limit,
+ )
+
+ self.share_expert = Step3p5MLP(
+ config=config,
+ hidden_size=self.hidden_size,
+ intermediate_size=config.share_expert_dim,
+ hidden_act="silu",
+ reduce_results=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.share_expert",
+ )
+ self.experts = SharedFusedMoE(
+ shared_experts=self.share_expert,
+ gate=self.gate,
+ num_experts=config.moe_num_experts,
+ top_k=config.moe_top_k,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ reduce_results=False,
+ renormalize=config.norm_expert_weight,
+ quant_config=quant_config,
+ activation=activation,
+ prefix=f"{prefix}.experts",
+ scoring_func=getattr(config, "moe_router_activation", "sigmoid"),
+ e_score_correction_bias=self.router_bias,
+ routed_scaling_factor=config.moe_router_scaling_factor,
+ enable_eplb=self.enable_eplb,
+ num_redundant_experts=self.n_redundant_experts,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ num_tokens, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+
+ if self.experts.is_internal_router:
+ # In this case, the gate/router runs inside the FusedMoE class
+ fused_moe_out = self.experts(
+ hidden_states=hidden_states, router_logits=hidden_states
+ )
+ else:
+ # router_logits: (num_tokens, n_experts)
+ router_logits, _ = self.gate(hidden_states)
+ fused_moe_out = self.experts(
+ hidden_states=hidden_states, router_logits=router_logits
+ )
+
+ shared_output, final_hidden_states = fused_moe_out
+ if self.share_expert is None:
+ assert shared_output is None
+
+ if self.share_expert is not None:
+ assert shared_output is not None
+ final_hidden_states += shared_output
+
+ if self.tp_size > 1:
+ final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
+ final_hidden_states
+ )
+
+ return final_hidden_states.view(num_tokens, hidden_dim)
+
+
+class Step3p5DecoderLayer(nn.Module):
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ self.hidden_size = config.hidden_size
+ layer_idx = extract_layer_index(prefix)
+ self.layer_idx = layer_idx
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+ if cache_config is not None:
+ cache_config.sliding_window = None
+ if config.att_impl_type == "GQA":
+ num_attention_heads = None
+ num_attention_groups = None
+ head_dim = None
+ if (
+ getattr(config, "attention_other_setting", None)
+ and getattr(config, "layer_types", [])
+ and config.layer_types[layer_idx]
+ == config.attention_other_setting["attention_type"]
+ ):
+ num_attention_heads = config.attention_other_setting[
+ "num_attention_heads"
+ ]
+ num_attention_groups = config.attention_other_setting[
+ "num_attention_groups"
+ ]
+ head_dim = config.attention_other_setting["head_dim"]
+ partial_rotary_factors = getattr(config, "partial_rotary_factors", [])
+ self.self_attn = Step3p5Attention(
+ hidden_size=self.hidden_size,
+ num_heads=num_attention_heads
+ if num_attention_heads
+ else config.num_attention_heads,
+ max_position=config.max_position_embeddings,
+ num_kv_heads=num_attention_groups
+ if num_attention_groups
+ else config.num_attention_groups,
+ rope_theta=config.rope_theta,
+ rms_norm_eps=config.rms_norm_eps,
+ qkv_bias=getattr(config, "attention_bias", False),
+ head_dim=head_dim if head_dim else getattr(config, "head_dim", None),
+ cache_config=cache_config,
+ quant_config=quant_config,
+ rope_scaling=getattr(config, "rope_scaling", None),
+ sliding_window=getattr(config, "sliding_window", None),
+ use_head_wise_attn_gate=getattr(
+ config, "use_head_wise_attn_gate", False
+ ),
+ layer_types=getattr(config, "layer_types", []),
+ use_rope_layers=getattr(config, "use_rope_layers", []),
+ yarn_only_types=getattr(config, "yarn_only_types", []),
+ partial_rotary_factor=partial_rotary_factors[layer_idx]
+ if partial_rotary_factors
+ else 1.0,
+ prefix=f"{prefix}.self_attn",
+ )
+ else:
+ raise ValueError(
+ f"Unsupported attention implementation: {config.att_impl_type}"
+ )
+ self.use_moe = False
+ self.tp_group = get_tp_group()
+ self.use_fused_all_reduce = (
+ get_tensor_model_parallel_world_size() > 1
+ and get_dp_group().world_size == 1
+ )
+ if self.use_fused_all_reduce:
+ logger.warning_once("Enable custom fused all reduce...")
+ else:
+ logger.warning_once("Disable custom fused all reduce...")
+
+ moe_layers_enum = getattr(config, "moe_layers_enum", None)
+ if moe_layers_enum is not None:
+ moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
+ else:
+ moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
+ if layer_idx in moe_layers_idx:
+ self.moe = FusedMoEBlock(
+ vllm_config,
+ prefix=f"{prefix}.moe",
+ )
+ self.use_moe = True
+ else:
+ self.mlp = Step3p5MLP(
+ config=config,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act="silu",
+ quant_config=quant_config,
+ reduce_results=True,
+ prefix=f"{prefix}.mlp",
+ )
+ self.input_layernorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
+ self.post_attention_layernorm = GemmaRMSNorm(
+ config.hidden_size, config.rms_norm_eps
+ )
+ self.prefix = prefix
+
+ def add_and_maybe_inplace_all_reduce(
+ self, in1: torch.Tensor, in2: torch.Tensor
+ ) -> torch.Tensor:
+ if not self.use_fused_all_reduce:
+ return in1 + in2
+ return self.tp_group.all_reduce(in1 + in2)
+
+ def forward(
+ self, positions: torch.Tensor, hidden_states: torch.Tensor
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states = self.self_attn(
+ positions=positions,
+ hidden_states=hidden_states,
+ )
+ hidden_states += residual
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+
+ if self.use_moe:
+ ffn_output = self.moe(hidden_states)
+ else:
+ ffn_output = self.mlp(hidden_states)
+ hidden_states = ffn_output + residual
+ return hidden_states
+
+
+@support_torch_compile
+class Step3p5Model(nn.Module):
+ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super().__init__()
+
+ self.vllm_config = vllm_config
+ config = vllm_config.model_config.hf_config
+ self.vocab_size = config.vocab_size
+ self.config = config
+
+ self.moe_num_experts = config.moe_num_experts
+
+ if get_pp_group().is_first_rank or (
+ config.tie_word_embeddings and get_pp_group().is_last_rank
+ ):
+ self.embed_tokens = VocabParallelEmbedding(
+ self.vocab_size,
+ config.hidden_size,
+ )
+ else:
+ self.embed_tokens = PPMissingLayer()
+
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: Step3p5DecoderLayer(
+ vllm_config,
+ prefix=prefix,
+ ),
+ prefix=f"{prefix}.layers",
+ )
+ if get_pp_group().is_last_rank:
+ self.norm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
+ ["hidden_states"], config.hidden_size
+ )
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.embed_input_ids(input_ids)
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ for i in range(self.start_layer, self.end_layer):
+ layer = self.layers[i]
+ hidden_states = layer(positions, hidden_states)
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors(
+ {
+ "hidden_states": hidden_states,
+ }
+ )
+
+ return hidden_states
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ config = self.config
+ assert config.num_attention_groups > 1, "Only support GQA"
+ qkv_params_mapping = []
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+
+ expert_params_mapping = [
+ (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
+ (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
+ (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"),
+ ]
+
+ disable_moe_stacked_params = [data[1] for data in expert_params_mapping]
+
+ for name, loaded_weight in weights:
+ if name.startswith("model."):
+ local_name = name[len("model.") :]
+ full_name = name
+ else:
+ local_name = name
+ full_name = f"model.{name}" if name else "model"
+
+ spec_layer = get_spec_layer_idx_from_weight_name(config, full_name)
+ if spec_layer is not None:
+ continue # skip spec decode layers for main model
+
+ # Skip any layers beyond the main model's depth (e.g., MTP layers)
+ if full_name.startswith("model.layers."):
+ parts = full_name.split(".")
+ if len(parts) > 2 and parts[2].isdigit():
+ layer_idx = int(parts[2])
+ if layer_idx >= config.num_hidden_layers:
+ continue
+
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in local_name:
+ continue
+ if any(
+ disable_moe_stacked_param in local_name
+ for disable_moe_stacked_param in disable_moe_stacked_params
+ ):
+ continue
+ replaced_name = local_name.replace(weight_name, param_name)
+ if is_pp_missing_parameter(replaced_name, self):
+ continue
+ if replaced_name not in params_dict:
+ continue
+ param = params_dict[replaced_name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ loaded_params.add(replaced_name)
+ break
+ else:
+ for param_name, weight_name, shard_id in expert_params_mapping:
+ if weight_name not in local_name:
+ continue
+ replaced_name = local_name.replace(weight_name, param_name)
+ if is_pp_missing_parameter(replaced_name, self):
+ continue
+ if (
+ replaced_name.endswith(".bias")
+ or replaced_name.endswith("_bias")
+ ) and replaced_name not in params_dict:
+ continue
+ if replaced_name not in params_dict:
+ continue
+ param = params_dict[replaced_name]
+ weight_loader = param.weight_loader
+ moe_expert_num = self.moe_num_experts
+ assert loaded_weight.shape[0] == moe_expert_num
+ for expert_id in range(moe_expert_num):
+ loaded_weight_expert = loaded_weight[expert_id]
+ weight_loader(
+ param,
+ loaded_weight_expert,
+ replaced_name,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ )
+ loaded_params.add(replaced_name)
+ break
+ else:
+ for (
+ param_name,
+ weight_name,
+ start_idx,
+ end_idx,
+ ) in qkv_params_mapping:
+ if weight_name not in local_name:
+ continue
+ replaced_name = local_name.replace(weight_name, param_name)
+ if is_pp_missing_parameter(replaced_name, self):
+ continue
+ if replaced_name not in params_dict:
+ continue
+ param = params_dict[replaced_name]
+ dim = param.shape[param.output_dim]
+ begin_idx = int(start_idx * dim)
+ end_idx = int(end_idx * dim)
+ param_slice = param.narrow(
+ param.output_dim, begin_idx, end_idx - begin_idx
+ )
+ param_slice.copy_(loaded_weight)
+ loaded_params.add(replaced_name)
+ break
+ else:
+ if is_pp_missing_parameter(local_name, self):
+ continue
+ if "expert_bias" in local_name:
+ logger.warning_once("ignore expert_bias")
+ continue
+ if local_name not in params_dict:
+ continue
+ param = params_dict[local_name]
+ weight_loader = getattr(
+ param, "weight_loader", default_weight_loader
+ )
+ weight_loader(param, loaded_weight)
+ loaded_params.add(local_name)
+ return loaded_params
+
+
+class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_substr={".share_expert.": ".moe.share_expert."}
+ )
+
+ def __init__(
+ self,
+ *,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ ):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ lora_config = vllm_config.lora_config
+ self.config = config
+ self.vllm_config = vllm_config
+
+ self.model = Step3p5Model(
+ vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
+ )
+
+ self.moe_layers: list[FusedMoEBlock] = []
+ for layer in self.model.layers:
+ if isinstance(layer, PPMissingLayer):
+ continue
+ assert isinstance(layer, Step3p5DecoderLayer)
+ if hasattr(layer, "moe") and isinstance(layer.moe, FusedMoEBlock):
+ self.moe_layers.append(layer.moe)
+
+ if get_pp_group().is_last_rank:
+ self.unpadded_vocab_size = config.vocab_size
+ if lora_config:
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+ self.lm_head = ParallelLMHead(
+ self.unpadded_vocab_size,
+ config.hidden_size,
+ org_num_embeddings=config.vocab_size,
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE
+ if not lora_config
+ else lora_config.lora_vocab_padding_size,
+ )
+ self.logits_processor = LogitsProcessor(
+ self.unpadded_vocab_size, config.vocab_size
+ )
+ else:
+ self.lm_head = PPMissingLayer()
+
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors
+ )
+
+ # Set MoE hyperparameters
+ self.expert_weights = []
+ assert len(self.moe_layers) > 0, "No MoE layers found in the model."
+ example_layer = self.moe_layers[0]
+ self.num_moe_layers = len(self.moe_layers)
+ self.num_expert_groups = 1
+ self.num_shared_experts = 0
+ self.num_logical_experts = example_layer.n_logical_experts
+ self.num_physical_experts = example_layer.n_physical_experts
+ self.num_local_physical_experts = example_layer.n_local_physical_experts
+ self.num_routed_experts = example_layer.n_routed_experts
+ self.num_redundant_experts = example_layer.n_redundant_experts
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ ):
+ hidden_states = self.model(
+ input_ids, positions, intermediate_tensors, inputs_embeds
+ )
+ return hidden_states
+
+ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.model.norm(hidden_states)
+ logits = self.logits_processor(self.lm_head, hidden_states)
+ return logits
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.embed_tokens(input_ids)
+
+ def set_eplb_state(
+ self,
+ expert_load_view: torch.Tensor,
+ logical_to_physical_map: torch.Tensor,
+ logical_replica_count: torch.Tensor,
+ ) -> None:
+ for layer_idx, layer in enumerate(self.moe_layers):
+ experts = layer.experts
+ assert isinstance(experts, FusedMoE)
+ # Register the expert weights.
+ self.expert_weights.append(experts.get_expert_weights())
+ experts.set_eplb_state(
+ moe_layer_idx=layer_idx,
+ expert_load_view=expert_load_view,
+ logical_to_physical_map=logical_to_physical_map,
+ logical_replica_count=logical_replica_count,
+ )
+
+ def update_physical_experts_metadata(
+ self,
+ num_physical_experts: int,
+ num_local_physical_experts: int,
+ ) -> None:
+ assert self.num_local_physical_experts == num_local_physical_experts
+ self.num_physical_experts = num_physical_experts
+ self.num_local_physical_experts = num_local_physical_experts
+ self.num_redundant_experts = num_physical_experts - self.num_logical_experts
+ for layer in self.moe_layers:
+ assert isinstance(layer, FusedMoEBlock)
+ layer.n_local_physical_experts = num_local_physical_experts
+ layer.n_physical_experts = num_physical_experts
+ layer.n_redundant_experts = self.num_redundant_experts
+ layer.experts.update_expert_map()
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+
+
+def get_spec_layer_idx_from_weight_name(
+ config: ModelConfig, weight_name: str
+) -> int | None:
+ if hasattr(config, "num_nextn_predict_layers") and (
+ config.num_nextn_predict_layers > 0
+ ):
+ layer_idx = config.num_hidden_layers
+ for i in range(config.num_nextn_predict_layers):
+ if weight_name.startswith(
+ f"layers.{layer_idx + i}." # Step3p5Model
+ ) or weight_name.startswith(f"model.layers.{layer_idx + i}."): # Step3p5MTP
+ return layer_idx + i
+ return None
diff --git a/vllm/model_executor/models/step3p5_mtp.py b/vllm/model_executor/models/step3p5_mtp.py
new file mode 100644
index 000000000000..83e43dce5114
--- /dev/null
+++ b/vllm/model_executor/models/step3p5_mtp.py
@@ -0,0 +1,315 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from collections.abc import Iterable
+
+import torch
+import torch.nn as nn
+from transformers import PretrainedConfig
+
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.model_executor.layers.layernorm import GemmaRMSNorm
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead,
+ VocabParallelEmbedding,
+)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.sequence import IntermediateTensors
+
+from .step3p5 import Step3p5DecoderLayer, get_spec_layer_idx_from_weight_name
+from .utils import maybe_prefix
+
+logger = init_logger(__name__)
+
+
+class SharedHead(nn.Module):
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: QuantizationConfig | None = None,
+ ) -> None:
+ super().__init__()
+ self.norm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
+ self.head = ParallelLMHead(
+ config.vocab_size, config.hidden_size, quant_config=quant_config
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return self.norm(hidden_states)
+
+
+class Step3p5AMultiTokenPredictorLayer(nn.Module):
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ prefix: str,
+ ) -> None:
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
+ self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
+ self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
+ self.shared_head = SharedHead(config=config, quant_config=quant_config)
+ self.mtp_block = Step3p5DecoderLayer(
+ vllm_config,
+ prefix=f"{prefix}.mtp_block",
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ previous_hidden_states: torch.Tensor,
+ inputs_embeds: torch.Tensor | None = None,
+ spec_step_index: int = 0,
+ ) -> torch.Tensor:
+ assert inputs_embeds is not None
+ inputs_embeds = self.enorm(inputs_embeds)
+ previous_hidden_states = self.hnorm(previous_hidden_states)
+
+ hidden_states = self.eh_proj(
+ torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
+ )
+
+ hidden_states = self.mtp_block(positions=positions, hidden_states=hidden_states)
+ return hidden_states
+
+
+class Step3p5AMultiTokenPredictor(nn.Module):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ )
+ self.mtp_start_layer_idx = config.num_hidden_layers
+ self.num_mtp_layers = config.num_nextn_predict_layers
+ # to map the exact layer index from weights
+ self.layers = torch.nn.ModuleDict(
+ {
+ str(idx): Step3p5AMultiTokenPredictorLayer(
+ vllm_config,
+ f"{prefix}.layers.{idx}",
+ )
+ for idx in range(
+ self.mtp_start_layer_idx,
+ self.mtp_start_layer_idx + self.num_mtp_layers,
+ )
+ }
+ )
+
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ previous_hidden_states: torch.Tensor,
+ inputs_embeds: torch.Tensor | None = None,
+ spec_step_idx: int = 0,
+ ) -> torch.Tensor:
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ current_step_idx = spec_step_idx % self.num_mtp_layers
+ return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
+ input_ids,
+ positions,
+ previous_hidden_states,
+ inputs_embeds,
+ current_step_idx,
+ )
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ spec_step_idx: int = 0,
+ ) -> torch.Tensor:
+ current_step_idx = spec_step_idx % self.num_mtp_layers
+ mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
+ logits = self.logits_processor(
+ mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
+ )
+ return logits
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+
+class Step3p5MTP(nn.Module):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ self.config = vllm_config.model_config.hf_config
+ self.vllm_config = vllm_config
+ self.model = Step3p5AMultiTokenPredictor(
+ vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
+ )
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.embed_input_ids(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ spec_step_idx: int = 0,
+ ) -> torch.Tensor:
+ hidden_states = self.model(
+ input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ spec_step_idx: int = 0,
+ ) -> torch.Tensor | None:
+ return self.model.compute_logits(hidden_states, spec_step_idx)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+
+ expert_params_mapping = [
+ (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
+ (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
+ (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"),
+ ]
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ if "rotary_emb.inv_freq" in name:
+ continue
+ spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
+ if "embed_tokens" not in name and spec_layer is None:
+ continue
+ name = self._rewrite_spec_layer_name(spec_layer, name)
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ # Skip non-stacked layers and experts (experts handled below).
+ if weight_name not in name:
+ continue
+ # We have mlp.experts[0].gate_proj in the checkpoint.
+ # Since we handle the experts below in expert_params_mapping,
+ # we need to skip here BEFORE we update the name, otherwise
+ # name will be updated to mlp.experts[0].gate_up_proj, which
+ # will then be updated below in expert_params_mapping
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+ if ("mlp.experts." in name) and name not in params_dict:
+ continue
+ if "experts" in name or "moe" in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ for mapping in expert_params_mapping:
+ param_name, weight_name, shard_id = mapping
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if (
+ name.endswith(".bias") or name.endswith("_bias")
+ ) and name not in params_dict:
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ for expert_id in range(loaded_weight.shape[0]):
+ loaded_weight_expert = loaded_weight[expert_id]
+ weight_loader(
+ param,
+ loaded_weight_expert,
+ name,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ )
+ loaded_params.add(name)
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if (
+ name.endswith(".bias")
+ and name not in params_dict
+ or "tok_embeddings" in name
+ ):
+ continue
+
+ if spec_layer is not None and ".transformer." in name:
+ name = name.replace(".transformer.", ".")
+ if "shared_head" in name:
+ name = name.replace("shared_head.output", "shared_head.head")
+ if "embed_tokens" in name:
+ assert (
+ hasattr(self.config, "num_nextn_predict_layers")
+ and self.config.num_nextn_predict_layers > 0
+ )
+ name = "model.embed_tokens.weight"
+ param = params_dict[name]
+ weight_loader = getattr(
+ param, "weight_loader", default_weight_loader
+ )
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ params_need_to_load = set(params_dict.keys())
+ # Some KV cache scales are optional: checkpoints may omit them and vLLM
+ # will fall back to default scales during initialization.
+ optional_params = {
+ name
+ for name, param in params_dict.items()
+ if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale"))
+ and getattr(param, "numel", lambda: 0)() == 1
+ and getattr(param, "requires_grad", False) is False
+ }
+ params_need_to_load -= optional_params
+ if params_need_to_load != loaded_params:
+ missing_params = list(params_need_to_load - loaded_params)
+ param_name_example = missing_params[0]
+ raise RuntimeError(
+ "Some parameters like "
+ f"{param_name_example} are not in the checkpoint and will falsely "
+ "use random initialization"
+ )
+ return loaded_params
+
+ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
+ """
+ Rewrite the weight name to match the format of the original model.
+ Add .mtp_block for modules in transformer layer block for spec layer
+ """
+ spec_layer_weight_names = [
+ "embed_tokens",
+ "enorm",
+ "hnorm",
+ "eh_proj",
+ "shared_head",
+ ]
+ spec_layer_weight = False
+ for weight_name in spec_layer_weight_names:
+ if weight_name in name:
+ spec_layer_weight = True
+ break
+ if not spec_layer_weight:
+ # treat rest weights as weights for transformer layer block
+ name = name.replace(
+ f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
+ )
+ return name
diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py
index 7201d3fbdfcd..8be56b56e9ca 100644
--- a/vllm/reasoning/__init__.py
+++ b/vllm/reasoning/__init__.py
@@ -84,6 +84,10 @@
"step3_reasoning_parser",
"Step3ReasoningParser",
),
+ "step3p5": (
+ "step3p5_reasoning_parser",
+ "Step3p5ReasoningParser",
+ ),
}
diff --git a/vllm/reasoning/step3p5_reasoning_parser.py b/vllm/reasoning/step3p5_reasoning_parser.py
new file mode 100644
index 000000000000..b93f551426fb
--- /dev/null
+++ b/vllm/reasoning/step3p5_reasoning_parser.py
@@ -0,0 +1,153 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from collections.abc import Sequence
+
+from vllm.entrypoints.openai.chat_completion.protocol import (
+ ChatCompletionRequest,
+)
+from vllm.entrypoints.openai.engine.protocol import DeltaMessage
+from vllm.entrypoints.openai.responses.protocol import (
+ ResponsesRequest,
+)
+from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
+from vllm.tokenizers import TokenizerLike
+
+
+class Step3p5ReasoningParser(BaseThinkingReasoningParser):
+ """
+ Reasoning parser for Step3p5 model.
+
+ Step3p5 uses the ... format, but it tends to emit an extra
+ newline immediately before and/or after the token. This parser trims:
+ - the newline right before
+ - the newline right after
+ """
+
+ @property
+ def start_token(self) -> str:
+ return ""
+
+ @property
+ def end_token(self) -> str:
+ return ""
+
+ def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
+ super().__init__(tokenizer, *args, **kwargs)
+
+ # Used to hold a trailing "\n" from reasoning content so we can decide
+ # whether it is immediately before .
+ self._pending_reasoning_newline = False
+
+ # Used to delay the reasoning end detection.
+ # This is necessary to remove the newline appears immediately after ,
+ # which may cause the end detection to be delayed by one round.
+ self.end_offset = 1
+
+ def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
+ if self.end_token_id in input_ids and self.end_offset > 0:
+ self.end_offset -= 1
+ return False
+ return self.end_offset < 1
+
+ def is_reasoning_end_streaming(
+ self, input_ids: Sequence[int], delta_ids: Sequence[int]
+ ) -> bool:
+ if self.end_token_id in input_ids and self.end_offset > 0:
+ self.end_offset -= 1
+ return False
+ return self.end_offset < 1
+
+ def extract_reasoning(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest | ResponsesRequest,
+ ) -> tuple[str | None, str | None]:
+ reasoning, content = super().extract_reasoning(model_output, request)
+ if reasoning is not None:
+ reasoning = reasoning.removesuffix("\n")
+ if content is not None:
+ content = content.removeprefix("\n")
+ return reasoning or None, content or None
+
+ def extract_reasoning_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ ) -> DeltaMessage | None:
+ # Drop the immediate newline that models often emit after .
+ if previous_text.endswith(self.end_token) and delta_text:
+ if delta_text == "\n":
+ return None
+ elif delta_text.startswith("\n"):
+ remaining = delta_text.removeprefix("\n")
+ return DeltaMessage(content=remaining) if remaining else None
+
+ ret = super().extract_reasoning_streaming(
+ previous_text,
+ current_text,
+ delta_text,
+ previous_token_ids,
+ current_token_ids,
+ delta_token_ids,
+ )
+
+ if ret is None:
+ return None
+
+ # Compatibility path for models that don't generate the start token:
+ # treat everything before as reasoning and everything after
+ # as content.
+ if (
+ self.start_token_id not in previous_token_ids
+ and self.start_token_id not in delta_token_ids
+ ):
+ if self.end_token_id in delta_token_ids:
+ end_index = delta_text.find(self.end_token)
+ reasoning = delta_text[:end_index]
+ content = delta_text[end_index + len(self.end_token) :]
+ ret = DeltaMessage(reasoning=reasoning, content=content or None)
+ elif self.end_token_id in previous_token_ids:
+ ret = DeltaMessage(content=delta_text)
+ else:
+ ret = DeltaMessage(reasoning=delta_text)
+
+ reasoning_to_output = ret.reasoning
+ content_to_output = ret.content
+
+ # Reasoning: handle the newline immediately before .
+ if reasoning_to_output is not None:
+ if self._pending_reasoning_newline:
+ reasoning_to_output = "\n" + reasoning_to_output
+ self._pending_reasoning_newline = False
+
+ if reasoning_to_output.endswith("\n"):
+ reasoning_to_output = reasoning_to_output.removesuffix("\n")
+ if self.end_token in delta_text:
+ # Trailing "\n" is right before , drop it.
+ self._pending_reasoning_newline = False
+ else:
+ # Hold the trailing "\n" until we know whether follows.
+ self._pending_reasoning_newline = True
+
+ # Content: handle the newline immediately after .
+ if content_to_output is not None:
+ # No need to get into parser again to remove newline after .
+ self.end_offset -= 1
+
+ # If we have content, reasoning must have ended.
+ self._pending_reasoning_newline = False
+
+ if self.end_token in delta_text and content_to_output.startswith("\n"):
+ content_to_output = content_to_output.removeprefix("\n")
+
+ reasoning_to_output = reasoning_to_output or None
+ content_to_output = content_to_output or None
+ if reasoning_to_output is None and content_to_output is None:
+ return None
+
+ return DeltaMessage(reasoning=reasoning_to_output, content=content_to_output)
diff --git a/vllm/tool_parsers/__init__.py b/vllm/tool_parsers/__init__.py
index b26638c0959b..c1a39f2afa02 100644
--- a/vllm/tool_parsers/__init__.py
+++ b/vllm/tool_parsers/__init__.py
@@ -134,6 +134,10 @@
"step3_tool_parser",
"Step3ToolParser",
),
+ "step3p5": (
+ "step3p5_tool_parser",
+ "Step3p5ToolParser",
+ ),
"xlam": (
"xlam_tool_parser",
"xLAMToolParser",
diff --git a/vllm/tool_parsers/step3p5_tool_parser.py b/vllm/tool_parsers/step3p5_tool_parser.py
new file mode 100644
index 000000000000..b7c8699a03db
--- /dev/null
+++ b/vllm/tool_parsers/step3p5_tool_parser.py
@@ -0,0 +1,1511 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import ast
+import json
+from collections.abc import Sequence
+from typing import Any
+from xml.parsers.expat import ParserCreate
+
+import regex as re
+
+from vllm.entrypoints.chat_utils import make_tool_call_id
+from vllm.entrypoints.openai.chat_completion.protocol import (
+ ChatCompletionRequest,
+ ChatCompletionToolsParam,
+)
+from vllm.entrypoints.openai.engine.protocol import (
+ DeltaFunctionCall,
+ DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall,
+ ToolCall,
+)
+from vllm.logger import init_logger
+from vllm.tokenizers import TokenizerLike
+from vllm.tool_parsers.abstract_tool_parser import (
+ ToolParser,
+ ToolParserManager,
+)
+
+logger = init_logger(__name__)
+
+
+class StreamingXMLToolCallParser:
+ """
+ Simplified streaming XML tool call parser
+ Supports streaming input, parsing, and output
+ """
+
+ def __init__(self):
+ self.reset_streaming_state()
+
+ # Tool configuration information
+ self.tools: list[ChatCompletionToolsParam] | None = None
+ self.tool_call_start_token: str = ""
+ self.tool_call_end_token: str = ""
+ self.function_start_token: str = " DeltaMessage:
+ """
+ Parse single streaming XML chunk and return Delta response
+ This is the actual streaming interface that receives chunks
+ one by one and maintains internal state
+
+ Args:
+ xml_chunk: Single XML chunk string
+ Returns:
+ DeltaMessage: Contains delta information generated by this chunk,
+ returns empty response if no complete elements
+ """
+ # Record delta count before processing
+ initial_delta_count = len(self.deltas)
+
+ self.streaming_buffer += xml_chunk
+
+ found_elements = self._process_complete_xml_elements()
+
+ if found_elements:
+ # If complete elements found, check if end events were missed
+ # some tags may not have been triggered
+ try:
+ new_deltas = self.deltas[initial_delta_count:]
+ # If this chunk contains
+ # but didn't generate '}', then complete it
+ if (
+ self.current_call_id is not None
+ and self.function_end_token in xml_chunk
+ ):
+ # - Added '}' (non-empty parameter ending)
+ # - Added '{}' (empty parameter function)
+ has_function_close = any(
+ (
+ td.tool_calls
+ and any(
+ (
+ tc.function
+ and tc.id == self.current_call_id
+ and isinstance(tc.function.arguments, str)
+ and (tc.function.arguments in ("}", "{}"))
+ )
+ for tc in td.tool_calls
+ )
+ )
+ for td in new_deltas
+ )
+ if not has_function_close:
+ # Close potentially unclosed element
+ if self.current_param_name:
+ self._end_element("parameter")
+ if self.current_function_name:
+ self._end_element("function")
+ # If this chunk contains
+ # but didn't generate final empty delta, then complete it
+ if (
+ self.current_call_id is not None
+ and self.tool_call_end_token in xml_chunk
+ ):
+ has_toolcall_close = any(
+ (
+ td.tool_calls
+ and any(
+ (
+ tc.type == "function"
+ and tc.function
+ and tc.function.arguments == ""
+ and tc.id == self.current_call_id
+ )
+ for tc in td.tool_calls
+ )
+ )
+ for td in new_deltas
+ )
+ if not has_toolcall_close:
+ # Close potentially unclosed element
+ if self.current_param_name:
+ self._end_element("parameter")
+ if self.current_function_name:
+ self._end_element("function")
+ self._end_element("tool_call")
+ except Exception as e:
+ logger.warning("Error with fallback parsing: %s", e)
+ # Merge newly generated deltas into single response
+ result_delta = self._merge_new_deltas_to_single_response(
+ initial_delta_count
+ )
+ return result_delta
+ else:
+ # No complete elements, check if there's unoutput text content
+ if self.text_content_buffer and self.tool_call_index == 0:
+ # Has text content but no tool_call yet, output text content
+ text_delta = DeltaMessage(content=self.text_content_buffer)
+ self._emit_delta(text_delta)
+ # Clear buffer to avoid duplicate output
+ self.text_content_buffer = ""
+ return text_delta
+
+ # If this chunk contains end tags but wasn't triggered by parser,
+ # manually complete end events
+ # Only execute when still on the same call as when entered,
+ # to prevent accidentally closing new calls
+ # in multi scenarios
+ if self.current_call_id is not None and (
+ self.function_end_token in xml_chunk
+ or self.tool_call_end_token in xml_chunk
+ ):
+ # Close potentially unclosed element
+ if self.current_param_name:
+ self._end_element("parameter")
+ if self.function_end_token in xml_chunk and self.current_function_name:
+ self._end_element("function")
+ if self.tool_call_end_token in xml_chunk:
+ self._end_element("tool_call")
+ # Return the merged delta result generated by this fallback
+ result_delta = self._merge_new_deltas_to_single_response(
+ initial_delta_count
+ )
+ return result_delta
+
+ # No complete elements, return empty response
+ return DeltaMessage(content=None)
+
+ def _escape_xml_special_chars(self, text: str) -> str:
+ """
+ Escape XML special characters
+ Args:
+ text: Original text
+ Returns:
+ Escaped text
+ """
+ xml_escapes = {
+ "&": "&",
+ "<": "<",
+ ">": ">",
+ '"': """,
+ "'": "'",
+ }
+
+ for char, escape in xml_escapes.items():
+ text = text.replace(char, escape)
+
+ return text
+
+ def _process_complete_xml_elements(self) -> bool:
+ """
+ Process complete XML elements in buffer
+
+ Returns:
+ bool: Whether complete elements were found and processed
+ """
+ found_any = False
+
+ while self.last_processed_pos < len(self.streaming_buffer):
+ # Find next complete xml element
+ element, end_pos = self._find_next_complete_element(self.last_processed_pos)
+ if element is None:
+ # No complete element found, wait for more data
+ break
+
+ # Check if this element should be skipped
+ if self._should_skip_element(element):
+ self.last_processed_pos = end_pos
+ continue
+
+ # Found complete XML element, process it
+ try:
+ preprocessed_element = self._preprocess_xml_chunk(element)
+ # Check if this is the first tool_call start
+ if (
+ (
+ preprocessed_element.strip().startswith("")
+ or preprocessed_element.strip().startswith("")
+ and self.tool_call_index > 0
+ and self.current_call_id
+ and self.current_function_name
+ ):
+ # Reset parser state but preserve generated deltas
+ if self.current_param_name:
+ self._end_element("parameter")
+ if self.current_function_open:
+ self._end_element("function")
+ # Output final tool_call tail delta
+ final_delta = DeltaMessage(
+ role=None,
+ content=None,
+ reasoning_content=None,
+ tool_calls=[
+ DeltaToolCall(
+ index=self.tool_call_index - 1,
+ id=self.current_call_id,
+ type="function",
+ function=DeltaFunctionCall(name=None, arguments=""),
+ )
+ ],
+ )
+ self._emit_delta(final_delta)
+ # Reset XML parser and current call state
+ self._reset_xml_parser_after_tool_call()
+ # Parse preprocessed element
+ self.parser.Parse(preprocessed_element, False)
+ found_any = True
+
+ except Exception as e:
+ logger.warning("Error when parsing XML elements: %s", e)
+
+ # Update processed position
+ self.last_processed_pos = end_pos
+
+ return found_any
+
+ def _fix_incomplete_tag_in_chunk(self, chunk: str) -> str:
+ """
+ Fallback: fix incomplete )
+ Examples: ,
+ Also handles missing = cases: -> ,
+ ->
+ Only fixes tags that pass validation (parameter exists in tool definition)
+ """
+ # First, handle missing = cases for function tags
+ chunk = self._fix_missing_equals_in_function_tag(chunk)
+
+ for tag_type in ["parameter", "function"]:
+ pattern = f"<{tag_type}="
+ if pattern not in chunk:
+ continue
+
+ start_idx = chunk.find(pattern)
+ after_tag = chunk[start_idx:]
+ gt_pos = after_tag.find(">")
+ lt_pos = after_tag.find("<", len(pattern))
+
+ # Skip if already well-formed
+ if (
+ gt_pos != -1
+ and (lt_pos == -1 or gt_pos < lt_pos)
+ and pattern in after_tag[:gt_pos]
+ ):
+ continue
+
+ # Extract tag name (stop at space, newline, or <)
+ content = chunk[start_idx + len(pattern) :]
+ end_pos = next(
+ (i for i, ch in enumerate(content) if ch in (" ", "\n", "<")),
+ len(content),
+ )
+ tag_name = content[:end_pos]
+
+ if not tag_name:
+ continue
+
+ # Remove duplicate prefix: ", 1
+ )
+
+ return chunk
+
+ def _fix_missing_equals_in_function_tag(self, chunk: str) -> str:
+ """
+ Fix missing = in function tags: or
+ Examples:
+ ->
+ ->
+ Only fixes if function name exists in tool definition
+ """
+ # already correct
+ if " (with space/newline but no =)
+ pattern1 = r""
+ match1 = re.search(pattern1, chunk)
+ if match1:
+ func_name = match1.group(1).strip()
+ # must validate function name exists before fixing
+ if func_name and self._validate_function_name(func_name):
+ original = match1.group(0)
+ fixed = f""
+ chunk = chunk.replace(original, fixed, 1)
+ return chunk
+
+ # Pattern 2: (no space, no =)
+ # only match "
+ match2 = re.search(pattern2, chunk)
+ if match2:
+ func_name = match2.group(1).strip()
+ # must validate function name exists before fixing
+ if func_name and self._validate_function_name(func_name):
+ original = match2.group(0)
+ fixed = f""
+ chunk = chunk.replace(original, fixed, 1)
+ return chunk
+
+ return chunk
+
+ def _validate_function_name(self, func_name: str) -> bool:
+ """Check if function name exists in tool definitions"""
+ if not self.tools:
+ return False
+
+ for tool in self.tools:
+ if (
+ hasattr(tool, "type")
+ and tool.type == "function"
+ and hasattr(tool, "function")
+ and hasattr(tool.function, "name")
+ and tool.function.name == func_name
+ ):
+ return True
+
+ return False
+
+ def _validate_parameter_name(self, param_name: str) -> bool:
+ """Check if parameter exists in current function's tool definition"""
+ if not self.tools or not self.current_function_name:
+ return True
+
+ for tool in self.tools:
+ if (
+ hasattr(tool, "type")
+ and tool.type == "function"
+ and hasattr(tool, "function")
+ and hasattr(tool.function, "name")
+ and tool.function.name == self.current_function_name
+ ):
+ if not hasattr(tool.function, "parameters"):
+ return True
+ params = tool.function.parameters
+ if isinstance(params, dict):
+ properties = params.get("properties", params)
+ return param_name in properties
+ break
+
+ return True
+
+ def _should_skip_element(self, element: str) -> bool:
+ """
+ Determine whether an element should be skipped
+
+ Args:
+ element: Element to evaluate
+
+ Returns:
+ bool: True means should skip, False means should process
+ """
+
+ # If it's a tool_call XML tag, don't skip
+ if (
+ element.startswith(self.tool_call_start_token)
+ or element.startswith(self.function_start_token)
+ or element.startswith(self.parameter_start_token)
+ ):
+ return False
+
+ # If currently not parsing tool calls and not blank,
+ # collect this text instead of skipping
+ # Only process other XML elements after tool_call appears,
+ # otherwise treat as plain text
+ if self.current_call_id is None and element:
+ # Collect text content to buffer
+ self.text_content_buffer += element
+ return True # Still skip, but content has been collected
+
+ # If currently parsing tool calls,
+ # this might be parameter value, don't skip
+ if self.current_call_id is not None:
+ return False
+
+ # Skip blank content
+ return not element
+
+ def _find_next_complete_element(self, start_pos: int) -> tuple[str | None, int]:
+ """
+ Find next complete XML element from specified position
+
+ Args:
+ start_pos: Position to start searching
+
+ Returns:
+ (Complete element string, element end position),
+ returns (None, start_pos) if no complete element found
+ """
+ buffer = self.streaming_buffer[start_pos:]
+
+ if not buffer:
+ return None, start_pos
+
+ if buffer.startswith("<"):
+ # Check if this is an incomplete parameter/function tag
+ # e.g., " not in buffer.split("\n")[0]
+ )
+ is_incomplete_func = (
+ buffer.startswith("" not in buffer.split("\n")[0]
+ )
+
+ if is_incomplete_param or is_incomplete_func:
+ # Find the corresponding closing tag
+ tag_type = "parameter" if is_incomplete_param else "function"
+ closing_tag = f"{tag_type}>"
+ closing_pos = buffer.find(closing_tag)
+
+ if closing_pos != -1:
+ # Found closing tag, return complete element including closing tag
+ complete_element = buffer[: closing_pos + len(closing_tag)]
+ return complete_element, start_pos + closing_pos + len(closing_tag)
+
+ # Need to ensure no new < appears,
+ # find the nearest one between < and >
+ tag_end = buffer.find("<", 1)
+ tag_end2 = buffer.find(">", 1)
+ if tag_end != -1 and tag_end2 != -1:
+ # Next nearest is <
+ if tag_end < tag_end2:
+ return buffer[:tag_end], start_pos + tag_end
+ # Next nearest is >, means found XML element
+ else:
+ return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
+ elif tag_end != -1:
+ return buffer[:tag_end], start_pos + tag_end
+ elif tag_end2 != -1:
+ return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
+ else:
+ # If currently not parsing tool calls (entering a tool_call),
+ # check if starts with or
+ if buffer == ""[: len(buffer)]:
+ # Might be start of , wait for more data
+ return None, start_pos
+ elif (
+ buffer.startswith(" or DeltaMessage:
+ """
+ Merge newly generated deltas from this processing
+ into a single DeltaMessage
+
+ Args:
+ initial_count: Delta count before processing
+
+ Returns:
+ Merged DeltaMessage containing all newly generated delta information
+ """
+ if len(self.deltas) <= initial_count:
+ return DeltaMessage(content=None)
+
+ # Get newly generated deltas
+ new_deltas = self.deltas[initial_count:]
+
+ if len(new_deltas) == 1:
+ # Only one new delta, return directly
+ return new_deltas[0]
+
+ # Merge multiple new deltas
+ merged_tool_calls: list[DeltaToolCall] = []
+ merged_content: str = ""
+
+ for delta in new_deltas:
+ if delta.content:
+ merged_content += delta.content
+ if delta.tool_calls:
+ # For tool_calls, we need to intelligently merge arguments
+ for tool_call in delta.tool_calls:
+ # Find if there's already a tool_call with the same call_id
+ existing_call = None
+ for existing in merged_tool_calls:
+ if existing.id == tool_call.id:
+ existing_call = existing
+ break
+
+ if existing_call and existing_call.function:
+ # Merge to existing tool_call
+ if tool_call.function and tool_call.function.name:
+ existing_call.function.name = tool_call.function.name
+ if (
+ tool_call.function
+ and tool_call.function.arguments is not None
+ ):
+ if existing_call.function.arguments is None:
+ existing_call.function.arguments = ""
+
+ # For streaming JSON parameters,
+ # simply concatenate in order
+ new_args = tool_call.function.arguments
+ existing_call.function.arguments += new_args
+ if tool_call.type:
+ existing_call.type = tool_call.type
+ else:
+ # Add new tool_call
+ merged_tool_calls.append(tool_call)
+
+ return DeltaMessage(
+ content=merged_content if merged_content else None,
+ tool_calls=merged_tool_calls,
+ )
+
+ def _preprocess_xml_chunk(self, chunk: str) -> str:
+ """
+ Preprocess XML chunk, handle non-standard formats,
+ and escape special characters
+
+ Args:
+ chunk: Original XML chunk
+
+ Returns:
+ Processed XML chunk
+ """
+
+ # Check if this is a tool_call related element
+ is_tool_call = False
+ if chunk.startswith(self.tool_call_start_token) or chunk.startswith(
+ self.tool_call_end_token
+ ):
+ is_tool_call = True
+ # Check for function tags (including malformed ones without =)
+ # , , ,
+ if (
+ chunk.startswith(self.function_start_token)
+ or chunk.startswith(self.function_end_token)
+ or chunk.startswith("
+ # This handles cases like: format ->
+ processed = re.sub(r"]+)>", r'', chunk)
+ # Handle format ->
+ processed = re.sub(r"]+)>", r'', processed)
+
+ original_chunk = chunk
+ # If in parameter value accumulation mode
+ if self._pre_inside_parameter:
+ # Parameter end: output accumulated raw text
+ # safely then return
+ if processed.startswith(""):
+ body_text = self._pre_param_buffer
+ # Trigger deferred parsing mode
+ # literal_eval+json output in end_element
+ self.defer_current_parameter = True
+ self.deferred_param_raw_value = body_text
+ # Clean up state
+ self._pre_inside_parameter = False
+ self._pre_param_buffer = ""
+ self._pre_current_param_name = None
+ safe_text = self._escape_xml_special_chars(body_text)
+ return f"{safe_text}"
+ else:
+ # If this is the first block of content after entering parameter
+ # evaluate if deferred parsing is needed;
+ # If not needed, exit accumulation mode
+ # and pass through directly
+ if self._pre_param_buffer == "":
+ # Get current parameter type
+ param_type = (
+ self._get_param_type(self._pre_current_param_name)
+ if self._pre_current_param_name
+ else "string"
+ )
+ # Only these types need deferred parsing to
+ # handle Python literals containing single quotes
+ is_object_type = param_type in ["object"]
+ is_complex_type = (
+ param_type in ["array", "arr", "sequence"]
+ or param_type.startswith("dict")
+ or param_type.startswith("list")
+ )
+
+ # Only delay when contains container symbols
+ # and has single quotes and is complex type
+ has_container_hint = (
+ ("[" in original_chunk)
+ or ("{" in original_chunk)
+ or ("(" in original_chunk)
+ )
+
+ # Determine if deferred parsing is needed
+ need_defer = False
+ if is_complex_type:
+ # Complex type, always need deferred parsing
+ need_defer = True
+ elif (
+ is_object_type
+ and has_container_hint
+ and ("'" in original_chunk)
+ ):
+ # Object type with container symbols
+ # and single quotes, need deferred parsing
+ need_defer = True
+
+ if not need_defer:
+ # No need for deferred parsing,
+ # exit parameter mode directly
+ self._pre_inside_parameter = False
+ return self._escape_xml_special_chars(original_chunk)
+ self._pre_param_buffer += original_chunk
+ return ""
+
+ # Parameter start: enable accumulation
+ if processed.startswith("', processed)
+ if m:
+ self._pre_current_param_name = m.group(1)
+ self._pre_inside_parameter = True
+ self._pre_param_buffer = ""
+ return processed
+
+ # If processed doesn't contain special_token, escape processed
+ # This is because XML parsing encounters special characters
+ # and reports errors, so escaping is needed
+ if not is_tool_call:
+ processed = self._escape_xml_special_chars(processed)
+ return processed
+
+ def _emit_delta(self, delta: DeltaMessage):
+ """Emit Delta response (streaming output)"""
+ self.deltas.append(delta)
+
+ def _auto_close_open_parameter_if_needed(self, incoming_tag: str | None = None):
+ """Before starting to process new elements,
+ if there are unclosed tags from before,
+ automatically complete their endings to the parser.
+ - If there are unclosed parameters,
+ it's equivalent to feeding ``
+ - When about to start a new function or tool_call,
+ if there are unclosed functions, complete ``.
+ - When about to start a new tool_call,
+ if there are unclosed tool_calls, complete ``.
+ """
+ # First close unclosed parameters
+ if self.current_param_name:
+ self._end_element("parameter")
+
+ # If about to start new function or tool_call,
+ # and there are unclosed functions, close function first
+ if incoming_tag in ("function", "tool_call") and self.current_function_name:
+ self._end_element("function")
+
+ # If about to start new tool_call,
+ # and there are unclosed tool_calls, close tool_call first
+ if incoming_tag == "tool_call" and self.current_call_id:
+ self._end_element("tool_call")
+
+ def _start_element(self, name: str, attrs: dict[str, str]):
+ """Handle XML start element events"""
+
+ if name == "root":
+ return
+
+ if name == "tool_call":
+ # Before opening new tool_call,
+ # automatically complete previous unclosed tags
+ self._auto_close_open_parameter_if_needed("tool_call")
+
+ self.parameters = {}
+ self.current_call_id = make_tool_call_id()
+ self.current_param_is_first = True
+ self.tool_call_index += 1
+ elif name.startswith("function") or (name == "function"):
+ # If missing tool_call, manually complete
+ if not self.current_call_id:
+ self._start_element("tool_call", {})
+ # Before opening new function,
+ # automatically complete previous unclosed tags (parameter/function)
+ self._auto_close_open_parameter_if_needed("function")
+ function_name = self._extract_function_name(name, attrs)
+ self.current_function_name = function_name
+ self.current_function_open = True
+ if function_name:
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.tool_call_index - 1,
+ id=self.current_call_id,
+ type="function",
+ function=DeltaFunctionCall(
+ name=function_name, arguments=""
+ ),
+ )
+ ]
+ )
+ self._emit_delta(delta)
+ elif name.startswith("parameter") or (name == "parameter"):
+ # If previous parameter hasn't ended normally,
+ # complete its end first, then start new parameter
+ self._auto_close_open_parameter_if_needed("parameter")
+ param_name = self._extract_parameter_name(name, attrs)
+ self.current_param_name = param_name
+ self.current_param_value = ""
+ self.current_param_value_converted = ""
+ self.start_quote_emitted = False # Reset start quote flag
+
+ # Only output parameter name and colon,
+ # don't output quotes
+ # decide after parameter value type is determined
+ if param_name:
+ if not self.parameters:
+ # First parameter
+ # start JSON, only output parameter name and colon
+ json_start = f'{{"{param_name}": '
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.tool_call_index - 1,
+ id=self.current_call_id,
+ type="function",
+ function=DeltaFunctionCall(
+ name=None, arguments=json_start
+ ),
+ )
+ ]
+ )
+ self._emit_delta(delta)
+ self.current_param_is_first = True
+ else:
+ # Subsequent parameters
+ # add comma and parameter name, no quotes
+ json_continue = f', "{param_name}": '
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.tool_call_index - 1,
+ id=self.current_call_id,
+ type="function",
+ function=DeltaFunctionCall(
+ name=None, arguments=json_continue
+ ),
+ )
+ ]
+ )
+ self._emit_delta(delta)
+ self.current_param_is_first = False
+
+ def _char_data(self, data: str):
+ """Handle XML character data events"""
+ if data and self.current_param_name:
+ # If preprocessing stage determines deferred parsing is needed,
+ # only cache character data, no streaming output
+ if self.defer_current_parameter:
+ original_data = data
+ if self.should_emit_end_newline:
+ original_data = "\n" + original_data
+ self.should_emit_end_newline = False
+ if original_data.endswith("\n"):
+ self.should_emit_end_newline = True
+ original_data = original_data[:-1]
+ self.current_param_value += original_data
+ return
+
+ param_type = self._get_param_type(self.current_param_name)
+
+ # Check if this is the first time receiving data for this parameter
+ # If this is the first packet of data and starts with \n, remove \n
+ if not self.current_param_value and data.startswith("\n"):
+ data = data[1:]
+
+ # Output start quote for string type (if not already output)
+ if (
+ param_type in ["string", "str", "text", "varchar", "char", "enum"]
+ and not self.start_quote_emitted
+ ):
+ quote_delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.tool_call_index - 1,
+ id=self.current_call_id,
+ type="function",
+ function=DeltaFunctionCall(name=None, arguments='"'),
+ )
+ ]
+ )
+ self._emit_delta(quote_delta)
+ self.start_quote_emitted = True
+
+ if not data:
+ return
+
+ original_data = data
+ # Delay output of trailing newline
+ if self.should_emit_end_newline:
+ original_data = "\n" + original_data
+ self.should_emit_end_newline = False
+ if original_data.endswith("\n"):
+ self.should_emit_end_newline = True
+ original_data = original_data[:-1]
+ self.current_param_value += original_data
+
+ # convert parameter value by param_type
+ converted_value = self._convert_param_value(
+ self.current_param_value, param_type
+ )
+ output_data = self._convert_for_json_streaming(converted_value, param_type)
+
+ delta_data = output_data[len(self.current_param_value_converted) :]
+ self.current_param_value_converted = output_data
+
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.tool_call_index - 1,
+ id=self.current_call_id,
+ type="function",
+ function=DeltaFunctionCall(name=None, arguments=delta_data),
+ )
+ ]
+ )
+ self._emit_delta(delta)
+
+ def _end_element(self, name: str):
+ """Handle XML end element events"""
+
+ if name == "root":
+ return
+
+ # If function or tool_call ends and there are still unclosed parameters,
+ # complete parameter end first
+ if (
+ name.startswith("function") or name == "function" or name == "tool_call"
+ ) and self.current_param_name:
+ self._auto_close_open_parameter_if_needed()
+
+ if (
+ name.startswith("parameter") or name == "parameter"
+ ) and self.current_param_name:
+ # End current parameter
+ param_name = self.current_param_name
+ param_value = self.current_param_value
+
+ # If in deferred parsing mode,
+ # perform overall parsing on raw content
+ # accumulated in preprocessing stage and output once
+ if self.defer_current_parameter:
+ raw_text = (
+ self.deferred_param_raw_value
+ if self.deferred_param_raw_value
+ else param_value
+ )
+ parsed_value = None
+ output_arguments = None
+ try:
+ # If previously delayed trailing newline,
+ # add it back before parsing
+ if self.should_emit_end_newline:
+ raw_for_parse = raw_text + "\n"
+ else:
+ raw_for_parse = raw_text
+ parsed_value = ast.literal_eval(raw_for_parse)
+ output_arguments = json.dumps(parsed_value, ensure_ascii=False)
+ except Exception:
+ # Fallback: output as string as-is
+ output_arguments = json.dumps(raw_text, ensure_ascii=False)
+ parsed_value = raw_text
+
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.tool_call_index - 1,
+ id=self.current_call_id,
+ type="function",
+ function=DeltaFunctionCall(
+ name=None, arguments=output_arguments
+ ),
+ )
+ ]
+ )
+ self._emit_delta(delta)
+
+ # Clean up and store
+ self.should_emit_end_newline = False
+ self.parameters[param_name] = parsed_value
+ self.current_param_name = None
+ self.current_param_value = ""
+ self.current_param_value_converted = ""
+ self.start_quote_emitted = False
+ self.defer_current_parameter = False
+ self.deferred_param_raw_value = ""
+ return
+
+ param_type = self._get_param_type(param_name)
+
+ # convert complete parameter value by param_type
+ converted_value = self._convert_param_value(param_value, param_type)
+
+ # Decide whether to add end quote based on parameter type
+ if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
+ # For empty string parameters, need special handling
+ if not param_value and not self.start_quote_emitted:
+ # No start quote output,
+ # directly output complete empty string
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.tool_call_index - 1,
+ id=self.current_call_id,
+ type="function",
+ function=DeltaFunctionCall(name=None, arguments='""'),
+ )
+ ]
+ )
+ self._emit_delta(delta)
+ else:
+ # Non-empty parameter value, output end quote
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.tool_call_index - 1,
+ id=self.current_call_id,
+ type="function",
+ function=DeltaFunctionCall(name=None, arguments='"'),
+ )
+ ]
+ )
+ self._emit_delta(delta)
+
+ self.should_emit_end_newline = False
+ # Store converted value
+ self.parameters[param_name] = converted_value
+ self.current_param_name = None
+ self.current_param_value = ""
+ self.current_param_value_converted = ""
+ self.start_quote_emitted = False
+
+ elif name.startswith("function") or name == "function":
+ # if there are parameters, close JSON object
+ if self.parameters:
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.tool_call_index - 1,
+ id=self.current_call_id,
+ type="function",
+ function=DeltaFunctionCall(name=None, arguments="}"),
+ )
+ ]
+ )
+ self._emit_delta(delta)
+ # return empty object
+ else:
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.tool_call_index - 1,
+ id=self.current_call_id,
+ type="function",
+ function=DeltaFunctionCall(name=None, arguments="{}"),
+ )
+ ]
+ )
+ self._emit_delta(delta)
+ self.current_function_open = False
+ self.current_function_name = (
+ None # Clear function name to prevent duplicate closing
+ )
+
+ elif name == "tool_call":
+ # Before ending tool_call,
+ # ensure function is closed to complete missing right brace
+ if self.current_function_open:
+ # If there are still unclosed parameters, close them first
+ if self.current_param_name:
+ self._end_element("parameter")
+ # Close function, ensure output '}' or '{}'
+ self._end_element("function")
+ # Final Delta
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.tool_call_index - 1,
+ id=self.current_call_id,
+ type="function",
+ function=DeltaFunctionCall(name=None, arguments=""),
+ )
+ ]
+ )
+ self._emit_delta(delta)
+
+ # Check if there's text content to output (between tool_calls)
+ if self.text_content_buffer.strip():
+ text_delta = DeltaMessage(content=self.text_content_buffer)
+ self._emit_delta(text_delta)
+
+ self._reset_xml_parser_after_tool_call()
+
+ def setup_parser(self):
+ """Set up XML parser event handlers"""
+ self.parser.buffer_text = True
+ self.parser.StartElementHandler = self._start_element
+ self.parser.EndElementHandler = self._end_element
+ self.parser.CharacterDataHandler = self._char_data
+
+ def set_tools(self, tools: list[ChatCompletionToolsParam] | None):
+ """Set tool configuration information"""
+ self.tools = tools
+
+ def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None:
+ """Extract function name from various formats"""
+ if attrs and "name" in attrs:
+ return attrs["name"]
+
+ if "=" in name:
+ parts = name.split("=", 1)
+ if len(parts) == 2 and parts[0] == "function":
+ return parts[1]
+
+ return None
+
+ def _extract_parameter_name(self, name: str, attrs: dict[str, str]) -> str | None:
+ """Extract parameter name from various formats"""
+ if attrs and "name" in attrs:
+ return attrs["name"]
+
+ if "=" in name:
+ parts = name.split("=", 1)
+ if len(parts) == 2 and parts[0] == "parameter":
+ return parts[1]
+
+ return None
+
+ def _get_param_type(self, param_name: str) -> str:
+ """Get parameter type based on tool configuration, defaults to string
+ Args:
+ param_name: Parameter name
+
+ Returns:
+ Parameter type
+ """
+ if not self.tools or not self.current_function_name:
+ return "string"
+
+ for tool in self.tools:
+ if not hasattr(tool, "type") or not (
+ hasattr(tool, "function") and hasattr(tool.function, "name")
+ ):
+ continue
+ if (
+ tool.type == "function"
+ and tool.function.name == self.current_function_name
+ ):
+ if not hasattr(tool.function, "parameters"):
+ return "string"
+ params = tool.function.parameters
+ if isinstance(params, dict) and "properties" in params:
+ properties = params["properties"]
+ if param_name in properties and isinstance(
+ properties[param_name], dict
+ ):
+ return self.repair_param_type(
+ str(properties[param_name].get("type", "string"))
+ )
+ elif isinstance(params, dict) and param_name in params:
+ param_config = params[param_name]
+ if isinstance(param_config, dict):
+ return self.repair_param_type(
+ str(param_config.get("type", "string"))
+ )
+ break
+ return "string"
+
+ def repair_param_type(self, param_type: str) -> str:
+ """Repair unknown parameter types by treating them as string
+ Args:
+ param_type: Parameter type
+
+ Returns:
+ Repaired parameter type
+ """
+ if (
+ param_type in ["string", "str", "text", "varchar", "char", "enum"]
+ or param_type.startswith("int")
+ or param_type.startswith("uint")
+ or param_type.startswith("long")
+ or param_type.startswith("short")
+ or param_type.startswith("unsigned")
+ or param_type.startswith("num")
+ or param_type.startswith("float")
+ or param_type in ["boolean", "bool", "binary"]
+ or (
+ param_type in ["object", "array", "arr", "sequence"]
+ or param_type.startswith("dict")
+ or param_type.startswith("list")
+ )
+ ):
+ return param_type
+ else:
+ return "string"
+
+ def _convert_param_value(self, param_value: str, param_type: str) -> Any:
+ """Convert value based on parameter type
+ Args:
+ param_value: Parameter value
+ param_type: Parameter type
+
+ Returns:
+ Converted value
+ """
+ if param_value.lower() == "null":
+ return None
+
+ param_type = param_type.strip().lower()
+ if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
+ return param_value
+ elif (
+ param_type.startswith("int")
+ or param_type.startswith("uint")
+ or param_type.startswith("long")
+ or param_type.startswith("short")
+ or param_type.startswith("unsigned")
+ ):
+ try:
+ return int(param_value)
+ except (ValueError, TypeError):
+ logger.warning(
+ "Parsed value '%s' is not an integer, degenerating to string.",
+ param_value,
+ )
+ return param_value
+ elif param_type.startswith("num") or param_type.startswith("float"):
+ try:
+ float_param_value: float = float(param_value)
+ return (
+ float_param_value
+ if float_param_value - int(float_param_value) != 0
+ else int(float_param_value)
+ )
+ except (ValueError, TypeError):
+ logger.warning(
+ "Parsed value '%s' is not a float, degenerating to string.",
+ param_value,
+ )
+ return param_value
+ elif param_type in ["boolean", "bool", "binary"]:
+ param_value = param_value.lower()
+ return param_value == "true"
+ else:
+ return param_value
+
+ def _convert_for_json_streaming(self, converted_value: Any, param_type: str) -> str:
+ """Convert converted_value based on
+ whether it's empty and if type is string
+ Args:
+ converted_value: Converted value
+ param_type: Parameter type
+
+ Returns:
+ Converted string for streaming output
+ """
+ # Check if value is empty, but exclude numeric 0
+ if converted_value is None or converted_value == "":
+ return ""
+
+ if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
+ # String type, remove double quotes
+ return json.dumps(converted_value, ensure_ascii=False)[1:-1]
+ else:
+ # Non-string type, return complete JSON string
+ if not isinstance(converted_value, str):
+ return json.dumps(converted_value, ensure_ascii=False)
+ else:
+ return converted_value
+
+ def _reset_xml_parser_after_tool_call(self):
+ """
+ Each tool_call is treated as a separate XML document,
+ so we need to reset the parser after each tool_call.
+ """
+
+ # recreate XML parser
+ self.parser = ParserCreate()
+ self.setup_parser()
+
+ # Reset current tool_call state
+ if self.current_call_id:
+ self.last_completed_call_id = self.current_call_id
+ self.current_call_id = None
+ self.current_function_name = None
+ self.current_function_open = False
+ self.parameters = {}
+ self.current_param_name = None
+ self.current_param_value = ""
+ self.current_param_value_converted = ""
+ self.current_param_is_first = False
+ self.should_emit_end_newline = False
+ self.start_quote_emitted = False
+ self.text_content_buffer = ""
+
+ # Reset preprocessing and deferred parsing state
+ self._pre_inside_parameter = False
+ self._pre_param_buffer = ""
+ self._pre_current_param_name = None
+ self.defer_current_parameter = False
+ self.deferred_param_raw_value = ""
+
+
+@ToolParserManager.register_module("step3p5")
+class Step3p5ToolParser(ToolParser):
+ def __init__(self, tokenizer: TokenizerLike):
+ super().__init__(tokenizer)
+ self.parser = StreamingXMLToolCallParser()
+
+ # Add missing attributes for compatibility with serving_chat.py
+ self.prev_tool_call_arr: list[dict] = []
+ self.streamed_args_for_tool: list[str] = []
+
+ logger.info(
+ "vLLM Successfully import tool parser %s !", self.__class__.__name__
+ )
+
+ def extract_tool_calls(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest,
+ ) -> ExtractedToolCallInformation:
+ self.parser.reset_streaming_state()
+ # Reset tool call tracking arrays for new extraction
+ self.prev_tool_call_arr = []
+ self.streamed_args_for_tool = []
+ if request:
+ self.parser.set_tools(request.tools)
+ result = self.parser.parse_single_streaming_chunks(model_output)
+ if not result.tool_calls:
+ return ExtractedToolCallInformation(
+ tool_calls=[],
+ tools_called=False,
+ content=result.content,
+ )
+ else:
+ tool_calls = []
+ for tool_call in result.tool_calls:
+ if tool_call.function and tool_call.function.name:
+ tool_calls.append(
+ ToolCall(
+ id=tool_call.id,
+ type=tool_call.type,
+ function=FunctionCall(
+ name=tool_call.function.name,
+ arguments=tool_call.function.arguments,
+ ),
+ )
+ )
+
+ # Update tool call tracking arrays for compatibility
+ tool_index = (
+ tool_call.index
+ if tool_call.index is not None
+ else len(self.prev_tool_call_arr) - 1
+ )
+
+ # Ensure we have enough entries in our tracking arrays
+ while len(self.prev_tool_call_arr) <= tool_index:
+ self.prev_tool_call_arr.append({"name": "", "arguments": ""})
+ while len(self.streamed_args_for_tool) <= tool_index:
+ self.streamed_args_for_tool.append("")
+
+ # Update tool call information
+ self.prev_tool_call_arr[tool_index]["name"] = (
+ tool_call.function.name
+ )
+ self.prev_tool_call_arr[tool_index]["arguments"] = (
+ tool_call.function.arguments
+ )
+
+ # Update streamed arguments
+ if tool_call.function.arguments:
+ self.streamed_args_for_tool[tool_index] = (
+ tool_call.function.arguments
+ )
+
+ return ExtractedToolCallInformation(
+ tool_calls=tool_calls,
+ tools_called=len(tool_calls) > 0,
+ content=result.content,
+ )
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> DeltaMessage | None:
+ if not previous_text:
+ self.parser.reset_streaming_state()
+ # Reset tool call tracking arrays for new streaming session
+ self.prev_tool_call_arr = []
+ self.streamed_args_for_tool = []
+ if request:
+ self.parser.set_tools(request.tools)
+
+ # Model sometimes outputs separately causing delta_text to be empty.
+ # If there were tool_calls before and all current tool_calls have ended,
+ # return an empty tool_call for outer streaming output
+ # to correctly output tool_call field
+ if not delta_text and delta_token_ids:
+ open_calls = current_text.count(
+ self.parser.tool_call_start_token
+ ) - current_text.count(self.parser.tool_call_end_token)
+ if (
+ open_calls == 0
+ and self.parser.tool_call_index > 0
+ or not self.parser.tool_call_index
+ and current_text
+ ):
+ return DeltaMessage(content="")
+ return None
+
+ # Parse the delta text and get the result
+ result = self.parser.parse_single_streaming_chunks(delta_text)
+
+ # Update tool call tracking arrays based on incremental parsing results
+ if result and result.tool_calls:
+ for tool_call in result.tool_calls:
+ if tool_call.function:
+ tool_index = (
+ tool_call.index
+ if tool_call.index is not None
+ else len(self.prev_tool_call_arr) - 1
+ )
+
+ # Ensure we have enough entries in our tracking arrays
+ while len(self.prev_tool_call_arr) <= tool_index:
+ self.prev_tool_call_arr.append({"name": "", "arguments": ""})
+ while len(self.streamed_args_for_tool) <= tool_index:
+ self.streamed_args_for_tool.append("")
+
+ # Update tool name if provided
+ if tool_call.function.name:
+ self.prev_tool_call_arr[tool_index]["name"] = (
+ tool_call.function.name
+ )
+
+ # Update arguments incrementally
+ if tool_call.function.arguments is not None:
+ # Concatenate the incremental arguments
+ # to the existing streamed arguments
+ self.prev_tool_call_arr[tool_index]["arguments"] += (
+ tool_call.function.arguments
+ )
+ self.streamed_args_for_tool[tool_index] += (
+ tool_call.function.arguments
+ )
+ return result
+
+ def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
+ """
+ Skip the remaining_call calculation in serving_chat
+ """
+ return False
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index 6679c8dd5548..5094e3fbb844 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -97,6 +97,7 @@ def __getitem__(self, key):
ultravox="UltravoxConfig",
step3_vl="Step3VLConfig",
step3_text="Step3TextConfig",
+ step3p5="Step3p5Config",
qwen3_asr="Qwen3ASRConfig",
qwen3_next="Qwen3NextConfig",
lfm2_moe="Lfm2MoeConfig",
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index 2f8179602b2e..7cd236532154 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -52,6 +52,7 @@
"Step3VLConfig": "vllm.transformers_utils.configs.step3_vl",
"Step3VisionEncoderConfig": "vllm.transformers_utils.configs.step3_vl",
"Step3TextConfig": "vllm.transformers_utils.configs.step3_vl",
+ "Step3p5Config": "vllm.transformers_utils.configs.step3p5",
"Qwen3ASRConfig": "vllm.transformers_utils.configs.qwen3_asr",
"Qwen3NextConfig": "vllm.transformers_utils.configs.qwen3_next",
"Tarsier2Config": "vllm.transformers_utils.configs.tarsier2",
@@ -95,6 +96,7 @@
"Step3VLConfig",
"Step3VisionEncoderConfig",
"Step3TextConfig",
+ "Step3p5Config",
"Qwen3ASRConfig",
"Qwen3NextConfig",
"Tarsier2Config",
diff --git a/vllm/transformers_utils/configs/step3p5.py b/vllm/transformers_utils/configs/step3p5.py
new file mode 100644
index 000000000000..435afd938212
--- /dev/null
+++ b/vllm/transformers_utils/configs/step3p5.py
@@ -0,0 +1,100 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import Any
+
+from transformers.configuration_utils import PretrainedConfig
+
+
+class Step3p5Config(PretrainedConfig):
+ model_type = "step3p5"
+
+ def __init__(
+ self,
+ hidden_size: int = 5120,
+ intermediate_size: int = 13312,
+ num_attention_heads: int = 40,
+ num_attention_groups: int = 8,
+ num_hidden_layers: int = 48,
+ max_seq_len: int = 4096,
+ vocab_size: int = 65536,
+ rms_norm_eps: float = 1e-5,
+ moe_every_n_layer: int = 2,
+ use_moe: bool = False,
+ moe_intermediate_size: int = 10240,
+ moe_num_experts: int = 16,
+ moe_top_k: int = 4,
+ moe_layer_offset: int = 0,
+ rope_theta: float | list[float] | None = 500000,
+ rope_scaling: dict[str, Any] | None = None,
+ head_dim: int | None = None,
+ share_expert_dim: int | None = None,
+ norm_expert_weight: bool = True,
+ bos_token_id: list[int] | int | None = None,
+ eos_token_id: list[int] | int | None = None,
+ moe_router_activation: str = "softmax",
+ moe_router_scaling_factor: float = 1.0,
+ att_impl_type: str = "GQA",
+ use_head_wise_attn_gate: bool = False,
+ use_moe_router_bias: bool = True,
+ need_fp32_gate: bool = True,
+ layer_types: list[str] | None = None,
+ use_rope_layers: list[bool] | None = None,
+ yarn_only_types: list[str] | None = None,
+ attention_other_setting: dict[str, Any] | None = None,
+ num_nextn_predict_layers: int = 0,
+ swiglu_limits: list[float] | None = None,
+ swiglu_limits_shared: list[float] | None = None,
+ max_position_embeddings: int | None = None,
+ **kwargs,
+ ):
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.num_attention_groups = num_attention_groups
+ self.num_hidden_layers = num_hidden_layers
+ self.max_seq_len = max_seq_len
+ self.vocab_size = vocab_size
+ self.rms_norm_eps = rms_norm_eps
+ self.use_moe = use_moe
+ self.moe_intermediate_size = moe_intermediate_size
+ self.moe_every_n_layer = moe_every_n_layer
+ self.moe_num_experts = moe_num_experts
+ self.num_experts_per_tok = moe_top_k
+ self.moe_top_k = moe_top_k
+ self.moe_layer_offset = moe_layer_offset
+
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.head_dim = head_dim
+ if share_expert_dim is None:
+ self.share_expert_dim = self.moe_intermediate_size * self.moe_top_k
+ else:
+ self.share_expert_dim = share_expert_dim
+ self.norm_expert_weight = norm_expert_weight
+
+ self.max_position_embeddings = max_position_embeddings
+ self.moe_router_activation = moe_router_activation
+ self.moe_router_scaling_factor = moe_router_scaling_factor
+ self.use_moe_router_bias = use_moe_router_bias
+ self.need_fp32_gate = need_fp32_gate
+
+ self.att_impl_type = att_impl_type
+ self.use_head_wise_attn_gate = use_head_wise_attn_gate
+ self.layer_types = layer_types
+ self.use_rope_layers = use_rope_layers
+ self.yarn_only_types = yarn_only_types
+ self.attention_other_setting = attention_other_setting
+ self.num_nextn_predict_layers = num_nextn_predict_layers
+ self.swiglu_limits = swiglu_limits
+ self.swiglu_limits_shared = swiglu_limits_shared
+
+ resolved_bos_token_id = 1 if bos_token_id is None else bos_token_id
+ resolved_eos_token_id = [2, 3] if eos_token_id is None else eos_token_id
+ self.bos_token_id = resolved_bos_token_id
+ self.eos_token_id = resolved_eos_token_id
+
+ super().__init__(
+ bos_token_id=resolved_bos_token_id,
+ eos_token_id=resolved_eos_token_id,
+ **kwargs,
+ )