diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index e79899e6fe7a..14e78d383523 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -456,6 +456,7 @@ th { | `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | | `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | | `Step1ForCausalLM` | Step-Audio | `stepfun-ai/Step-Audio-EditX`, etc. | ✅︎ | ✅︎ | +| `Step3p5ForCausalLM` | Step-3.5-flash | `stepfun-ai/step-3.5-flash`, etc. | | ✅︎ | | `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | | `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | | `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | diff --git a/tests/kernels/core/test_activation.py b/tests/kernels/core/test_activation.py index 8f28e967aec3..66727a3099ee 100644 --- a/tests/kernels/core/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -17,6 +17,8 @@ QuickGELU, SiluAndMul, SwigluOAIAndMul, + SwigluStepAndMul, + swiglustep_and_mul_triton, ) from vllm.utils.torch_utils import set_random_seed @@ -36,6 +38,7 @@ "gelu_tanh", "fatrelu", "swigluoai_and_mul", + "swiglustep_and_mul", ], ) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -75,9 +78,12 @@ def test_act_and_mul( elif activation == "swigluoai_and_mul": layer = SwigluOAIAndMul() fn = torch.ops._C.swigluoai_and_mul + elif activation == "swiglustep_and_mul": + layer = SwigluStepAndMul() + fn = swiglustep_and_mul_triton out = layer(x) ref_out = layer.forward_native(x) - if activation == "swigluoai_and_mul": + if activation in ["swigluoai_and_mul", "swiglustep_and_mul"]: rtol = { # For fp16, change the relative tolerance from 1e-3 to 2e-3 torch.float16: 2e-3, @@ -104,7 +110,7 @@ def _get_rtol(output) -> float: opcheck(fn, (out, x, threshold)) elif activation == "swigluoai_and_mul": opcheck(fn, (out, x, layer.alpha, layer.limit)) - else: + elif activation != "swiglustep_and_mul": opcheck(fn, (out, x)) diff --git a/tests/models/registry.py b/tests/models/registry.py index c2760d37f4cd..1bee16c81d01 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -484,6 +484,9 @@ def check_available_online( "Step1ForCausalLM": _HfExamplesInfo( "stepfun-ai/Step-Audio-EditX", trust_remote_code=True ), + "Step3p5ForCausalLM": _HfExamplesInfo( + "stepfun-ai/step-3.5-flash", is_available_online=False + ), "SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), @@ -1113,6 +1116,12 @@ def check_available_online( "Qwen3NextMTP": _HfExamplesInfo( "Qwen/Qwen3-Next-80B-A3B-Instruct", min_transformers_version="4.56.3" ), + "Step3p5MTP": _HfExamplesInfo( + "stepfun-ai/Step-3.5-Flash", + trust_remote_code=True, + speculative_model="stepfun-ai/Step-3.5-Flash", + is_available_online=False, + ), } _TRANSFORMERS_BACKEND_MODELS = { diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index f3de1e171af2..966d168b47ab 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -41,6 +41,7 @@ "longcat_flash_mtp", "mtp", "pangu_ultra_moe_mtp", + "step3p5_mtp", ] EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes] SpeculativeMethod = Literal[ @@ -264,6 +265,11 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]} ) + if hf_config.model_type == "step3p5": + hf_config.model_type = "step3p5_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", 1) + hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]}) + if initial_architecture == "MistralLarge3ForCausalLM": hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]}) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index c8822aed2e3f..b53a37a31761 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -17,11 +17,63 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils.collection_utils import LazyDict logger = init_logger(__name__) +@triton.jit +def _swiglustep_and_mul_kernel( + o_ptr, + o_stride, + x_ptr, + x_stride, + limit: tl.constexpr, + d: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +) -> None: + i = tl.program_id(axis=0).to(tl.int64) + j = tl.program_id(axis=1) + o_row_ptr = o_ptr + o_stride * i + x_row_ptr = x_ptr + x_stride * i + offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < d + + gate = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32) + up = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32) + + gate_silu = tl.sigmoid(gate) * gate + gate_clamped = tl.minimum(gate_silu, limit) + up_clamped = tl.minimum(tl.maximum(up, -limit), limit) + + result = gate_clamped * up_clamped + result = result.to(x_ptr.dtype.element_ty) + tl.store(o_row_ptr + offsets, result, mask=mask) + + +def swiglustep_and_mul_triton( + output: torch.Tensor, input: torch.Tensor, limit: float = 7.0 +): + b, n = input.shape + assert input.ndim == 2 + assert n % 2 == 0 + d = n // 2 + + def grid(meta): + return (b, triton.cdiv(d, meta["BLOCK_SIZE"])) + + _swiglustep_and_mul_kernel[grid]( + output, + output.stride(0), + input, + input.stride(0), + limit=limit, + d=d, + BLOCK_SIZE=1024, + ) + + # --8<-- [start:fatrelu_and_mul] @CustomOp.register("fatrelu_and_mul") class FatreluAndMul(CustomOp): @@ -304,6 +356,44 @@ def extra_repr(self) -> str: return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}" +# --8<-- [start:swiglustep_and_mul] +@CustomOp.register("swiglustep_and_mul") +class SwigluStepAndMul(CustomOp): + """An activation function for SwiGLU with clamping. + + Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit) + where d = x.shape[-1] // 2. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def __init__(self, limit: float = 7.0): + super().__init__() + if limit is None: + raise ValueError("SwigluStepAndMul requires limit to be set.") + self.limit = limit + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + gate, up = x.chunk(2, dim=-1) + gate = F.silu(gate) + gate = gate.clamp(max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + return gate * up + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + swiglustep_and_mul_triton(out, x, self.limit) + return out + + def extra_repr(self) -> str: + return f"limit={repr(self.limit)}" + + # --8<-- [start:gelu_new] @CustomOp.register("gelu_new") class NewGELU(CustomOp): diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index fafcf6de6140..00d55bfb7e04 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -144,7 +144,7 @@ def _supports_quant_scheme( @staticmethod def _supports_activation(activation: str) -> bool: - return activation in ["silu"] + return activation in ["silu", "swiglustep"] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 987388692725..120b3c2d1e2c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1939,7 +1939,7 @@ def _supports_quant_scheme( @staticmethod def _supports_activation(activation: str) -> bool: - return activation in ["silu", "gelu", "swigluoai"] + return activation in ["silu", "gelu", "swigluoai", "swiglustep"] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index a4b20505ea32..75873a92abdb 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -358,6 +358,11 @@ def apply_moe_activation( torch.ops._C.gelu_and_mul(output, input) elif activation == "swigluoai": torch.ops._C.swigluoai_and_mul(output, input) + elif activation == "swiglustep": + from vllm.model_executor.layers.activation import swiglustep_and_mul_triton + + swiglustep_and_mul_triton(output, input) + # Activations without gated multiplication elif activation == SILU_NO_MUL: output.copy_(F.silu(input)) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f38914a7ce33..95f6cc06527a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -189,6 +189,7 @@ "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"), "Step1ForCausalLM": ("step1", "Step1ForCausalLM"), "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"), + "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), @@ -495,6 +496,7 @@ "MedusaModel": ("medusa", "Medusa"), "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"), "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"), + "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"), # Temporarily disabled. # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py new file mode 100644 index 000000000000..f0d7b4a75a9d --- /dev/null +++ b/vllm/model_executor/models/step3p5.py @@ -0,0 +1,894 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Jurassic model.""" + +from collections.abc import Iterable +from typing import Any + +import torch +from torch import nn +from torch.nn.parameter import Parameter + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import ( + get_dp_group, + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul, SwigluStepAndMul +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backend import AttentionType + +from .interfaces import MixtureOfExperts, SupportsPP +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + WeightsMapper, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + +logger = init_logger(__name__) + + +class FP32ReplicatedLinear(ReplicatedLinear): + """ + Use FP32 for higher precision. + """ + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: + assert self.params_dtype == torch.float32 + return super().forward(x.to(torch.float32)) + + +class Step3p5MLP(nn.Module): + def __init__( + self, + config: ModelConfig, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + self.prefix = prefix + self.hidden_size = hidden_size + self.limit = None + layer_idx = extract_layer_index(prefix) + if ( + config.swiglu_limits_shared + and config.swiglu_limits_shared[layer_idx] is not None + and config.swiglu_limits_shared[layer_idx] != 0 + ): + self.limit = config.swiglu_limits_shared[layer_idx] + self.act_fn = SwigluStepAndMul(limit=self.limit) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(hidden_states) + intermediate_act = self.act_fn(gate_up) + output, _ = self.down_proj(intermediate_act) + return output + + +class Step3p5Attention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: int | None = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + rope_theta: float | list[float] | None = 10000, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + rope_scaling: dict[str, Any] | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + # Step3p5 specific args + sliding_window: int | None = None, + use_head_wise_attn_gate: bool = False, + layer_types: list = None, + use_rope_layers: list = None, + yarn_only_types: list = None, + swa_num_attention_heads: int | None = None, + partial_rotary_factor: float = 1.0, + ): + super().__init__() + self.hidden_size = hidden_size + self.total_num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + self.layer_idx = extract_layer_index(prefix) + if layer_types: + enable_sliding_window = layer_types[self.layer_idx] == "sliding_attention" + else: + enable_sliding_window = self.layer_idx % 2 == 0 + if yarn_only_types and layer_types[self.layer_idx] not in yarn_only_types: + rope_scaling = None + + if sliding_window is not None and enable_sliding_window: + sliding_window = sliding_window + if swa_num_attention_heads is not None: + num_heads = swa_num_attention_heads + self.total_num_heads = swa_num_attention_heads + else: + sliding_window = None + + if isinstance(rope_theta, list): + rope_theta = rope_theta[self.layer_idx] + + self.rank = get_tensor_model_parallel_rank() + self.partial_rotary_factor = partial_rotary_factor + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + if rope_scaling is not None and not isinstance(rope_scaling, dict): + raise ValueError("rope_scaling must be a dict for Step3p5Attention.") + + rope_parameters: dict[str, Any] = ( + dict(rope_scaling) if rope_scaling is not None else {} + ) + rope_parameters.setdefault("rope_type", "default") + rope_parameters["rope_theta"] = self.rope_theta + rope_parameters["partial_rotary_factor"] = partial_rotary_factor + + self.rotary_emb = get_rope( + head_size=self.head_dim, + max_position=max_position, + rope_parameters=rope_parameters, + ) + + self.q_norm = GemmaRMSNorm(self.head_dim, rms_norm_eps) + self.k_norm = GemmaRMSNorm(self.head_dim, rms_norm_eps) + self.use_head_wise_attn_gate = use_head_wise_attn_gate + if use_head_wise_attn_gate: + self.g_proj = ColumnParallelLinear( + hidden_size, + self.total_num_heads, + bias=False, + prefix=f"{prefix}.g_proj", + ) + + self.use_rope = True + if use_rope_layers: + self.use_rope = use_rope_layers[self.layer_idx] + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + per_layer_sliding_window=sliding_window, + attn_type=attn_type, + ) + + self.max_position_embeddings = max_position + assert self.partial_rotary_factor == 1 or self.partial_rotary_factor == 0.5 + self.rotary_dim = ( + self.head_dim if self.partial_rotary_factor == 1 else self.head_dim // 2 + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # Add qk-norm inline similar to Qwen3 MOE attention + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_by_head = self.q_norm(q_by_head.contiguous()) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_by_head = self.k_norm(k_by_head.contiguous()) + k = k_by_head.view(k.shape) + if self.use_rope: + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + if self.use_head_wise_attn_gate: + extra_dims, _ = self.g_proj(hidden_states) + output = ( + attn_output.view(*attn_output.shape[:-1], self.num_heads, self.head_dim) + * extra_dims.unsqueeze(-1).sigmoid() + ) + attn_output = output.view(*attn_output.shape) + output, _ = self.o_proj(attn_output) + return output + + +class FusedMoEBlock(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + + self.tp_size = get_tensor_model_parallel_world_size() + self.layer_idx = extract_layer_index(prefix) + + self.ep_size = get_ep_group().device_group.size() + self.ep_rank = get_ep_group().device_group.rank() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + + self.hidden_size = config.hidden_size + self.enable_eplb = parallel_config.enable_eplb + self.n_routed_experts = config.moe_num_experts + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + if self.tp_size > config.moe_num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.moe_num_experts}." + ) + + self.gate = FP32ReplicatedLinear( + config.hidden_size, + config.moe_num_experts, + bias=False, + quant_config=None, + params_dtype=torch.float32, # Use FP32 for higher precision. + prefix=f"{prefix}.gate", + ) + self.use_moe_router_bias = config.use_moe_router_bias + assert self.use_moe_router_bias, "Only support use_moe_router_bias is true." + self.routed_scaling_factor = config.moe_router_scaling_factor + self.router_bias = nn.Parameter( + torch.zeros(config.moe_num_experts, dtype=torch.float32), + requires_grad=False, + ) + self.need_fp32_gate = config.need_fp32_gate + assert self.need_fp32_gate, ( + "Router logits must use FP32 precision for numerical stability." + ) + + activation = "silu" + swiglu_limits = config.swiglu_limits or [] + swiglu_limit = ( + swiglu_limits[self.layer_idx] + if self.layer_idx < len(swiglu_limits) + else None + ) + if swiglu_limit not in (None, 0): + swiglu_limit = float(swiglu_limit) + assert swiglu_limit == 7.0, ( + "Swiglu limit in fused moe block only suport 7.0 now." + ) + activation = "swiglustep" + logger.debug( + "step3p5 layer_idx: %s, activation: %s, limit: %s", + self.layer_idx, + activation, + swiglu_limit, + ) + + self.share_expert = Step3p5MLP( + config=config, + hidden_size=self.hidden_size, + intermediate_size=config.share_expert_dim, + hidden_act="silu", + reduce_results=False, + quant_config=quant_config, + prefix=f"{prefix}.share_expert", + ) + self.experts = SharedFusedMoE( + shared_experts=self.share_expert, + gate=self.gate, + num_experts=config.moe_num_experts, + top_k=config.moe_top_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_expert_weight, + quant_config=quant_config, + activation=activation, + prefix=f"{prefix}.experts", + scoring_func=getattr(config, "moe_router_activation", "sigmoid"), + e_score_correction_bias=self.router_bias, + routed_scaling_factor=config.moe_router_scaling_factor, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.experts.is_internal_router: + # In this case, the gate/router runs inside the FusedMoE class + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=hidden_states + ) + else: + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + shared_output, final_hidden_states = fused_moe_out + if self.share_expert is None: + assert shared_output is None + + if self.share_expert is not None: + assert shared_output is not None + final_hidden_states += shared_output + + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class Step3p5DecoderLayer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + config = vllm_config.model_config.hf_config + self.hidden_size = config.hidden_size + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + if cache_config is not None: + cache_config.sliding_window = None + if config.att_impl_type == "GQA": + num_attention_heads = None + num_attention_groups = None + head_dim = None + if ( + getattr(config, "attention_other_setting", None) + and getattr(config, "layer_types", []) + and config.layer_types[layer_idx] + == config.attention_other_setting["attention_type"] + ): + num_attention_heads = config.attention_other_setting[ + "num_attention_heads" + ] + num_attention_groups = config.attention_other_setting[ + "num_attention_groups" + ] + head_dim = config.attention_other_setting["head_dim"] + partial_rotary_factors = getattr(config, "partial_rotary_factors", []) + self.self_attn = Step3p5Attention( + hidden_size=self.hidden_size, + num_heads=num_attention_heads + if num_attention_heads + else config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=num_attention_groups + if num_attention_groups + else config.num_attention_groups, + rope_theta=config.rope_theta, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, "attention_bias", False), + head_dim=head_dim if head_dim else getattr(config, "head_dim", None), + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=getattr(config, "rope_scaling", None), + sliding_window=getattr(config, "sliding_window", None), + use_head_wise_attn_gate=getattr( + config, "use_head_wise_attn_gate", False + ), + layer_types=getattr(config, "layer_types", []), + use_rope_layers=getattr(config, "use_rope_layers", []), + yarn_only_types=getattr(config, "yarn_only_types", []), + partial_rotary_factor=partial_rotary_factors[layer_idx] + if partial_rotary_factors + else 1.0, + prefix=f"{prefix}.self_attn", + ) + else: + raise ValueError( + f"Unsupported attention implementation: {config.att_impl_type}" + ) + self.use_moe = False + self.tp_group = get_tp_group() + self.use_fused_all_reduce = ( + get_tensor_model_parallel_world_size() > 1 + and get_dp_group().world_size == 1 + ) + if self.use_fused_all_reduce: + logger.warning_once("Enable custom fused all reduce...") + else: + logger.warning_once("Disable custom fused all reduce...") + + moe_layers_enum = getattr(config, "moe_layers_enum", None) + if moe_layers_enum is not None: + moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] + else: + moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] + if layer_idx in moe_layers_idx: + self.moe = FusedMoEBlock( + vllm_config, + prefix=f"{prefix}.moe", + ) + self.use_moe = True + else: + self.mlp = Step3p5MLP( + config=config, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act="silu", + quant_config=quant_config, + reduce_results=True, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, config.rms_norm_eps + ) + self.prefix = prefix + + def add_and_maybe_inplace_all_reduce( + self, in1: torch.Tensor, in2: torch.Tensor + ) -> torch.Tensor: + if not self.use_fused_all_reduce: + return in1 + in2 + return self.tp_group.all_reduce(in1 + in2) + + def forward( + self, positions: torch.Tensor, hidden_states: torch.Tensor + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states += residual + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if self.use_moe: + ffn_output = self.moe(hidden_states) + else: + ffn_output = self.mlp(hidden_states) + hidden_states = ffn_output + residual + return hidden_states + + +@support_torch_compile +class Step3p5Model(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + + self.vllm_config = vllm_config + config = vllm_config.model_config.hf_config + self.vocab_size = config.vocab_size + self.config = config + + self.moe_num_experts = config.moe_num_experts + + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Step3p5DecoderLayer( + vllm_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) + + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + config = self.config + assert config.num_attention_groups > 1, "Only support GQA" + qkv_params_mapping = [] + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + expert_params_mapping = [ + (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), + (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), + (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"), + ] + + disable_moe_stacked_params = [data[1] for data in expert_params_mapping] + + for name, loaded_weight in weights: + if name.startswith("model."): + local_name = name[len("model.") :] + full_name = name + else: + local_name = name + full_name = f"model.{name}" if name else "model" + + spec_layer = get_spec_layer_idx_from_weight_name(config, full_name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + # Skip any layers beyond the main model's depth (e.g., MTP layers) + if full_name.startswith("model.layers."): + parts = full_name.split(".") + if len(parts) > 2 and parts[2].isdigit(): + layer_idx = int(parts[2]) + if layer_idx >= config.num_hidden_layers: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in local_name: + continue + if any( + disable_moe_stacked_param in local_name + for disable_moe_stacked_param in disable_moe_stacked_params + ): + continue + replaced_name = local_name.replace(weight_name, param_name) + if is_pp_missing_parameter(replaced_name, self): + continue + if replaced_name not in params_dict: + continue + param = params_dict[replaced_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(replaced_name) + break + else: + for param_name, weight_name, shard_id in expert_params_mapping: + if weight_name not in local_name: + continue + replaced_name = local_name.replace(weight_name, param_name) + if is_pp_missing_parameter(replaced_name, self): + continue + if ( + replaced_name.endswith(".bias") + or replaced_name.endswith("_bias") + ) and replaced_name not in params_dict: + continue + if replaced_name not in params_dict: + continue + param = params_dict[replaced_name] + weight_loader = param.weight_loader + moe_expert_num = self.moe_num_experts + assert loaded_weight.shape[0] == moe_expert_num + for expert_id in range(moe_expert_num): + loaded_weight_expert = loaded_weight[expert_id] + weight_loader( + param, + loaded_weight_expert, + replaced_name, + shard_id=shard_id, + expert_id=expert_id, + ) + loaded_params.add(replaced_name) + break + else: + for ( + param_name, + weight_name, + start_idx, + end_idx, + ) in qkv_params_mapping: + if weight_name not in local_name: + continue + replaced_name = local_name.replace(weight_name, param_name) + if is_pp_missing_parameter(replaced_name, self): + continue + if replaced_name not in params_dict: + continue + param = params_dict[replaced_name] + dim = param.shape[param.output_dim] + begin_idx = int(start_idx * dim) + end_idx = int(end_idx * dim) + param_slice = param.narrow( + param.output_dim, begin_idx, end_idx - begin_idx + ) + param_slice.copy_(loaded_weight) + loaded_params.add(replaced_name) + break + else: + if is_pp_missing_parameter(local_name, self): + continue + if "expert_bias" in local_name: + logger.warning_once("ignore expert_bias") + continue + if local_name not in params_dict: + continue + param = params_dict[local_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(local_name) + return loaded_params + + +class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={".share_expert.": ".moe.share_expert."} + ) + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + lora_config = vllm_config.lora_config + self.config = config + self.vllm_config = vllm_config + + self.model = Step3p5Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + self.moe_layers: list[FusedMoEBlock] = [] + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + assert isinstance(layer, Step3p5DecoderLayer) + if hasattr(layer, "moe") and isinstance(layer.moe, FusedMoEBlock): + self.moe_layers.append(layer.moe) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + if not lora_config + else lora_config.lora_vocab_padding_size, + ) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + # Set MoE hyperparameters + self.expert_weights = [] + assert len(self.moe_layers) > 0, "No MoE layers found in the model." + example_layer = self.moe_layers[0] + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.model.norm(hidden_states) + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_tokens(input_ids) + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + experts = layer.experts + assert isinstance(experts, FusedMoE) + # Register the expert weights. + self.expert_weights.append(experts.get_expert_weights()) + experts.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.moe_layers: + assert isinstance(layer, FusedMoEBlock) + layer.n_local_physical_experts = num_local_physical_experts + layer.n_physical_experts = num_physical_experts + layer.n_redundant_experts = self.num_redundant_experts + layer.experts.update_expert_map() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + +def get_spec_layer_idx_from_weight_name( + config: ModelConfig, weight_name: str +) -> int | None: + if hasattr(config, "num_nextn_predict_layers") and ( + config.num_nextn_predict_layers > 0 + ): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith( + f"layers.{layer_idx + i}." # Step3p5Model + ) or weight_name.startswith(f"model.layers.{layer_idx + i}."): # Step3p5MTP + return layer_idx + i + return None diff --git a/vllm/model_executor/models/step3p5_mtp.py b/vllm/model_executor/models/step3p5_mtp.py new file mode 100644 index 000000000000..83e43dce5114 --- /dev/null +++ b/vllm/model_executor/models/step3p5_mtp.py @@ -0,0 +1,315 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.sequence import IntermediateTensors + +from .step3p5 import Step3p5DecoderLayer, get_spec_layer_idx_from_weight_name +from .utils import maybe_prefix + +logger = init_logger(__name__) + + +class SharedHead(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__() + self.norm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) + self.head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +class Step3p5AMultiTokenPredictorLayer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str, + ) -> None: + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) + self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) + self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.mtp_block = Step3p5DecoderLayer( + vllm_config, + prefix=f"{prefix}.mtp_block", + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1) + ) + + hidden_states = self.mtp_block(positions=positions, hidden_states=hidden_states) + return hidden_states + + +class Step3p5AMultiTokenPredictor(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict( + { + str(idx): Step3p5AMultiTokenPredictorLayer( + vllm_config, + f"{prefix}.layers.{idx}", + ) + for idx in range( + self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers, + ) + } + ) + + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + current_step_idx = spec_step_idx % self.num_mtp_layers + return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( + input_ids, + positions, + previous_hidden_states, + inputs_embeds, + current_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = spec_step_idx % self.num_mtp_layers + mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] + logits = self.logits_processor( + mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states) + ) + return logits + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + +class Step3p5MTP(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model = Step3p5AMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, positions, hidden_states, inputs_embeds, spec_step_idx + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor | None: + return self.model.compute_logits(hidden_states, spec_step_idx) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = [ + (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), + (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), + (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if "embed_tokens" not in name and spec_layer is None: + continue + name = self._rewrite_spec_layer_name(spec_layer, name) + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + if "experts" in name or "moe" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + for expert_id in range(loaded_weight.shape[0]): + loaded_weight_expert = loaded_weight[expert_id] + weight_loader( + param, + loaded_weight_expert, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + loaded_params.add(name) + break + else: + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") + and name not in params_dict + or "tok_embeddings" in name + ): + continue + + if spec_layer is not None and ".transformer." in name: + name = name.replace(".transformer.", ".") + if "shared_head" in name: + name = name.replace("shared_head.output", "shared_head.head") + if "embed_tokens" in name: + assert ( + hasattr(self.config, "num_nextn_predict_layers") + and self.config.num_nextn_predict_layers > 0 + ) + name = "model.embed_tokens.weight" + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + params_need_to_load = set(params_dict.keys()) + # Some KV cache scales are optional: checkpoints may omit them and vLLM + # will fall back to default scales during initialization. + optional_params = { + name + for name, param in params_dict.items() + if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")) + and getattr(param, "numel", lambda: 0)() == 1 + and getattr(param, "requires_grad", False) is False + } + params_need_to_load -= optional_params + if params_need_to_load != loaded_params: + missing_params = list(params_need_to_load - loaded_params) + param_name_example = missing_params[0] + raise RuntimeError( + "Some parameters like " + f"{param_name_example} are not in the checkpoint and will falsely " + "use random initialization" + ) + return loaded_params + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .mtp_block for modules in transformer layer block for spec layer + """ + spec_layer_weight_names = [ + "embed_tokens", + "enorm", + "hnorm", + "eh_proj", + "shared_head", + ] + spec_layer_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace( + f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block." + ) + return name diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 7201d3fbdfcd..8be56b56e9ca 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -84,6 +84,10 @@ "step3_reasoning_parser", "Step3ReasoningParser", ), + "step3p5": ( + "step3p5_reasoning_parser", + "Step3p5ReasoningParser", + ), } diff --git a/vllm/reasoning/step3p5_reasoning_parser.py b/vllm/reasoning/step3p5_reasoning_parser.py new file mode 100644 index 000000000000..b93f551426fb --- /dev/null +++ b/vllm/reasoning/step3p5_reasoning_parser.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence + +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, +) +from vllm.entrypoints.openai.engine.protocol import DeltaMessage +from vllm.entrypoints.openai.responses.protocol import ( + ResponsesRequest, +) +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser +from vllm.tokenizers import TokenizerLike + + +class Step3p5ReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for Step3p5 model. + + Step3p5 uses the ... format, but it tends to emit an extra + newline immediately before and/or after the token. This parser trims: + - the newline right before + - the newline right after + """ + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def __init__(self, tokenizer: TokenizerLike, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + # Used to hold a trailing "\n" from reasoning content so we can decide + # whether it is immediately before . + self._pending_reasoning_newline = False + + # Used to delay the reasoning end detection. + # This is necessary to remove the newline appears immediately after , + # which may cause the end detection to be delayed by one round. + self.end_offset = 1 + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + if self.end_token_id in input_ids and self.end_offset > 0: + self.end_offset -= 1 + return False + return self.end_offset < 1 + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + if self.end_token_id in input_ids and self.end_offset > 0: + self.end_offset -= 1 + return False + return self.end_offset < 1 + + def extract_reasoning( + self, + model_output: str, + request: ChatCompletionRequest | ResponsesRequest, + ) -> tuple[str | None, str | None]: + reasoning, content = super().extract_reasoning(model_output, request) + if reasoning is not None: + reasoning = reasoning.removesuffix("\n") + if content is not None: + content = content.removeprefix("\n") + return reasoning or None, content or None + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + # Drop the immediate newline that models often emit after . + if previous_text.endswith(self.end_token) and delta_text: + if delta_text == "\n": + return None + elif delta_text.startswith("\n"): + remaining = delta_text.removeprefix("\n") + return DeltaMessage(content=remaining) if remaining else None + + ret = super().extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + + if ret is None: + return None + + # Compatibility path for models that don't generate the start token: + # treat everything before as reasoning and everything after + # as content. + if ( + self.start_token_id not in previous_token_ids + and self.start_token_id not in delta_token_ids + ): + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + reasoning = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token) :] + ret = DeltaMessage(reasoning=reasoning, content=content or None) + elif self.end_token_id in previous_token_ids: + ret = DeltaMessage(content=delta_text) + else: + ret = DeltaMessage(reasoning=delta_text) + + reasoning_to_output = ret.reasoning + content_to_output = ret.content + + # Reasoning: handle the newline immediately before . + if reasoning_to_output is not None: + if self._pending_reasoning_newline: + reasoning_to_output = "\n" + reasoning_to_output + self._pending_reasoning_newline = False + + if reasoning_to_output.endswith("\n"): + reasoning_to_output = reasoning_to_output.removesuffix("\n") + if self.end_token in delta_text: + # Trailing "\n" is right before , drop it. + self._pending_reasoning_newline = False + else: + # Hold the trailing "\n" until we know whether follows. + self._pending_reasoning_newline = True + + # Content: handle the newline immediately after . + if content_to_output is not None: + # No need to get into parser again to remove newline after . + self.end_offset -= 1 + + # If we have content, reasoning must have ended. + self._pending_reasoning_newline = False + + if self.end_token in delta_text and content_to_output.startswith("\n"): + content_to_output = content_to_output.removeprefix("\n") + + reasoning_to_output = reasoning_to_output or None + content_to_output = content_to_output or None + if reasoning_to_output is None and content_to_output is None: + return None + + return DeltaMessage(reasoning=reasoning_to_output, content=content_to_output) diff --git a/vllm/tool_parsers/__init__.py b/vllm/tool_parsers/__init__.py index b26638c0959b..c1a39f2afa02 100644 --- a/vllm/tool_parsers/__init__.py +++ b/vllm/tool_parsers/__init__.py @@ -134,6 +134,10 @@ "step3_tool_parser", "Step3ToolParser", ), + "step3p5": ( + "step3p5_tool_parser", + "Step3p5ToolParser", + ), "xlam": ( "xlam_tool_parser", "xLAMToolParser", diff --git a/vllm/tool_parsers/step3p5_tool_parser.py b/vllm/tool_parsers/step3p5_tool_parser.py new file mode 100644 index 000000000000..b7c8699a03db --- /dev/null +++ b/vllm/tool_parsers/step3p5_tool_parser.py @@ -0,0 +1,1511 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast +import json +from collections.abc import Sequence +from typing import Any +from xml.parsers.expat import ParserCreate + +import regex as re + +from vllm.entrypoints.chat_utils import make_tool_call_id +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, +) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, + ToolParserManager, +) + +logger = init_logger(__name__) + + +class StreamingXMLToolCallParser: + """ + Simplified streaming XML tool call parser + Supports streaming input, parsing, and output + """ + + def __init__(self): + self.reset_streaming_state() + + # Tool configuration information + self.tools: list[ChatCompletionToolsParam] | None = None + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + self.function_start_token: str = " DeltaMessage: + """ + Parse single streaming XML chunk and return Delta response + This is the actual streaming interface that receives chunks + one by one and maintains internal state + + Args: + xml_chunk: Single XML chunk string + Returns: + DeltaMessage: Contains delta information generated by this chunk, + returns empty response if no complete elements + """ + # Record delta count before processing + initial_delta_count = len(self.deltas) + + self.streaming_buffer += xml_chunk + + found_elements = self._process_complete_xml_elements() + + if found_elements: + # If complete elements found, check if end events were missed + # some tags may not have been triggered + try: + new_deltas = self.deltas[initial_delta_count:] + # If this chunk contains + # but didn't generate '}', then complete it + if ( + self.current_call_id is not None + and self.function_end_token in xml_chunk + ): + # - Added '}' (non-empty parameter ending) + # - Added '{}' (empty parameter function) + has_function_close = any( + ( + td.tool_calls + and any( + ( + tc.function + and tc.id == self.current_call_id + and isinstance(tc.function.arguments, str) + and (tc.function.arguments in ("}", "{}")) + ) + for tc in td.tool_calls + ) + ) + for td in new_deltas + ) + if not has_function_close: + # Close potentially unclosed element + if self.current_param_name: + self._end_element("parameter") + if self.current_function_name: + self._end_element("function") + # If this chunk contains + # but didn't generate final empty delta, then complete it + if ( + self.current_call_id is not None + and self.tool_call_end_token in xml_chunk + ): + has_toolcall_close = any( + ( + td.tool_calls + and any( + ( + tc.type == "function" + and tc.function + and tc.function.arguments == "" + and tc.id == self.current_call_id + ) + for tc in td.tool_calls + ) + ) + for td in new_deltas + ) + if not has_toolcall_close: + # Close potentially unclosed element + if self.current_param_name: + self._end_element("parameter") + if self.current_function_name: + self._end_element("function") + self._end_element("tool_call") + except Exception as e: + logger.warning("Error with fallback parsing: %s", e) + # Merge newly generated deltas into single response + result_delta = self._merge_new_deltas_to_single_response( + initial_delta_count + ) + return result_delta + else: + # No complete elements, check if there's unoutput text content + if self.text_content_buffer and self.tool_call_index == 0: + # Has text content but no tool_call yet, output text content + text_delta = DeltaMessage(content=self.text_content_buffer) + self._emit_delta(text_delta) + # Clear buffer to avoid duplicate output + self.text_content_buffer = "" + return text_delta + + # If this chunk contains end tags but wasn't triggered by parser, + # manually complete end events + # Only execute when still on the same call as when entered, + # to prevent accidentally closing new calls + # in multi scenarios + if self.current_call_id is not None and ( + self.function_end_token in xml_chunk + or self.tool_call_end_token in xml_chunk + ): + # Close potentially unclosed element + if self.current_param_name: + self._end_element("parameter") + if self.function_end_token in xml_chunk and self.current_function_name: + self._end_element("function") + if self.tool_call_end_token in xml_chunk: + self._end_element("tool_call") + # Return the merged delta result generated by this fallback + result_delta = self._merge_new_deltas_to_single_response( + initial_delta_count + ) + return result_delta + + # No complete elements, return empty response + return DeltaMessage(content=None) + + def _escape_xml_special_chars(self, text: str) -> str: + """ + Escape XML special characters + Args: + text: Original text + Returns: + Escaped text + """ + xml_escapes = { + "&": "&", + "<": "<", + ">": ">", + '"': """, + "'": "'", + } + + for char, escape in xml_escapes.items(): + text = text.replace(char, escape) + + return text + + def _process_complete_xml_elements(self) -> bool: + """ + Process complete XML elements in buffer + + Returns: + bool: Whether complete elements were found and processed + """ + found_any = False + + while self.last_processed_pos < len(self.streaming_buffer): + # Find next complete xml element + element, end_pos = self._find_next_complete_element(self.last_processed_pos) + if element is None: + # No complete element found, wait for more data + break + + # Check if this element should be skipped + if self._should_skip_element(element): + self.last_processed_pos = end_pos + continue + + # Found complete XML element, process it + try: + preprocessed_element = self._preprocess_xml_chunk(element) + # Check if this is the first tool_call start + if ( + ( + preprocessed_element.strip().startswith("") + or preprocessed_element.strip().startswith("") + and self.tool_call_index > 0 + and self.current_call_id + and self.current_function_name + ): + # Reset parser state but preserve generated deltas + if self.current_param_name: + self._end_element("parameter") + if self.current_function_open: + self._end_element("function") + # Output final tool_call tail delta + final_delta = DeltaMessage( + role=None, + content=None, + reasoning_content=None, + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments=""), + ) + ], + ) + self._emit_delta(final_delta) + # Reset XML parser and current call state + self._reset_xml_parser_after_tool_call() + # Parse preprocessed element + self.parser.Parse(preprocessed_element, False) + found_any = True + + except Exception as e: + logger.warning("Error when parsing XML elements: %s", e) + + # Update processed position + self.last_processed_pos = end_pos + + return found_any + + def _fix_incomplete_tag_in_chunk(self, chunk: str) -> str: + """ + Fallback: fix incomplete ) + Examples: , + Also handles missing = cases: -> , + -> + Only fixes tags that pass validation (parameter exists in tool definition) + """ + # First, handle missing = cases for function tags + chunk = self._fix_missing_equals_in_function_tag(chunk) + + for tag_type in ["parameter", "function"]: + pattern = f"<{tag_type}=" + if pattern not in chunk: + continue + + start_idx = chunk.find(pattern) + after_tag = chunk[start_idx:] + gt_pos = after_tag.find(">") + lt_pos = after_tag.find("<", len(pattern)) + + # Skip if already well-formed + if ( + gt_pos != -1 + and (lt_pos == -1 or gt_pos < lt_pos) + and pattern in after_tag[:gt_pos] + ): + continue + + # Extract tag name (stop at space, newline, or <) + content = chunk[start_idx + len(pattern) :] + end_pos = next( + (i for i, ch in enumerate(content) if ch in (" ", "\n", "<")), + len(content), + ) + tag_name = content[:end_pos] + + if not tag_name: + continue + + # Remove duplicate prefix: ", 1 + ) + + return chunk + + def _fix_missing_equals_in_function_tag(self, chunk: str) -> str: + """ + Fix missing = in function tags: or + Examples: + -> + -> + Only fixes if function name exists in tool definition + """ + # already correct + if " (with space/newline but no =) + pattern1 = r"" + match1 = re.search(pattern1, chunk) + if match1: + func_name = match1.group(1).strip() + # must validate function name exists before fixing + if func_name and self._validate_function_name(func_name): + original = match1.group(0) + fixed = f"" + chunk = chunk.replace(original, fixed, 1) + return chunk + + # Pattern 2: (no space, no =) + # only match " + match2 = re.search(pattern2, chunk) + if match2: + func_name = match2.group(1).strip() + # must validate function name exists before fixing + if func_name and self._validate_function_name(func_name): + original = match2.group(0) + fixed = f"" + chunk = chunk.replace(original, fixed, 1) + return chunk + + return chunk + + def _validate_function_name(self, func_name: str) -> bool: + """Check if function name exists in tool definitions""" + if not self.tools: + return False + + for tool in self.tools: + if ( + hasattr(tool, "type") + and tool.type == "function" + and hasattr(tool, "function") + and hasattr(tool.function, "name") + and tool.function.name == func_name + ): + return True + + return False + + def _validate_parameter_name(self, param_name: str) -> bool: + """Check if parameter exists in current function's tool definition""" + if not self.tools or not self.current_function_name: + return True + + for tool in self.tools: + if ( + hasattr(tool, "type") + and tool.type == "function" + and hasattr(tool, "function") + and hasattr(tool.function, "name") + and tool.function.name == self.current_function_name + ): + if not hasattr(tool.function, "parameters"): + return True + params = tool.function.parameters + if isinstance(params, dict): + properties = params.get("properties", params) + return param_name in properties + break + + return True + + def _should_skip_element(self, element: str) -> bool: + """ + Determine whether an element should be skipped + + Args: + element: Element to evaluate + + Returns: + bool: True means should skip, False means should process + """ + + # If it's a tool_call XML tag, don't skip + if ( + element.startswith(self.tool_call_start_token) + or element.startswith(self.function_start_token) + or element.startswith(self.parameter_start_token) + ): + return False + + # If currently not parsing tool calls and not blank, + # collect this text instead of skipping + # Only process other XML elements after tool_call appears, + # otherwise treat as plain text + if self.current_call_id is None and element: + # Collect text content to buffer + self.text_content_buffer += element + return True # Still skip, but content has been collected + + # If currently parsing tool calls, + # this might be parameter value, don't skip + if self.current_call_id is not None: + return False + + # Skip blank content + return not element + + def _find_next_complete_element(self, start_pos: int) -> tuple[str | None, int]: + """ + Find next complete XML element from specified position + + Args: + start_pos: Position to start searching + + Returns: + (Complete element string, element end position), + returns (None, start_pos) if no complete element found + """ + buffer = self.streaming_buffer[start_pos:] + + if not buffer: + return None, start_pos + + if buffer.startswith("<"): + # Check if this is an incomplete parameter/function tag + # e.g., " not in buffer.split("\n")[0] + ) + is_incomplete_func = ( + buffer.startswith("" not in buffer.split("\n")[0] + ) + + if is_incomplete_param or is_incomplete_func: + # Find the corresponding closing tag + tag_type = "parameter" if is_incomplete_param else "function" + closing_tag = f"" + closing_pos = buffer.find(closing_tag) + + if closing_pos != -1: + # Found closing tag, return complete element including closing tag + complete_element = buffer[: closing_pos + len(closing_tag)] + return complete_element, start_pos + closing_pos + len(closing_tag) + + # Need to ensure no new < appears, + # find the nearest one between < and > + tag_end = buffer.find("<", 1) + tag_end2 = buffer.find(">", 1) + if tag_end != -1 and tag_end2 != -1: + # Next nearest is < + if tag_end < tag_end2: + return buffer[:tag_end], start_pos + tag_end + # Next nearest is >, means found XML element + else: + return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 + elif tag_end != -1: + return buffer[:tag_end], start_pos + tag_end + elif tag_end2 != -1: + return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 + else: + # If currently not parsing tool calls (entering a tool_call), + # check if starts with or + if buffer == ""[: len(buffer)]: + # Might be start of , wait for more data + return None, start_pos + elif ( + buffer.startswith(" DeltaMessage: + """ + Merge newly generated deltas from this processing + into a single DeltaMessage + + Args: + initial_count: Delta count before processing + + Returns: + Merged DeltaMessage containing all newly generated delta information + """ + if len(self.deltas) <= initial_count: + return DeltaMessage(content=None) + + # Get newly generated deltas + new_deltas = self.deltas[initial_count:] + + if len(new_deltas) == 1: + # Only one new delta, return directly + return new_deltas[0] + + # Merge multiple new deltas + merged_tool_calls: list[DeltaToolCall] = [] + merged_content: str = "" + + for delta in new_deltas: + if delta.content: + merged_content += delta.content + if delta.tool_calls: + # For tool_calls, we need to intelligently merge arguments + for tool_call in delta.tool_calls: + # Find if there's already a tool_call with the same call_id + existing_call = None + for existing in merged_tool_calls: + if existing.id == tool_call.id: + existing_call = existing + break + + if existing_call and existing_call.function: + # Merge to existing tool_call + if tool_call.function and tool_call.function.name: + existing_call.function.name = tool_call.function.name + if ( + tool_call.function + and tool_call.function.arguments is not None + ): + if existing_call.function.arguments is None: + existing_call.function.arguments = "" + + # For streaming JSON parameters, + # simply concatenate in order + new_args = tool_call.function.arguments + existing_call.function.arguments += new_args + if tool_call.type: + existing_call.type = tool_call.type + else: + # Add new tool_call + merged_tool_calls.append(tool_call) + + return DeltaMessage( + content=merged_content if merged_content else None, + tool_calls=merged_tool_calls, + ) + + def _preprocess_xml_chunk(self, chunk: str) -> str: + """ + Preprocess XML chunk, handle non-standard formats, + and escape special characters + + Args: + chunk: Original XML chunk + + Returns: + Processed XML chunk + """ + + # Check if this is a tool_call related element + is_tool_call = False + if chunk.startswith(self.tool_call_start_token) or chunk.startswith( + self.tool_call_end_token + ): + is_tool_call = True + # Check for function tags (including malformed ones without =) + # , , , + if ( + chunk.startswith(self.function_start_token) + or chunk.startswith(self.function_end_token) + or chunk.startswith(" + # This handles cases like: format -> + processed = re.sub(r"]+)>", r'', chunk) + # Handle format -> + processed = re.sub(r"]+)>", r'', processed) + + original_chunk = chunk + # If in parameter value accumulation mode + if self._pre_inside_parameter: + # Parameter end: output accumulated raw text + # safely then return + if processed.startswith(""): + body_text = self._pre_param_buffer + # Trigger deferred parsing mode + # literal_eval+json output in end_element + self.defer_current_parameter = True + self.deferred_param_raw_value = body_text + # Clean up state + self._pre_inside_parameter = False + self._pre_param_buffer = "" + self._pre_current_param_name = None + safe_text = self._escape_xml_special_chars(body_text) + return f"{safe_text}" + else: + # If this is the first block of content after entering parameter + # evaluate if deferred parsing is needed; + # If not needed, exit accumulation mode + # and pass through directly + if self._pre_param_buffer == "": + # Get current parameter type + param_type = ( + self._get_param_type(self._pre_current_param_name) + if self._pre_current_param_name + else "string" + ) + # Only these types need deferred parsing to + # handle Python literals containing single quotes + is_object_type = param_type in ["object"] + is_complex_type = ( + param_type in ["array", "arr", "sequence"] + or param_type.startswith("dict") + or param_type.startswith("list") + ) + + # Only delay when contains container symbols + # and has single quotes and is complex type + has_container_hint = ( + ("[" in original_chunk) + or ("{" in original_chunk) + or ("(" in original_chunk) + ) + + # Determine if deferred parsing is needed + need_defer = False + if is_complex_type: + # Complex type, always need deferred parsing + need_defer = True + elif ( + is_object_type + and has_container_hint + and ("'" in original_chunk) + ): + # Object type with container symbols + # and single quotes, need deferred parsing + need_defer = True + + if not need_defer: + # No need for deferred parsing, + # exit parameter mode directly + self._pre_inside_parameter = False + return self._escape_xml_special_chars(original_chunk) + self._pre_param_buffer += original_chunk + return "" + + # Parameter start: enable accumulation + if processed.startswith("', processed) + if m: + self._pre_current_param_name = m.group(1) + self._pre_inside_parameter = True + self._pre_param_buffer = "" + return processed + + # If processed doesn't contain special_token, escape processed + # This is because XML parsing encounters special characters + # and reports errors, so escaping is needed + if not is_tool_call: + processed = self._escape_xml_special_chars(processed) + return processed + + def _emit_delta(self, delta: DeltaMessage): + """Emit Delta response (streaming output)""" + self.deltas.append(delta) + + def _auto_close_open_parameter_if_needed(self, incoming_tag: str | None = None): + """Before starting to process new elements, + if there are unclosed tags from before, + automatically complete their endings to the parser. + - If there are unclosed parameters, + it's equivalent to feeding `` + - When about to start a new function or tool_call, + if there are unclosed functions, complete ``. + - When about to start a new tool_call, + if there are unclosed tool_calls, complete ``. + """ + # First close unclosed parameters + if self.current_param_name: + self._end_element("parameter") + + # If about to start new function or tool_call, + # and there are unclosed functions, close function first + if incoming_tag in ("function", "tool_call") and self.current_function_name: + self._end_element("function") + + # If about to start new tool_call, + # and there are unclosed tool_calls, close tool_call first + if incoming_tag == "tool_call" and self.current_call_id: + self._end_element("tool_call") + + def _start_element(self, name: str, attrs: dict[str, str]): + """Handle XML start element events""" + + if name == "root": + return + + if name == "tool_call": + # Before opening new tool_call, + # automatically complete previous unclosed tags + self._auto_close_open_parameter_if_needed("tool_call") + + self.parameters = {} + self.current_call_id = make_tool_call_id() + self.current_param_is_first = True + self.tool_call_index += 1 + elif name.startswith("function") or (name == "function"): + # If missing tool_call, manually complete + if not self.current_call_id: + self._start_element("tool_call", {}) + # Before opening new function, + # automatically complete previous unclosed tags (parameter/function) + self._auto_close_open_parameter_if_needed("function") + function_name = self._extract_function_name(name, attrs) + self.current_function_name = function_name + self.current_function_open = True + if function_name: + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=function_name, arguments="" + ), + ) + ] + ) + self._emit_delta(delta) + elif name.startswith("parameter") or (name == "parameter"): + # If previous parameter hasn't ended normally, + # complete its end first, then start new parameter + self._auto_close_open_parameter_if_needed("parameter") + param_name = self._extract_parameter_name(name, attrs) + self.current_param_name = param_name + self.current_param_value = "" + self.current_param_value_converted = "" + self.start_quote_emitted = False # Reset start quote flag + + # Only output parameter name and colon, + # don't output quotes + # decide after parameter value type is determined + if param_name: + if not self.parameters: + # First parameter + # start JSON, only output parameter name and colon + json_start = f'{{"{param_name}": ' + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=json_start + ), + ) + ] + ) + self._emit_delta(delta) + self.current_param_is_first = True + else: + # Subsequent parameters + # add comma and parameter name, no quotes + json_continue = f', "{param_name}": ' + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=json_continue + ), + ) + ] + ) + self._emit_delta(delta) + self.current_param_is_first = False + + def _char_data(self, data: str): + """Handle XML character data events""" + if data and self.current_param_name: + # If preprocessing stage determines deferred parsing is needed, + # only cache character data, no streaming output + if self.defer_current_parameter: + original_data = data + if self.should_emit_end_newline: + original_data = "\n" + original_data + self.should_emit_end_newline = False + if original_data.endswith("\n"): + self.should_emit_end_newline = True + original_data = original_data[:-1] + self.current_param_value += original_data + return + + param_type = self._get_param_type(self.current_param_name) + + # Check if this is the first time receiving data for this parameter + # If this is the first packet of data and starts with \n, remove \n + if not self.current_param_value and data.startswith("\n"): + data = data[1:] + + # Output start quote for string type (if not already output) + if ( + param_type in ["string", "str", "text", "varchar", "char", "enum"] + and not self.start_quote_emitted + ): + quote_delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='"'), + ) + ] + ) + self._emit_delta(quote_delta) + self.start_quote_emitted = True + + if not data: + return + + original_data = data + # Delay output of trailing newline + if self.should_emit_end_newline: + original_data = "\n" + original_data + self.should_emit_end_newline = False + if original_data.endswith("\n"): + self.should_emit_end_newline = True + original_data = original_data[:-1] + self.current_param_value += original_data + + # convert parameter value by param_type + converted_value = self._convert_param_value( + self.current_param_value, param_type + ) + output_data = self._convert_for_json_streaming(converted_value, param_type) + + delta_data = output_data[len(self.current_param_value_converted) :] + self.current_param_value_converted = output_data + + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments=delta_data), + ) + ] + ) + self._emit_delta(delta) + + def _end_element(self, name: str): + """Handle XML end element events""" + + if name == "root": + return + + # If function or tool_call ends and there are still unclosed parameters, + # complete parameter end first + if ( + name.startswith("function") or name == "function" or name == "tool_call" + ) and self.current_param_name: + self._auto_close_open_parameter_if_needed() + + if ( + name.startswith("parameter") or name == "parameter" + ) and self.current_param_name: + # End current parameter + param_name = self.current_param_name + param_value = self.current_param_value + + # If in deferred parsing mode, + # perform overall parsing on raw content + # accumulated in preprocessing stage and output once + if self.defer_current_parameter: + raw_text = ( + self.deferred_param_raw_value + if self.deferred_param_raw_value + else param_value + ) + parsed_value = None + output_arguments = None + try: + # If previously delayed trailing newline, + # add it back before parsing + if self.should_emit_end_newline: + raw_for_parse = raw_text + "\n" + else: + raw_for_parse = raw_text + parsed_value = ast.literal_eval(raw_for_parse) + output_arguments = json.dumps(parsed_value, ensure_ascii=False) + except Exception: + # Fallback: output as string as-is + output_arguments = json.dumps(raw_text, ensure_ascii=False) + parsed_value = raw_text + + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=output_arguments + ), + ) + ] + ) + self._emit_delta(delta) + + # Clean up and store + self.should_emit_end_newline = False + self.parameters[param_name] = parsed_value + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.start_quote_emitted = False + self.defer_current_parameter = False + self.deferred_param_raw_value = "" + return + + param_type = self._get_param_type(param_name) + + # convert complete parameter value by param_type + converted_value = self._convert_param_value(param_value, param_type) + + # Decide whether to add end quote based on parameter type + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + # For empty string parameters, need special handling + if not param_value and not self.start_quote_emitted: + # No start quote output, + # directly output complete empty string + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='""'), + ) + ] + ) + self._emit_delta(delta) + else: + # Non-empty parameter value, output end quote + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='"'), + ) + ] + ) + self._emit_delta(delta) + + self.should_emit_end_newline = False + # Store converted value + self.parameters[param_name] = converted_value + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.start_quote_emitted = False + + elif name.startswith("function") or name == "function": + # if there are parameters, close JSON object + if self.parameters: + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments="}"), + ) + ] + ) + self._emit_delta(delta) + # return empty object + else: + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments="{}"), + ) + ] + ) + self._emit_delta(delta) + self.current_function_open = False + self.current_function_name = ( + None # Clear function name to prevent duplicate closing + ) + + elif name == "tool_call": + # Before ending tool_call, + # ensure function is closed to complete missing right brace + if self.current_function_open: + # If there are still unclosed parameters, close them first + if self.current_param_name: + self._end_element("parameter") + # Close function, ensure output '}' or '{}' + self._end_element("function") + # Final Delta + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments=""), + ) + ] + ) + self._emit_delta(delta) + + # Check if there's text content to output (between tool_calls) + if self.text_content_buffer.strip(): + text_delta = DeltaMessage(content=self.text_content_buffer) + self._emit_delta(text_delta) + + self._reset_xml_parser_after_tool_call() + + def setup_parser(self): + """Set up XML parser event handlers""" + self.parser.buffer_text = True + self.parser.StartElementHandler = self._start_element + self.parser.EndElementHandler = self._end_element + self.parser.CharacterDataHandler = self._char_data + + def set_tools(self, tools: list[ChatCompletionToolsParam] | None): + """Set tool configuration information""" + self.tools = tools + + def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None: + """Extract function name from various formats""" + if attrs and "name" in attrs: + return attrs["name"] + + if "=" in name: + parts = name.split("=", 1) + if len(parts) == 2 and parts[0] == "function": + return parts[1] + + return None + + def _extract_parameter_name(self, name: str, attrs: dict[str, str]) -> str | None: + """Extract parameter name from various formats""" + if attrs and "name" in attrs: + return attrs["name"] + + if "=" in name: + parts = name.split("=", 1) + if len(parts) == 2 and parts[0] == "parameter": + return parts[1] + + return None + + def _get_param_type(self, param_name: str) -> str: + """Get parameter type based on tool configuration, defaults to string + Args: + param_name: Parameter name + + Returns: + Parameter type + """ + if not self.tools or not self.current_function_name: + return "string" + + for tool in self.tools: + if not hasattr(tool, "type") or not ( + hasattr(tool, "function") and hasattr(tool.function, "name") + ): + continue + if ( + tool.type == "function" + and tool.function.name == self.current_function_name + ): + if not hasattr(tool.function, "parameters"): + return "string" + params = tool.function.parameters + if isinstance(params, dict) and "properties" in params: + properties = params["properties"] + if param_name in properties and isinstance( + properties[param_name], dict + ): + return self.repair_param_type( + str(properties[param_name].get("type", "string")) + ) + elif isinstance(params, dict) and param_name in params: + param_config = params[param_name] + if isinstance(param_config, dict): + return self.repair_param_type( + str(param_config.get("type", "string")) + ) + break + return "string" + + def repair_param_type(self, param_type: str) -> str: + """Repair unknown parameter types by treating them as string + Args: + param_type: Parameter type + + Returns: + Repaired parameter type + """ + if ( + param_type in ["string", "str", "text", "varchar", "char", "enum"] + or param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + or param_type.startswith("num") + or param_type.startswith("float") + or param_type in ["boolean", "bool", "binary"] + or ( + param_type in ["object", "array", "arr", "sequence"] + or param_type.startswith("dict") + or param_type.startswith("list") + ) + ): + return param_type + else: + return "string" + + def _convert_param_value(self, param_value: str, param_type: str) -> Any: + """Convert value based on parameter type + Args: + param_value: Parameter value + param_type: Parameter type + + Returns: + Converted value + """ + if param_value.lower() == "null": + return None + + param_type = param_type.strip().lower() + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + return param_value + elif ( + param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + ): + try: + return int(param_value) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' is not an integer, degenerating to string.", + param_value, + ) + return param_value + elif param_type.startswith("num") or param_type.startswith("float"): + try: + float_param_value: float = float(param_value) + return ( + float_param_value + if float_param_value - int(float_param_value) != 0 + else int(float_param_value) + ) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' is not a float, degenerating to string.", + param_value, + ) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + return param_value == "true" + else: + return param_value + + def _convert_for_json_streaming(self, converted_value: Any, param_type: str) -> str: + """Convert converted_value based on + whether it's empty and if type is string + Args: + converted_value: Converted value + param_type: Parameter type + + Returns: + Converted string for streaming output + """ + # Check if value is empty, but exclude numeric 0 + if converted_value is None or converted_value == "": + return "" + + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + # String type, remove double quotes + return json.dumps(converted_value, ensure_ascii=False)[1:-1] + else: + # Non-string type, return complete JSON string + if not isinstance(converted_value, str): + return json.dumps(converted_value, ensure_ascii=False) + else: + return converted_value + + def _reset_xml_parser_after_tool_call(self): + """ + Each tool_call is treated as a separate XML document, + so we need to reset the parser after each tool_call. + """ + + # recreate XML parser + self.parser = ParserCreate() + self.setup_parser() + + # Reset current tool_call state + if self.current_call_id: + self.last_completed_call_id = self.current_call_id + self.current_call_id = None + self.current_function_name = None + self.current_function_open = False + self.parameters = {} + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.current_param_is_first = False + self.should_emit_end_newline = False + self.start_quote_emitted = False + self.text_content_buffer = "" + + # Reset preprocessing and deferred parsing state + self._pre_inside_parameter = False + self._pre_param_buffer = "" + self._pre_current_param_name = None + self.defer_current_parameter = False + self.deferred_param_raw_value = "" + + +@ToolParserManager.register_module("step3p5") +class Step3p5ToolParser(ToolParser): + def __init__(self, tokenizer: TokenizerLike): + super().__init__(tokenizer) + self.parser = StreamingXMLToolCallParser() + + # Add missing attributes for compatibility with serving_chat.py + self.prev_tool_call_arr: list[dict] = [] + self.streamed_args_for_tool: list[str] = [] + + logger.info( + "vLLM Successfully import tool parser %s !", self.__class__.__name__ + ) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + self.parser.reset_streaming_state() + # Reset tool call tracking arrays for new extraction + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [] + if request: + self.parser.set_tools(request.tools) + result = self.parser.parse_single_streaming_chunks(model_output) + if not result.tool_calls: + return ExtractedToolCallInformation( + tool_calls=[], + tools_called=False, + content=result.content, + ) + else: + tool_calls = [] + for tool_call in result.tool_calls: + if tool_call.function and tool_call.function.name: + tool_calls.append( + ToolCall( + id=tool_call.id, + type=tool_call.type, + function=FunctionCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) + ) + + # Update tool call tracking arrays for compatibility + tool_index = ( + tool_call.index + if tool_call.index is not None + else len(self.prev_tool_call_arr) - 1 + ) + + # Ensure we have enough entries in our tracking arrays + while len(self.prev_tool_call_arr) <= tool_index: + self.prev_tool_call_arr.append({"name": "", "arguments": ""}) + while len(self.streamed_args_for_tool) <= tool_index: + self.streamed_args_for_tool.append("") + + # Update tool call information + self.prev_tool_call_arr[tool_index]["name"] = ( + tool_call.function.name + ) + self.prev_tool_call_arr[tool_index]["arguments"] = ( + tool_call.function.arguments + ) + + # Update streamed arguments + if tool_call.function.arguments: + self.streamed_args_for_tool[tool_index] = ( + tool_call.function.arguments + ) + + return ExtractedToolCallInformation( + tool_calls=tool_calls, + tools_called=len(tool_calls) > 0, + content=result.content, + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + if not previous_text: + self.parser.reset_streaming_state() + # Reset tool call tracking arrays for new streaming session + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [] + if request: + self.parser.set_tools(request.tools) + + # Model sometimes outputs separately causing delta_text to be empty. + # If there were tool_calls before and all current tool_calls have ended, + # return an empty tool_call for outer streaming output + # to correctly output tool_call field + if not delta_text and delta_token_ids: + open_calls = current_text.count( + self.parser.tool_call_start_token + ) - current_text.count(self.parser.tool_call_end_token) + if ( + open_calls == 0 + and self.parser.tool_call_index > 0 + or not self.parser.tool_call_index + and current_text + ): + return DeltaMessage(content="") + return None + + # Parse the delta text and get the result + result = self.parser.parse_single_streaming_chunks(delta_text) + + # Update tool call tracking arrays based on incremental parsing results + if result and result.tool_calls: + for tool_call in result.tool_calls: + if tool_call.function: + tool_index = ( + tool_call.index + if tool_call.index is not None + else len(self.prev_tool_call_arr) - 1 + ) + + # Ensure we have enough entries in our tracking arrays + while len(self.prev_tool_call_arr) <= tool_index: + self.prev_tool_call_arr.append({"name": "", "arguments": ""}) + while len(self.streamed_args_for_tool) <= tool_index: + self.streamed_args_for_tool.append("") + + # Update tool name if provided + if tool_call.function.name: + self.prev_tool_call_arr[tool_index]["name"] = ( + tool_call.function.name + ) + + # Update arguments incrementally + if tool_call.function.arguments is not None: + # Concatenate the incremental arguments + # to the existing streamed arguments + self.prev_tool_call_arr[tool_index]["arguments"] += ( + tool_call.function.arguments + ) + self.streamed_args_for_tool[tool_index] += ( + tool_call.function.arguments + ) + return result + + def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool: + """ + Skip the remaining_call calculation in serving_chat + """ + return False diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 6679c8dd5548..5094e3fbb844 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -97,6 +97,7 @@ def __getitem__(self, key): ultravox="UltravoxConfig", step3_vl="Step3VLConfig", step3_text="Step3TextConfig", + step3p5="Step3p5Config", qwen3_asr="Qwen3ASRConfig", qwen3_next="Qwen3NextConfig", lfm2_moe="Lfm2MoeConfig", diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 2f8179602b2e..7cd236532154 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -52,6 +52,7 @@ "Step3VLConfig": "vllm.transformers_utils.configs.step3_vl", "Step3VisionEncoderConfig": "vllm.transformers_utils.configs.step3_vl", "Step3TextConfig": "vllm.transformers_utils.configs.step3_vl", + "Step3p5Config": "vllm.transformers_utils.configs.step3p5", "Qwen3ASRConfig": "vllm.transformers_utils.configs.qwen3_asr", "Qwen3NextConfig": "vllm.transformers_utils.configs.qwen3_next", "Tarsier2Config": "vllm.transformers_utils.configs.tarsier2", @@ -95,6 +96,7 @@ "Step3VLConfig", "Step3VisionEncoderConfig", "Step3TextConfig", + "Step3p5Config", "Qwen3ASRConfig", "Qwen3NextConfig", "Tarsier2Config", diff --git a/vllm/transformers_utils/configs/step3p5.py b/vllm/transformers_utils/configs/step3p5.py new file mode 100644 index 000000000000..435afd938212 --- /dev/null +++ b/vllm/transformers_utils/configs/step3p5.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +from transformers.configuration_utils import PretrainedConfig + + +class Step3p5Config(PretrainedConfig): + model_type = "step3p5" + + def __init__( + self, + hidden_size: int = 5120, + intermediate_size: int = 13312, + num_attention_heads: int = 40, + num_attention_groups: int = 8, + num_hidden_layers: int = 48, + max_seq_len: int = 4096, + vocab_size: int = 65536, + rms_norm_eps: float = 1e-5, + moe_every_n_layer: int = 2, + use_moe: bool = False, + moe_intermediate_size: int = 10240, + moe_num_experts: int = 16, + moe_top_k: int = 4, + moe_layer_offset: int = 0, + rope_theta: float | list[float] | None = 500000, + rope_scaling: dict[str, Any] | None = None, + head_dim: int | None = None, + share_expert_dim: int | None = None, + norm_expert_weight: bool = True, + bos_token_id: list[int] | int | None = None, + eos_token_id: list[int] | int | None = None, + moe_router_activation: str = "softmax", + moe_router_scaling_factor: float = 1.0, + att_impl_type: str = "GQA", + use_head_wise_attn_gate: bool = False, + use_moe_router_bias: bool = True, + need_fp32_gate: bool = True, + layer_types: list[str] | None = None, + use_rope_layers: list[bool] | None = None, + yarn_only_types: list[str] | None = None, + attention_other_setting: dict[str, Any] | None = None, + num_nextn_predict_layers: int = 0, + swiglu_limits: list[float] | None = None, + swiglu_limits_shared: list[float] | None = None, + max_position_embeddings: int | None = None, + **kwargs, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_attention_groups = num_attention_groups + self.num_hidden_layers = num_hidden_layers + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.rms_norm_eps = rms_norm_eps + self.use_moe = use_moe + self.moe_intermediate_size = moe_intermediate_size + self.moe_every_n_layer = moe_every_n_layer + self.moe_num_experts = moe_num_experts + self.num_experts_per_tok = moe_top_k + self.moe_top_k = moe_top_k + self.moe_layer_offset = moe_layer_offset + + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.head_dim = head_dim + if share_expert_dim is None: + self.share_expert_dim = self.moe_intermediate_size * self.moe_top_k + else: + self.share_expert_dim = share_expert_dim + self.norm_expert_weight = norm_expert_weight + + self.max_position_embeddings = max_position_embeddings + self.moe_router_activation = moe_router_activation + self.moe_router_scaling_factor = moe_router_scaling_factor + self.use_moe_router_bias = use_moe_router_bias + self.need_fp32_gate = need_fp32_gate + + self.att_impl_type = att_impl_type + self.use_head_wise_attn_gate = use_head_wise_attn_gate + self.layer_types = layer_types + self.use_rope_layers = use_rope_layers + self.yarn_only_types = yarn_only_types + self.attention_other_setting = attention_other_setting + self.num_nextn_predict_layers = num_nextn_predict_layers + self.swiglu_limits = swiglu_limits + self.swiglu_limits_shared = swiglu_limits_shared + + resolved_bos_token_id = 1 if bos_token_id is None else bos_token_id + resolved_eos_token_id = [2, 3] if eos_token_id is None else eos_token_id + self.bos_token_id = resolved_bos_token_id + self.eos_token_id = resolved_eos_token_id + + super().__init__( + bos_token_id=resolved_bos_token_id, + eos_token_id=resolved_eos_token_id, + **kwargs, + )