From f9af118e1fff17b1928cd0e99b0f415041dd5b86 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Tue, 27 Jan 2026 16:03:09 +0800 Subject: [PATCH 01/34] step3p5-turbo-support --- vllm/config/speculative.py | 10 +- vllm/envs.py | 4 + vllm/model_executor/layers/activation.py | 63 ++ .../layers/fused_moe/deep_gemm_moe.py | 9 +- .../layers/fused_moe/fallback.py | 2 + .../layers/fused_moe/fused_moe.py | 12 +- .../fused_moe/fused_moe_modular_method.py | 1 + vllm/model_executor/layers/fused_moe/layer.py | 6 + .../layers/fused_moe/modular_kernel.py | 18 +- .../fused_moe/router/custom_routing_router.py | 5 +- .../layers/fused_moe/router/router_factory.py | 3 +- .../fused_moe/unquantized_fused_moe_method.py | 1 + vllm/model_executor/layers/fused_moe/utils.py | 11 +- .../layers/quantization/bitsandbytes.py | 1 + .../compressed_tensors_moe.py | 4 + .../layers/quantization/experts_int8.py | 1 + .../model_executor/layers/quantization/fp8.py | 1 + .../layers/quantization/modelopt.py | 2 + .../layers/quantization/quark/quark_moe.py | 2 + vllm/model_executor/models/registry.py | 2 + vllm/model_executor/models/step3p5.py | 990 ++++++++++++++++++ vllm/model_executor/models/step3p5_mtp.py | 360 +++++++ vllm/reasoning/__init__.py | 4 + vllm/reasoning/step3p5_reasoning_parser.py | 135 +++ vllm/tool_parsers/__init__.py | 4 + .../tool_parsers/qwen3coder_tool_parser_rl.py | 806 ++++++++++++++ vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/step3p5.py | 123 +++ 29 files changed, 2571 insertions(+), 12 deletions(-) create mode 100644 vllm/model_executor/models/step3p5.py create mode 100644 vllm/model_executor/models/step3p5_mtp.py create mode 100644 vllm/reasoning/step3p5_reasoning_parser.py create mode 100644 vllm/tool_parsers/qwen3coder_tool_parser_rl.py create mode 100644 vllm/transformers_utils/configs/step3p5.py diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index f3de1e171af2..c7648487866c 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -41,6 +41,7 @@ "longcat_flash_mtp", "mtp", "pangu_ultra_moe_mtp", + "step3p5_mtp" ] EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes] SpeculativeMethod = Literal[ @@ -263,7 +264,14 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.update( {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]} ) - + + if hf_config.model_type == "step3p5": + hf_config.model_type = "step3p5_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", 1) + hf_config.update( + {"n_predict": n_predict, "architectures": ["Step3p5MTP"]} + ) + if initial_architecture == "MistralLarge3ForCausalLM": hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]}) diff --git a/vllm/envs.py b/vllm/envs.py index 741a2163c91f..7d2bdadc47e3 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -254,6 +254,7 @@ VLLM_DEBUG_MFU_METRICS: bool = False VLLM_DISABLE_LOG_LOGO: bool = False VLLM_LORA_DISABLE_PDL: bool = False + VLLM_USE_FUSED_ALL_REDUCE: bool = True def get_default_cache_root(): @@ -1631,6 +1632,9 @@ def _get_or_set_default() -> str: # Disable PDL for LoRA, as enabling PDL with LoRA on SM100 causes # Triton compilation to fail. "VLLM_LORA_DISABLE_PDL": lambda: bool(int(os.getenv("VLLM_LORA_DISABLE_PDL", "0"))), + # If set, step3p5 will use symmcomm inplace all reduce. + "VLLM_USE_FUSED_ALL_REDUCE": + lambda: os.getenv("VLLM_USE_FUSED_ALL_REDUCE", "true").lower() in ("1", "true"), } diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index c8822aed2e3f..3420731542ed 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -22,6 +22,29 @@ logger = init_logger(__name__) +def swigluoai_step_and_mul_out( + out: torch.Tensor, + x: torch.Tensor, + limit: float, +) -> torch.Tensor: + """Out-variant of swigluoai-step activation. + + Writes into `out`: + silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit) + """ + # Prefer the fused custom op when available (CUDA); fallback to PyTorch ops + # otherwise. + if x.is_cuda and hasattr(torch.ops._C, "swigluoai_step_and_mul"): + torch.ops._C.swigluoai_step_and_mul(out, x, limit) + else: + gate, up = x.chunk(2, dim=-1) + gate = F.silu(gate) + gate = gate.clamp(max=limit) + up = up.clamp(min=-limit, max=limit) + out.copy_(gate * up) + return out + + # --8<-- [start:fatrelu_and_mul] @CustomOp.register("fatrelu_and_mul") class FatreluAndMul(CustomOp): @@ -304,6 +327,46 @@ def extra_repr(self) -> str: return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}" +# --8<-- [start:swigluoai_step_and_mul] +@CustomOp.register("swigluoai_step_and_mul") +class SwigluOAIStepAndMul(CustomOp): + """An activation function for SwiGLU with clamping. + + Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit) + where d = x.shape[-1] // 2. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + # --8<-- [end:swigluoai_step_and_mul] + + def __init__(self, limit: float): + super().__init__() + if limit is None: + raise ValueError("SwigluOAIStepAndMul requires limit to be set.") + self.limit = limit + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + gate, up = x.chunk(2, dim=-1) + gate = F.silu(gate) + gate = gate.clamp(max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + return gate * up + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + swigluoai_step_and_mul_out(out, x, self.limit) + return out + + def extra_repr(self) -> str: + return f"limit={repr(self.limit)}" + + # --8<-- [start:gelu_new] @CustomOp.register("gelu_new") class NewGELU(CustomOp): diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index fafcf6de6140..0a94757e335b 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -185,7 +185,7 @@ def workspace_shapes( return (workspace1, workspace2, output) def _act_mul_quant( - self, input: torch.Tensor, output: torch.Tensor, activation: str + self, input: torch.Tensor, output: torch.Tensor, activation: str, activation_limit: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: assert self.block_shape is not None block_k = self.block_shape[1] @@ -199,7 +199,7 @@ def _act_mul_quant( act_out = torch.empty( (M_sum, activation_out_dim), dtype=input.dtype, device=input.device ) - self.activation(activation, act_out, input) + self.activation(activation, act_out, input, activation_limit) a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm( act_out, block_k, @@ -220,7 +220,7 @@ def _act_mul_quant( act_out = torch.empty( (M_sum, activation_out_dim), dtype=input.dtype, device=input.device ) - self.activation(activation, act_out, input) + self.activation(activation, act_out, input, activation_limit) return per_token_group_quant_fp8( act_out, block_k, column_major_scales=True, out_q=output ) @@ -242,6 +242,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + activation_limit: float | None = None, ): assert a1q_scale is not None assert a2_scale is None @@ -290,7 +291,7 @@ def apply( workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim) ) a2q, a2q_scale = self._act_mul_quant( - input=mm1_out.view(-1, N), output=quant_out, activation=activation + input=mm1_out.view(-1, N), output=quant_out, activation=activation, activation_limit=activation_limit ) mm2_out = _resize_cache(workspace2, (M_sum, K)) diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py index 07e5b80059f0..014bd30e8cbe 100644 --- a/vllm/model_executor/layers/fused_moe/fallback.py +++ b/vllm/model_executor/layers/fused_moe/fallback.py @@ -168,6 +168,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + activation_limit: float | None = None, ): experts = self._select_experts_impl(hidden_states, w1, w2) experts.apply( @@ -186,4 +187,5 @@ def apply( workspace2, expert_tokens_meta, apply_router_weight_on_input, + activation_limit ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 987388692725..2161af17ac8e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1343,6 +1343,7 @@ def inplace_fused_experts( block_shape: list[int] | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, + activation_limit: float | None = None, ) -> None: fused_experts_impl( hidden_states, @@ -1370,6 +1371,7 @@ def inplace_fused_experts( block_shape, w1_bias, w2_bias, + activation_limit, ) @@ -1398,6 +1400,7 @@ def inplace_fused_experts_fake( block_shape: list[int] | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, + activation_limit: float | None = None, ) -> None: pass @@ -1435,6 +1438,7 @@ def outplace_fused_experts( block_shape: list[int] | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, + activation_limit: float | None = None, ) -> torch.Tensor: return fused_experts_impl( hidden_states, @@ -1462,6 +1466,7 @@ def outplace_fused_experts( block_shape, w1_bias, w2_bias, + activation_limit, ) @@ -1489,6 +1494,7 @@ def outplace_fused_experts_fake( block_shape: list[int] | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, + activation_limit: float | None = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1613,6 +1619,7 @@ def fused_experts_impl( block_shape: list[int] | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, + activation_limit: float | None = None, ) -> torch.Tensor: # Check constraints. if use_int4_w4a16: @@ -1841,7 +1848,7 @@ def fused_experts_impl( ) apply_moe_activation( - activation, intermediate_cache2, intermediate_cache1.view(-1, N) + activation, intermediate_cache2, intermediate_cache1.view(-1, N), activation_limit=activation_limit, ) qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( @@ -1988,6 +1995,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + activation_limit: float | None = None, ): # Check constraints. if self.quant_config.use_int4_w4a16: @@ -2075,7 +2083,7 @@ def apply( ) self.activation( - activation, intermediate_cache2, intermediate_cache1.view(-1, N) + activation, intermediate_cache2, intermediate_cache1.view(-1, N), activation_limit=activation_limit ) a2q_scale: torch.Tensor | None = None diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 7a2244a9bc1d..eaab3bcff677 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -101,6 +101,7 @@ def apply( topk_ids=topk_ids, inplace=self.allow_inplace, activation=layer.activation, + activation_limit=getattr(layer, "activation_limit", None), global_num_experts=layer.global_num_experts, apply_router_weight_on_input=layer.apply_router_weight_on_input, expert_map=None if self.disable_expert_map else layer.expert_map, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5fe4bce7a4fc..19c5d8dc92f3 100755 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -332,6 +332,7 @@ def __init__( expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, router_logits_dtype: torch.dtype | None = None, + activation_limit: float | None = None, ): super().__init__() @@ -519,6 +520,11 @@ def __init__( self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation + self.activation_limit = activation_limit + if self.activation == "swigluoai-step" and self.activation_limit is None: + raise ValueError( + "activation='swigluoai-step' requires activation_limit to be set." + ) self.router = create_fused_moe_router( top_k=top_k, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 940a2c55f73a..beb78ec7f937 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -414,6 +414,7 @@ def __init__( self.quant_config = quant_config self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers + self.activation_limit: float | None = None @property def expects_unquantized_inputs(self) -> bool: @@ -711,9 +712,14 @@ def adjust_N_for_activation(N: int, activation: str) -> int: return N if is_no_mul else N // 2 def activation( - self, activation: str, output: torch.Tensor, input: torch.Tensor + self, activation: str, output: torch.Tensor, input: torch.Tensor, activation_limit: float | None = None ) -> None: - apply_moe_activation(activation, output, input) + apply_moe_activation( + activation, + output, + input, + activation_limit=activation_limit, + ) def enable_chunking(self): return ( @@ -741,6 +747,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: ExpertTokensMetadata | None, apply_router_weight_on_input: bool, + activation_limit: float | None = None ) -> None: """ This function computes the intermediate result of a Mixture of Experts @@ -1139,6 +1146,7 @@ def _fused_experts( expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, expert_tokens_meta: ExpertTokensMetadata | None, + activation_limit: float | None = None ) -> torch.Tensor: _, M_full, N, K, top_k = self.fused_experts.moe_problem_size( a1q, w1, w2, topk_ids @@ -1214,6 +1222,7 @@ def input_chunk_range(chunk_idx: int) -> tuple[int, int]: workspace2=workspace2, expert_tokens_meta=c_expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, + activation_limit=activation_limit, ) return fused_out @@ -1297,6 +1306,7 @@ def forward( global_num_experts: int = -1, expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, + activation_limit: float | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ This function computes a Mixture of Experts (MoE) layer using two sets @@ -1326,6 +1336,9 @@ def forward( - torch.Tensor: The output tensor after applying the MoE layer. """ + # Propagate any activation parameters to the experts implementation. + self.fused_experts.activation_limit = activation_limit + if inplace and self.shared_experts is None and not disable_inplace(): output = hidden_states else: @@ -1358,6 +1371,7 @@ def forward( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, expert_tokens_meta=expert_tokens_meta, + activation_limit=activation_limit ) return self._finalize( diff --git a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py index 0367189ca1ab..a1f931156750 100644 --- a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py +++ b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py @@ -21,6 +21,7 @@ def __init__( renormalize: bool = True, enable_eplb: bool = False, indices_type_getter: Callable[[], torch.dtype | None] | None = None, + routed_scaling_factor: float = 1.0, ): super().__init__( top_k=top_k, @@ -31,6 +32,7 @@ def __init__( ) self.custom_routing_function = custom_routing_function self.renormalize = renormalize + self.routed_scaling_factor = routed_scaling_factor @property def routing_method_type(self) -> RoutingMethodType: @@ -54,7 +56,8 @@ def _compute_routing( topk=self.top_k, renormalize=self.renormalize, ) - + if self.routed_scaling_factor != 1.0: + topk_weights *= self.routed_scaling_factor return topk_weights.to(torch.float32), topk_ids.to( torch.int32 if indices_type is None else indices_type ) diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py index 890f846d3539..330741dd5f4d 100644 --- a/vllm/model_executor/layers/fused_moe/router/router_factory.py +++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py @@ -40,7 +40,7 @@ def create_fused_moe_router( topk_group: int | None = None, scoring_func: str = "softmax", num_fused_shared_experts: int = 0, - # grouped topk + fused topk bias parameters + # grouped topk + fused topk bias parameters/ custom router function routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, # custom routing paramaters @@ -130,6 +130,7 @@ def create_fused_moe_router( renormalize=renormalize, enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, + routed_scaling_factor=routed_scaling_factor, ) if e_score_correction_bias is not None: diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 2ddaf272b147..4bcd367c000d 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -344,6 +344,7 @@ def forward_cuda( topk_ids=topk_ids, inplace=self.use_inplace, activation=layer.activation, + activation_limit=getattr(layer, "activation_limit", None), apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index a4b20505ea32..9cdbf33e0564 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -331,11 +331,12 @@ def apply_moe_activation( activation: str, output: torch.Tensor, input: torch.Tensor, + activation_limit: float | None = None, ) -> torch.Tensor: """ Apply MoE activation function. - For *_and_mul activations (silu, gelu, swigluoai): + For *_and_mul activations (silu, gelu, swigluoai, swigluoai-step): - Expects output.size(-1) * 2 == input.size(-1) For *_no_mul activations (silu_no_mul, gelu_no_mul, relu2_no_mul): @@ -358,6 +359,14 @@ def apply_moe_activation( torch.ops._C.gelu_and_mul(output, input) elif activation == "swigluoai": torch.ops._C.swigluoai_and_mul(output, input) + elif activation == "swigluoai-step": + if activation_limit is None: + raise ValueError( + "activation='swigluoai-step' requires activation_limit to be set." + ) + from vllm.model_executor.layers.activation import swigluoai_step_and_mul_out + + swigluoai_step_and_mul_out(output, input, activation_limit) # Activations without gated multiplication elif activation == SILU_NO_MUL: output.copy_(F.silu(input)) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 8b6b1e445f35..22211a17e445 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -517,6 +517,7 @@ def apply( topk_ids=topk_ids, inplace=True, activation=layer.activation, + activation_limit=getattr(layer, "activation_limit", None), apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index dbfa8fb9bd7a..2e3944f0c585 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -672,6 +672,7 @@ def apply( topk_ids, inplace=False, activation=layer.activation, + activation_limit=getattr(layer, "activation_limit", None), global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, @@ -1084,6 +1085,7 @@ def apply( topk_ids, inplace=self.use_inplace, activation=layer.activation, + activation_limit=getattr(layer, "activation_limit", None), global_num_experts=layer.global_num_experts, # TODO(rob): investigate the disable_expert_map introduced by: # https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501 @@ -1223,6 +1225,7 @@ def apply( topk_ids=topk_ids, inplace=True, activation=layer.activation, + activation_limit=getattr(layer, "activation_limit", None), apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, @@ -1980,6 +1983,7 @@ def apply( topk_ids=topk_ids, inplace=True, activation=layer.activation, + activation_limit=getattr(layer, "activation_limit", None), apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 5a0bb5d30f9e..e1d594a76730 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -151,6 +151,7 @@ def apply( topk_ids=topk_ids, inplace=True, activation=layer.activation, + activation_limit=getattr(layer, "activation_limit", None), apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 6436a9ae0abf..208cd1914a82 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1019,6 +1019,7 @@ def apply( topk_ids, inplace=self.use_inplace, activation=layer.activation, + activation_limit=getattr(layer, "activation_limit", None), global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e76c109eceda..17abc92b3970 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -969,6 +969,7 @@ def apply( topk_ids, inplace=self.use_inplace, activation=layer.activation, + activation_limit=getattr(layer, "activation_limit", None), global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, @@ -1540,6 +1541,7 @@ def apply( topk_ids, inplace=False, activation=layer.activation, + activation_limit=getattr(layer, "activation_limit", None), global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index d2f0213e8091..d519390d9089 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -400,6 +400,7 @@ def apply( topk_ids=topk_ids, inplace=True, activation=layer.activation, + activation_limit=getattr(layer, "activation_limit", None), apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, @@ -771,6 +772,7 @@ def apply( topk_ids=topk_ids, inplace=True, activation=layer.activation, + activation_limit=getattr(layer, "activation_limit", None), global_num_experts=layer.global_num_experts, apply_router_weight_on_input=layer.apply_router_weight_on_input, expert_map=layer.expert_map, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f38914a7ce33..95f6cc06527a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -189,6 +189,7 @@ "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"), "Step1ForCausalLM": ("step1", "Step1ForCausalLM"), "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"), + "Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), @@ -495,6 +496,7 @@ "MedusaModel": ("medusa", "Medusa"), "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"), "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"), + "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"), # Temporarily disabled. # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py new file mode 100644 index 000000000000..1db8d13ca0fb --- /dev/null +++ b/vllm/model_executor/models/step3p5.py @@ -0,0 +1,990 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Jurassic model.""" +import math +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn + +import vllm.envs as envs +from vllm.attention.layer import Attention +from vllm.v1.attention.backend import AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.distributed import (get_dp_group, + get_ep_group, get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIStepAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.fused_moe.shared_fused_moe import ( + SharedFusedMoE) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.sequence import IntermediateTensors + +from .interfaces import MixtureOfExperts, SupportsPP +from .utils import (PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + +def sigmoid_routing_function(hidden_states: torch.Tensor, + gating_output: torch.Tensor, topk: int, + renormalize: bool): + gating_output = gating_output.float() + gate_prob = torch.sigmoid(gating_output) + gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True) + topk_prob, indices = torch.topk(gate_prob, k=topk, dim=1) + expert_topk_weight = topk_prob + if renormalize: + expert_topk_weight = expert_topk_weight / torch.sum( + expert_topk_weight, dim=-1, keepdim=True) + return expert_topk_weight, indices.to(torch.int32) + +def Step3p5RMSNorm( + hidden_size: int, + eps: float = 1e-6, + zero_centered: bool = True, +): + if zero_centered: + return GemmaRMSNorm(hidden_size, eps) + else: + return RMSNorm(hidden_size, eps) + +def pad_param( + weight: torch.Tensor, + name: str, + param: torch.nn.Parameter, + quant_config: Optional[QuantizationConfig] = None, +) -> torch.Tensor: + """Pad 2D weight for groupwise quantization TP sharding. + + Decide whether to pad based on `param.quant_method`: + - None / UnquantizedLinearMethod / UnquantizedFusedMoEMethod => no padding + - otherwise => treat as quantized and pad if groupwise_quant config is found + """ + if weight.dim() != 2: + return weight + + quant_method = getattr(param, "quant_method", None) + if not quant_config or quant_config.get_name( + ) != "groupwise_quant" or not quant_method: + return weight + + world_size = get_tensor_model_parallel_world_size() + group_size = quant_config.group_size + + if ("down_proj.scales" in name) or ("w2_weight_scale" in name): + group_size = 1 + + ic, oc = weight.shape + if ("down" in name) or ("w2" in name): + ic_pad = int( + math.ceil(ic / group_size / world_size) * world_size * + group_size) - ic + out = torch.nn.functional.pad(weight, (0, 0, 0, ic_pad), "constant", 0) + else: + oc_pad = int( + math.ceil(oc / group_size / world_size) * world_size * + group_size) - oc + out = torch.nn.functional.pad(weight, (0, oc_pad, 0, 0), "constant", 0) + + logger.debug( + f"padding {name} ,quant_config={quant_config},original weight.shape: {weight.shape}, padded weight.shape: {out.shape}" + ) + return out + + +def _pad_size_for_groupwise_quant( + size: int, + quant_config: Optional[QuantizationConfig], +) -> int: + """Pad `size` to be a multiple of (group_size * tensor_parallel_world_size). + + This is needed for groupwise quantization TP sharding. + """ + if quant_config is None or quant_config.get_name() != "groupwise_quant": + return size + world_size = get_tensor_model_parallel_world_size() + + group_size = quant_config.group_size + multiple = world_size * group_size + return math.ceil(size / multiple) * multiple + + +class Step3p5MLP(nn.Module): + + def __init__( + self, + config: ModelConfig, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + + intermediate_size = _pad_size_for_groupwise_quant( + intermediate_size, quant_config) + + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + self.prefix = prefix + self.hidden_size = hidden_size + self.limit = None + layer_idx = int(prefix.split("layers.")[1].split(".")[0]) + if config.swiglu_limits_shared and config.swiglu_limits_shared[ + layer_idx] is not None and config.swiglu_limits_shared[ + layer_idx] != 0: + self.limit = config.swiglu_limits_shared[layer_idx] + self.act_fn = SwigluOAIStepAndMul(limit=self.limit) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(hidden_states) + # dynamo 在sharedfusedmoe里面 合不了silu的两个torch op,不如直接调用cuda op + intermediate_act = self.act_fn.forward_cuda(gate_up) + output, _ = self.down_proj(intermediate_act) + return output + +class Step3p5Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + rope_theta: Optional[Union[float, list[float]]] = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, + # Step3p5 specific args + sliding_window: Optional[int] = None, + enable_sink: bool = False, + use_head_wise_attn_gate: bool = False, + layer_types: list = None, + use_rope_layers: list = None, + yarn_only_types: list = None, + swa_num_attention_heads: Optional[int] = None, + partial_rotary_factor: float = 1.0, + zero_centered: bool = True, + ): + super().__init__() + self.hidden_size = hidden_size + self.total_num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + self.layer_idx = extract_layer_index(prefix) + if layer_types: + enable_sliding_window = layer_types[ + self.layer_idx] == "sliding_attention" + else: + enable_sliding_window = self.layer_idx % 2 == 0 + if yarn_only_types and layer_types[ + self.layer_idx] not in yarn_only_types: + rope_scaling = None + + if sliding_window is not None and enable_sliding_window: + sliding_window = (sliding_window) + if swa_num_attention_heads is not None: + num_heads = swa_num_attention_heads + self.total_num_heads = swa_num_attention_heads + if enable_sink: + self.sinks = torch.nn.Parameter(torch.empty( + self.total_num_heads // tp_size, dtype=torch.bfloat16), + requires_grad=False) + else: + self.sinks = None + else: + self.sinks = None + sliding_window = None + + if isinstance(rope_theta, list): + rope_theta = rope_theta[self.layer_idx] + + self.rank = get_tensor_model_parallel_rank() + self.partial_rotary_factor = partial_rotary_factor + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.dual_chunk_attention_config = dual_chunk_attention_config + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # vLLM >= 0.7 uses `rope_parameters` for all RoPE scaling variants. + # `rope_scaling` (HF-style) is mapped into `rope_parameters` here to + # preserve the behavior of older Step3p5 implementations. + rope_parameters: dict[str, Any] = { + "rope_type": "default", + "partial_rotary_factor": partial_rotary_factor, + } + if rope_scaling is not None: + if isinstance(rope_scaling, dict): + rope_parameters.update(rope_scaling) + elif isinstance(rope_scaling, (tuple, list)) and rope_scaling and isinstance( + rope_scaling[0], dict): + # Per-layer rope scaling configs. + if self.layer_idx < len(rope_scaling): + rope_parameters.update(rope_scaling[self.layer_idx]) + elif isinstance(rope_scaling, + (tuple, list)) and len(rope_scaling) == 2 and isinstance( + rope_scaling[0], str): + # Legacy tuple format: (type, factor) + rope_parameters.update({ + "rope_type": rope_scaling[0], + "factor": rope_scaling[1], + }) + if "type" in rope_parameters: + rope_parameters.setdefault("rope_type", rope_parameters["type"]) + rope_parameters.pop("type", None) + # Always take the per-layer resolved rope theta, instead of trusting + # any potentially list-valued rope_theta coming from rope_scaling. + rope_parameters["rope_theta"] = self.rope_theta + rope_parameters["partial_rotary_factor"] = partial_rotary_factor + + self.rotary_emb = get_rope( + head_size=self.head_dim, + max_position=max_position, + rope_parameters=rope_parameters, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + + self.q_norm = Step3p5RMSNorm(self.head_dim, + eps=rms_norm_eps, + zero_centered=zero_centered) + self.k_norm = Step3p5RMSNorm(self.head_dim, + eps=rms_norm_eps, + zero_centered=zero_centered) + self.zero_centered = zero_centered + self.use_head_wise_attn_gate = use_head_wise_attn_gate + if use_head_wise_attn_gate: + self.g_proj = ColumnParallelLinear( + hidden_size, + self.total_num_heads, + bias=False, + prefix=f"{prefix}.g_proj", + ) + + self.use_rope = True + if use_rope_layers: + self.use_rope = use_rope_layers[self.layer_idx] + + # TODO: Add sink attention + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + per_layer_sliding_window=sliding_window, + attn_type=attn_type, + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": dual_chunk_attention_config, + } if dual_chunk_attention_config else {}, + ) + + self.max_position_embeddings = max_position + assert self.partial_rotary_factor == 1 or self.partial_rotary_factor == 0.5 + self.rotary_dim = self.head_dim if self.partial_rotary_factor == 1 else self.head_dim // 2 + + def qk_norm_rope(self, q, k, positions): + # Add qk-norm + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, + self.head_dim) + q_by_head = self.q_norm(q_by_head.contiguous()) + q = q_by_head.view(q.shape) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, + self.head_dim) + k_by_head = self.k_norm(k_by_head.contiguous()) + k = k_by_head.view(k.shape) + if self.use_rope: + q, k = self.rotary_emb(positions, q, k) + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + q, k = self.qk_norm_rope(q, k, positions) + attn_output = self.attn(q, k, v) + if self.use_head_wise_attn_gate: + extra_dims, _ = self.g_proj(hidden_states) + output = attn_output.view( + *attn_output.shape[:-1], self.num_heads, + self.head_dim) * extra_dims.unsqueeze(-1).sigmoid() + attn_output = output.view(*attn_output.shape) + output, _ = self.o_proj(attn_output) + return output + +class FusedMoEBlock(nn.Module): + + def __init__(self, + config: ModelConfig, + parallel_config: ParallelConfig, + shared_experts: torch.nn.Module, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = ""): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.layer_idx = extract_layer_index(prefix) + + self.ep_size = get_ep_group().device_group.size() + self.ep_rank = get_ep_group().device_group.rank() + + self.enable_eplb = parallel_config.enable_eplb + self.n_routed_experts = config.moe_num_experts + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + + if self.tp_size > config.moe_num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.moe_num_experts}.") + + assert config.moe_dynamic_exp_p == 1, "Only support dynamic exp p=1" + + self.use_moe_router_bias = config.use_moe_router_bias + if self.use_moe_router_bias: + self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts, + dtype=torch.float32), + requires_grad=False) + custom_routing_function = self.router_bias_func + elif config.moe_router_activation == "sigmoid": + custom_routing_function = sigmoid_routing_function + else: + custom_routing_function = None + self.need_fp32_gate = config.need_fp32_gate + layer_idx = int(prefix.split("layers.")[1].split(".")[0]) + activation = "silu" + swigluoai_step_limit = None + if config.swiglu_limits and config.swiglu_limits[ + layer_idx] is not None and config.swiglu_limits[layer_idx] != 0: + swigluoai_step_limit = config.swiglu_limits[layer_idx] + activation = "swigluoai-step" + logger.info( + f"step3p5 layer_idx: {layer_idx}, activation limit: {config.swiglu_limits[layer_idx]}, will use swigluoai-step" + ) + moe_intermediate_size = _pad_size_for_groupwise_quant( + config.moe_intermediate_size, + quant_config, + ) + self.experts = SharedFusedMoE( + shared_experts=shared_experts, + num_experts=config.moe_num_experts, + top_k=config.moe_top_k, + hidden_size=config.hidden_size, + intermediate_size=moe_intermediate_size, + reduce_results=reduce_results, + renormalize=config.norm_expert_weight, + quant_config=quant_config, + activation=activation, + activation_limit=swigluoai_step_limit if swigluoai_step_limit else None, + prefix=f"{prefix}.experts", + custom_routing_function=custom_routing_function, + routed_scaling_factor=config.moe_router_scaling_factor, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) + self.gate = ReplicatedLinear(config.hidden_size, + config.moe_num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + def router_bias_func(self, hidden_states: torch.Tensor, + gating_output: torch.Tensor, topk: int, + renormalize: bool): + gate_prob = torch.sigmoid(gating_output.float()) + gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0) + _, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1) + topk_prob = torch.gather(gate_prob, 1, indices) + expert_topk_weight = topk_prob + if renormalize: + expert_topk_weight = expert_topk_weight / ( + torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20) + return expert_topk_weight, indices.to(torch.int32) + + def forward( + self, + hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + if self.need_fp32_gate: + router_logits = hidden_states.to( + torch.float32) @ self.gate.weight.to(torch.float32).t() + else: + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + shared_out, final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits) + + return shared_out, final_hidden_states.view(orig_shape) + + +class Step3p5DecoderLayer(nn.Module): + + def __init__(self, + config: ModelConfig, + parallel_config: ParallelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + use_fused_moe: bool = False, + prefix: str = "") -> None: + super().__init__() + config = config.hf_config + self.hidden_size = config.hidden_size + rope_scaling = getattr(config, "rope_scaling", None) + layer_idx = int(prefix.split("layers.")[1].split(".")[0]) + self.layer_idx = layer_idx + if cache_config is not None: + cache_config.sliding_window = None + if config.att_impl_type == "GQA": + num_attention_heads = None + num_attention_groups = None + head_dim = None + if getattr(config, "attention_other_setting", None) and getattr( + config, "layer_types", []) and config.layer_types[ + layer_idx] == config.attention_other_setting[ + 'attention_type']: + num_attention_heads = config.attention_other_setting[ + 'num_attention_heads'] + num_attention_groups = config.attention_other_setting[ + 'num_attention_groups'] + head_dim = config.attention_other_setting['head_dim'] + partial_rotary_factors = getattr(config, "partial_rotary_factors", + []) + self.self_attn = Step3p5Attention( + hidden_size=self.hidden_size, + num_heads=num_attention_heads + if num_attention_heads else config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=num_attention_groups + if num_attention_groups else config.num_attention_groups, + rope_theta=config.rope_theta, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=head_dim if head_dim else getattr( + config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + sliding_window=getattr(config, 'sliding_window', None), + enable_sink=getattr(config, "sink", False), + use_head_wise_attn_gate=getattr(config, + "use_head_wise_attn_gate", + False), + layer_types=getattr(config, "layer_types", []), + use_rope_layers=getattr(config, "use_rope_layers", []), + yarn_only_types=getattr(config, "yarn_only_types", []), + partial_rotary_factor=partial_rotary_factors[layer_idx] + if partial_rotary_factors else 1.0, + zero_centered=getattr(config, "zero_centered", False), + prefix=f"{prefix}.self_attn", + ) + else: + raise ValueError( + f"Unsupported attention implementation: {config.att_impl_type}" + ) + self.use_moe = False + self.tp_group = get_tp_group() + self.use_fused_all_reduce = get_tensor_model_parallel_world_size( + ) > 1 and get_dp_group().world_size == 1 and envs.VLLM_USE_FUSED_ALL_REDUCE + if self.use_fused_all_reduce: + logger.warning_once("Enable custom fused all reduce...") + else: + logger.warning_once("Disable custom fused all reduce...") + + moe_layers_enum = getattr(config, "moe_layers_enum", None) + if moe_layers_enum is not None: + moe_layers_idx = [ + int(i) for i in moe_layers_enum.strip().split(',') + ] + else: + # Default to 1dense. + moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] + if layer_idx in moe_layers_idx: + reduce_results = True + if self.use_fused_all_reduce or self.tp_group.world_size == 1 and get_ep_group( + ).world_size == 1: + reduce_results = False + moe_intermediate_size = _pad_size_for_groupwise_quant( + config.share_expert_dim, quant_config) + self.share_expert = Step3p5MLP( + config=config, + hidden_size=self.hidden_size, + intermediate_size=moe_intermediate_size, + hidden_act="silu", + reduce_results=reduce_results, + quant_config=quant_config, + prefix=f"{prefix}.share_expert") + self.moe = FusedMoEBlock(shared_experts=self.share_expert, + config=config, + parallel_config=parallel_config, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.moe") + self.use_moe = True + else: + self.mlp = Step3p5MLP(config=config, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act="silu", + quant_config=quant_config, + reduce_results=True, + prefix=f"{prefix}.mlp") + self.use_fused_moe = use_fused_moe + self.input_layernorm = Step3p5RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + zero_centered=config.zero_centered) + self.post_attention_layernorm = Step3p5RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + zero_centered=config.zero_centered) + self.prefix = prefix + + def add_and_maybe_inplace_all_reduce(self, in1: torch.Tensor, + in2: torch.Tensor) -> torch.Tensor: + if not self.use_fused_all_reduce: + return in1 + in2 + return self.tp_group.all_reduce(in1 + in2) + + def forward(self, positions: torch.Tensor, + hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states += residual + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if self.use_moe: + shared_output, moe_output = self.moe(hidden_states) + # share expert & moe 可以合并all reduce + ffn_output = self.add_and_maybe_inplace_all_reduce( + moe_output, shared_output) + else: + ffn_output = self.mlp(hidden_states) + hidden_states = ffn_output + residual + return hidden_states + + +# Note: max-num-batched-tokens 开到64k 编译会不通过,小于64k没啥问题 +@support_torch_compile +class Step3p5Model(nn.Module): + + def __init__(self, + vllm_config: VllmConfig, + prefix: str = "", + use_fused_moe: bool = False) -> None: + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.vocab_size = config.vocab_size + self.config = config + self.use_fused_moe = use_fused_moe + + self.moe_num_experts = config.moe_num_experts + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Step3p5DecoderLayer( + config=vllm_config.model_config, + parallel_config=vllm_config.parallel_config, + cache_config=cache_config, + quant_config=quant_config, + use_fused_moe=self.use_fused_moe, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = Step3p5RMSNorm(config.hidden_size, + eps=config.rms_norm_eps, + zero_centered=config.zero_centered) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + }) + + return hidden_states + + +class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + lora_config = vllm_config.lora_config + self.config = config + self.vllm_config = vllm_config + + self.model = Step3p5Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + use_fused_moe=True) + + self.moe_layers: list[FusedMoEBlock] = [] + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + assert isinstance(layer, Step3p5DecoderLayer) + if hasattr(layer, "moe") and isinstance(layer.moe, FusedMoEBlock): + self.moe_layers.append(layer.moe) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # Set MoE hyperparameters + self.expert_weights = [] + assert len(self.moe_layers) > 0, "No MoE layers found in the model." + example_layer = self.moe_layers[0] + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_layer.n_logical_experts + self.num_physical_experts = example_layer.n_physical_experts + self.num_local_physical_experts = example_layer.n_local_physical_experts + self.num_routed_experts = example_layer.n_routed_experts + self.num_redundant_experts = example_layer.n_redundant_experts + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None): + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.model.norm(hidden_states) + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_tokens(input_ids) + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # layer_idx = layer.layer_idx + experts = layer.experts + assert isinstance(experts, FusedMoE) + # Register the expert weights. + self.expert_weights.append(experts.get_expert_weights()) + experts.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = (num_physical_experts - + self.num_logical_experts) + for layer in self.moe_layers: + assert isinstance(layer, FusedMoEBlock) + layer.n_local_physical_experts = num_local_physical_experts + layer.n_physical_experts = num_physical_experts + layer.n_redundant_experts = self.num_redundant_experts + layer.experts.update_expert_map() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + vllm_config = self.vllm_config + config = vllm_config.model_config.hf_config + assert config.num_attention_groups > 1, "Only support GQA" + #GQA + qkv_params_mapping = [] + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params = set() + + if self.model.use_fused_moe: + is_groupwise_quant = self.vllm_config.quant_config is not None and self.vllm_config.quant_config.get_name( + ) == "groupwise_quant" + if is_groupwise_quant: + expert_params_mapping = [ + (".moe.experts.w13_weight", ".moe.gate_proj.qweight", + "w1"), + (".moe.experts.w13_weight", ".moe.up_proj.qweight", "w3"), + (".moe.experts.w2_weight", ".moe.down_proj.qweight", "w2"), + (".moe.experts.w13_weight_scale", ".moe.gate_proj.scales", + "w1"), + (".moe.experts.w13_weight_scale", ".moe.up_proj.scales", + "w3"), + (".moe.experts.w2_weight_scale", ".moe.down_proj.scales", + "w2"), + ] + else: + expert_params_mapping = [ + (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), + (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), + (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2") + ] + else: + expert_params_mapping = [] + + disable_moe_stacked_params = [data[1] for data in expert_params_mapping] + + for name, loaded_weight in weights: + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + if any(disable_moe_stacked_param in name + for disable_moe_stacked_param in + disable_moe_stacked_params): + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + loaded_weight = pad_param(loaded_weight, name, param, + self.vllm_config.quant_config) + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + moe_expert_num = self.model.moe_num_experts + assert loaded_weight.shape[0] == moe_expert_num + for expert_id in range(moe_expert_num): + loaded_weight_expert = loaded_weight[expert_id] + loaded_weight_expert = pad_param( + loaded_weight_expert, name, param, + self.vllm_config.quant_config) + weight_loader(param, + loaded_weight_expert, + name, + shard_id=shard_id, + expert_id=expert_id) + loaded_params.add(name) + break + else: + for (param_name, weight_name, start_idx, + end_idx) in qkv_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + dim = param.shape[param.output_dim] + begin_idx = int(start_idx * dim) + end_idx = int(end_idx * dim) + param_slice = param.narrow(param.output_dim, begin_idx, + end_idx - begin_idx) + param_slice.copy_(loaded_weight) + loaded_params.add(name) + break + else: + if is_pp_missing_parameter(name, self): + continue + if "expert_bias" in name: + logger.warning_once("ignore expert_bias") + continue + param = params_dict[name] + loaded_weight = pad_param( + loaded_weight, name, param, + self.vllm_config.quant_config) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +def get_spec_layer_idx_from_weight_name(config: ModelConfig, + weight_name: str) -> Optional[int]: + if hasattr(config, + "num_nextn_predict_layers") and (config.num_nextn_predict_layers + > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx+i}."): + return layer_idx + i + return None diff --git a/vllm/model_executor/models/step3p5_mtp.py b/vllm/model_executor/models/step3p5_mtp.py new file mode 100644 index 000000000000..eb3da14f2b7d --- /dev/null +++ b/vllm/model_executor/models/step3p5_mtp.py @@ -0,0 +1,360 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.logger import init_logger +# from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.sequence import IntermediateTensors + +from .step3p5 import Step3p5RMSNorm, Step3p5DecoderLayer, get_spec_layer_idx_from_weight_name +from .utils import maybe_prefix + +logger = init_logger(__name__) + + +class SharedHead(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.norm = Step3p5RMSNorm(config.hidden_size, + eps=config.rms_norm_eps, + zero_centered=config.zero_centered) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +class Step3p5AMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + parallel_config: ParallelConfig = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.enorm = Step3p5RMSNorm(config.hidden_size, + eps=config.rms_norm_eps, + zero_centered=config.zero_centered) + self.hnorm = Step3p5RMSNorm(config.hidden_size, + eps=config.rms_norm_eps, + zero_centered=config.zero_centered) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.mtp_block = Step3p5DecoderLayer(model_config, + parallel_config=parallel_config, + cache_config=cache_config, + quant_config=quant_config, + use_fused_moe=True, + prefix=f"{prefix}.mtp_block") + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states = self.mtp_block(positions=positions, + hidden_states=hidden_states) + return hidden_states + + +class Step3p5AMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + Step3p5AMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + parallel_config=vllm_config.parallel_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.use_fused_moe = True + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + current_step_idx = (spec_step_idx % self.num_mtp_layers) + return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( + input_ids, + positions, + previous_hidden_states, + inputs_embeds, + current_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers[str(self.mtp_start_layer_idx + + current_step_idx)] + logits = self.logits_processor(mtp_layer.shared_head.head, + mtp_layer.shared_head(hidden_states)) + return logits + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + +class Step3p5MTP(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model = Step3p5AMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, hidden_states, + inputs_embeds, spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, spec_step_idx) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + vllm_config = self.vllm_config + config = vllm_config.model_config.hf_config + + if config.num_attention_groups > 1: + #GQA + # qkv_params_mapping=[] + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + elif config.att_impl_type == "MLA": + # qkv_params_mapping = [] + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + else: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + if self.model.use_fused_moe: + if self.vllm_config.quant_config is not None and self.vllm_config.quant_config.get_name( + ) == "groupwise_quant": + expert_params_mapping = [ + (".moe.experts.w13_weight", ".moe.gate_proj.qweight", + "w1"), + (".moe.experts.w13_weight", ".moe.up_proj.qweight", "w3"), + (".moe.experts.w2_weight", ".moe.down_proj.qweight", "w2"), + (".moe.experts.w13_weight_scale", ".moe.gate_proj.scales", + "w1"), + (".moe.experts.w13_weight_scale", ".moe.up_proj.scales", + "w3"), + (".moe.experts.w2_weight_scale", ".moe.down_proj.scales", + "w2"), + ] + else: + expert_params_mapping = [ + (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), + (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), + (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2") + ] + else: + expert_params_mapping = [] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if "embed_tokens" not in name and spec_layer is None: + continue + name = self._rewrite_spec_layer_name(spec_layer, name) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + if "experts" in name or "moe" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + for expert_id in range(loaded_weight.shape[0]): + loaded_weight_expert = loaded_weight[expert_id] + weight_loader(param, + loaded_weight_expert, + name, + shard_id=shard_id, + expert_id=expert_id) + loaded_params.add(name) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith( + ".bias" + ) and name not in params_dict or "tok_embeddings" in name: + continue + + if f"{config.num_hidden_layers}.transformer." in name: + name = name.replace(".transformer.", ".") + if "shared_head" in name: + name = name.replace("shared_head.output", + "shared_head.head") + if "embed_tokens" in name: + assert hasattr( + self.config, "num_nextn_predict_layers" + ) and self.config.num_nextn_predict_layers > 0 + name = "model.embed_tokens.weight" + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + params_need_to_load = set(params_dict.keys()) + # Some KV cache scales are optional: checkpoints may omit them and vLLM + # will fall back to default scales during initialization. + optional_params = { + name + for name, param in params_dict.items() + if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")) + and getattr(param, "numel", lambda: 0)() == 1 + and getattr(param, "requires_grad", False) is False + } + params_need_to_load -= optional_params + if params_need_to_load != loaded_params: + missing_params = list(params_need_to_load - loaded_params) + param_name_example = missing_params[0] + raise RuntimeError( + f"Some parameters like {param_name_example} are not in the checkpoint and will falsely use random initialization" + ) + return loaded_params + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .mtp_block for modules in transformer layer block for spec layer + """ + spec_layer_weight_names = [ + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + ] + spec_layer_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace(f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.mtp_block.") + return name diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 7201d3fbdfcd..8be56b56e9ca 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -84,6 +84,10 @@ "step3_reasoning_parser", "Step3ReasoningParser", ), + "step3p5": ( + "step3p5_reasoning_parser", + "Step3p5ReasoningParser", + ), } diff --git a/vllm/reasoning/step3p5_reasoning_parser.py b/vllm/reasoning/step3p5_reasoning_parser.py new file mode 100644 index 000000000000..50ac65aaa1b6 --- /dev/null +++ b/vllm/reasoning/step3p5_reasoning_parser.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ResponsesRequest, +) +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class Step3p5ReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for Step3p5 model. + + Step3p5 uses the ... format, but it tends to emit an extra + newline immediately before and/or after the token. This parser trims: + - the newline right before + - the newline right after + """ + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + # Used to hold a trailing "\n" from reasoning content so we can decide + # whether it is immediately before . + self._pending_reasoning_newline = False + + def extract_reasoning( + self, + model_output: str, + request: ChatCompletionRequest | ResponsesRequest, + ) -> tuple[str | None, str | None]: + reasoning, content = super().extract_reasoning(model_output, request) + if reasoning is not None: + reasoning = reasoning.removesuffix("\n") + if content is not None: + content = content.removeprefix("\n") + return reasoning or None, content or None + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + # Drop the immediate newline that models often emit after . + if previous_text.endswith(self.end_token) and delta_text: + if delta_text == "\n": + return None + elif delta_text.startswith("\n"): + remaining = delta_text.removeprefix("\n") + return DeltaMessage(content=remaining) if remaining else None + + # If we are about to see the end token, any pending newline is + # immediately before and should be dropped. + if self.end_token_id in delta_token_ids and self._pending_reasoning_newline: + self._pending_reasoning_newline = False + + ret = super().extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + + if ret is None: + return None + + # Compatibility path for models that don't generate the start token: + # treat everything before as reasoning and everything after + # as content. + if ( + self.start_token_id not in previous_token_ids + and self.start_token_id not in delta_token_ids + ): + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + reasoning = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token) :] + ret = DeltaMessage(reasoning=reasoning, content=content or None) + elif self.end_token_id in previous_token_ids: + ret = DeltaMessage(content=delta_text or None) + else: + ret = DeltaMessage(reasoning=delta_text or None) + + reasoning_to_output = ret.reasoning + content_to_output = ret.content + + # Reasoning: handle the newline immediately before . + if reasoning_to_output is not None: + if self._pending_reasoning_newline: + reasoning_to_output = "\n" + reasoning_to_output + self._pending_reasoning_newline = False + + if reasoning_to_output.endswith("\n"): + reasoning_to_output = reasoning_to_output.removesuffix("\n") + if self.end_token in delta_text: + # Trailing "\n" is right before , drop it. + self._pending_reasoning_newline = False + else: + # Hold the trailing "\n" until we know whether follows. + self._pending_reasoning_newline = True + + # Content: handle the newline immediately after . + if content_to_output is not None: + # If we have content, reasoning must have ended. + self._pending_reasoning_newline = False + + if self.end_token in delta_text and content_to_output.startswith("\n"): + content_to_output = content_to_output.removeprefix("\n") + + reasoning_to_output = reasoning_to_output or None + content_to_output = content_to_output or None + if reasoning_to_output is None and content_to_output is None: + return None + + return DeltaMessage(reasoning=reasoning_to_output, content=content_to_output) diff --git a/vllm/tool_parsers/__init__.py b/vllm/tool_parsers/__init__.py index b26638c0959b..fee125c398ad 100644 --- a/vllm/tool_parsers/__init__.py +++ b/vllm/tool_parsers/__init__.py @@ -122,6 +122,10 @@ "qwen3coder_tool_parser", "Qwen3CoderToolParser", ), + "qwen3_coder_rl": ( + "qwen3coder_tool_parser_rl", + "Qwen3CoderToolParserRL", + ), "qwen3_xml": ( "qwen3xml_tool_parser", "Qwen3XMLToolParser", diff --git a/vllm/tool_parsers/qwen3coder_tool_parser_rl.py b/vllm/tool_parsers/qwen3coder_tool_parser_rl.py new file mode 100644 index 000000000000..78c5ee4f9585 --- /dev/null +++ b/vllm/tool_parsers/qwen3coder_tool_parser_rl.py @@ -0,0 +1,806 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast +import json +import uuid +from collections.abc import Sequence +from typing import Any, Optional, Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +class Qwen3CoderToolParserRL(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + # Override base class type - we use string IDs for tool calls + self.current_tool_id: Optional[str] = None # type: ignore + self.streamed_args_for_tool: list[str] = [] + + # Sentinel tokens for streaming mode + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + self.tool_call_prefix: str = "(.*?)", re.DOTALL) + self.tool_call_function_regex = re.compile( + r"", re.DOTALL) + self.tool_call_parameter_regex = re.compile( + r"", re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + if (self.tool_call_start_token_id is None + or self.tool_call_end_token_id is None): + raise RuntimeError( + "Qwen3 XML Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") + + # Get EOS token ID for EOS detection + self.eos_token_id = getattr(self.model_tokenizer, 'eos_token_id', None) + + logger.info("vLLM Successfully import tool parser %s !", + self.__class__.__name__) + + def _generate_tool_call_id(self) -> str: + """Generate a unique tool call ID.""" + return f"call_{uuid.uuid4().hex[:24]}" + + def _reset_streaming_state(self): + """Reset all streaming state for a new request.""" + self._processed_length: int = 0 # Position of last processed character + self._tool_call_index: int = 0 # Number of tool calls processed so far + self.streaming_request = None # Current request being processed + + def _get_arguments_config( + self, func_name: str, + tools: Optional[list[ChatCompletionToolsParam]]) -> dict: + """Extract argument configuration for a function.""" + if tools is None: + return {} + for config in tools: + if not hasattr(config, "type") or not (hasattr( + config, "function") and hasattr(config.function, "name")): + continue + if config.type == "function" and config.function.name == func_name: + if not hasattr(config.function, "parameters"): + return {} + params = config.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + else: + return {} + logger.warning("Tool '%s' is not defined in the tools list.", + func_name) + return {} + + def _convert_param_value(self, param_value: str, param_name: str, + param_config: dict, func_name: str) -> Any: + """Convert parameter value based on its type in the schema.""" + # Handle null value for any type + if param_value.lower() == "null": + return None + + if param_name not in param_config: + if param_config != {}: + logger.warning( + "Parsed parameter '%s' is not defined in the tool " + "parameters for tool '%s', directly returning the " + "string value.", param_name, func_name) + return param_value + + if isinstance(param_config[param_name], + dict) and "type" in param_config[param_name]: + param_type = str(param_config[param_name]["type"]).strip().lower() + else: + param_type = "string" + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + return param_value + elif param_type.startswith("int") or param_type.startswith( + "uint") or param_type.startswith( + "long") or param_type.startswith( + "short") or param_type.startswith("unsigned"): + try: + return int(param_value) + except (ValueError, TypeError) as e: + raise ValueError( + f"Parsed value '{param_value}' of parameter '{param_name}' " + f"is not an integer in tool '{func_name}'.") from e + elif param_type.startswith("num") or param_type.startswith("float"): + try: + float_param_value = float(param_value) + return float_param_value if float_param_value - int( + float_param_value) != 0 else int(float_param_value) + except (ValueError, TypeError) as e: + raise ValueError( + f"Parsed value '{param_value}' of parameter '{param_name}' " + f"is not a float in tool '{func_name}'.") from e + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + if param_value not in ["true", "false"]: + raise ValueError( + f"Parsed value '{param_value}' of parameter '{param_name}' " + f"is not a boolean (`true` or `false`) in tool '{func_name}'." + ) + return param_value == "true" + else: + if param_type in ["object", "array", "arr" + ] or param_type.startswith( + "dict") or param_type.startswith("list"): + try: + param_value = json.loads(param_value) + return param_value + except (json.JSONDecodeError, TypeError, ValueError) as e: + raise ValueError( + f"Parsed value '{param_value}' of parameter '{param_name}' " + f"cannot be parsed with json.loads in tool '{func_name}'." + ) from e + try: + param_value = ast.literal_eval(param_value) # safer + except (ValueError, SyntaxError, TypeError) as e: + raise ValueError( + f"Parsed value '{param_value}' of parameter '{param_name}' " + f"cannot be converted via Python `ast.literal_eval()` in tool " + f"'{func_name}'.") from e + return param_value + + def _parse_xml_function_call( + self, function_call_str: str, + tools: Optional[list[ChatCompletionToolsParam]] + ) -> Optional[ToolCall]: + + # Extract function name + end_index = function_call_str.index(">") + function_name = function_call_str[:end_index] + param_config = self._get_arguments_config(function_name, tools) + parameters = function_call_str[end_index + 1:] + param_dict = {} + for match_text in self.tool_call_parameter_regex.findall(parameters): + idx = match_text.index(">") + param_name = match_text[:idx] + param_value = str(match_text[idx + 1:]) + # Remove prefix and trailing \n + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + try: + param_dict[param_name] = self._convert_param_value( + param_value, param_name, param_config, function_name) + except Exception: + return None + return ToolCall( + type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(param_dict, + ensure_ascii=False)), + ) + + def _get_function_calls(self, model_output: str) -> list[str]: + # Find all tool calls + raw_tool_calls = self.tool_call_complete_regex.findall(model_output) + + # if no closed tool_call tags found, return empty list + if len(raw_tool_calls) == 0: + return [] + + raw_function_calls = [] + for tool_call in raw_tool_calls: + function_matches = self.tool_call_function_regex.findall(tool_call) + raw_function_calls.extend(function_matches) + + return raw_function_calls + + def _check_format(self, model_output: str) -> bool: + """Check if model output contains properly formatted tool call. + + Requirements: + 1. Must have closed tool_call tags (...) + 2. Must have closed function tags () + 3. If parameter tags exist, they must be closed and correct + + Returns True if the format is valid, False otherwise. + """ + # Check 1: Must have closed tool_call tags + tool_call_matches = self.tool_call_complete_regex.findall(model_output) + if len(tool_call_matches) == 0: + return False + + # Check 2: Must have closed function tags within tool_call + has_valid_function = False + for tool_call_content in tool_call_matches: + function_matches = self.tool_call_function_regex.findall( + tool_call_content) + if len(function_matches) > 0: + has_valid_function = True + # Check if there's an unclosed function tag + if self.tool_call_prefix in tool_call_content and self.function_end_token not in tool_call_content: + return False + + if not has_valid_function: + return False + + # Check 3: If parameter tags exist, they must be closed and correct + for tool_call_content in tool_call_matches: + # Count opening and closing parameter tags + param_open_count = tool_call_content.count(self.parameter_prefix) + param_close_count = tool_call_content.count( + self.parameter_end_token) + + # If there are parameter tags, they must be balanced + if param_open_count > 0: + if param_open_count != param_close_count: + return False + # Check if all parameter tags are properly closed using regex + param_matches = self.tool_call_parameter_regex.findall( + tool_call_content) + if len(param_matches) != param_open_count: + return False + + return True + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + # Quick check to avoid unnecessary processing + if not self._check_format(model_output): + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + function_calls = self._get_function_calls(model_output) + if len(function_calls) == 0: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + tool_calls = [ + self._parse_xml_function_call(function_call_str, request.tools) + for function_call_str in function_calls + ] + + # Populate prev_tool_call_arr for serving layer to set finish_reason + self.prev_tool_call_arr.clear() # Clear previous calls + for tool_call in tool_calls: + if tool_call: + self.prev_tool_call_arr.append({ + "name": + tool_call.function.name, + "arguments": + tool_call.function.arguments, + }) + + # Extract content before tool calls + content_index = model_output.find(self.tool_call_start_token) + content = model_output[:content_index] # .rstrip() + + return ExtractedToolCallInformation( + tools_called=(len(tool_calls) > 0), + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.warning("Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def _find_first_complete_tool_call_end(self, + text: str, + start_pos: int = 0) -> int: + """Find the end position of the first complete tool call. + + Args: + text: Text to search in + start_pos: Position to start searching from + + Returns: + Position after the first tag, or -1 if incomplete + + Example: + "......" returns position after + """ + # Find tool call start + start_idx = text.find(self.tool_call_start_token, start_pos) + if start_idx == -1: + return -1 + + # Find matching end token + end_idx = text.find(self.tool_call_end_token, + start_idx + len(self.tool_call_start_token)) + if end_idx == -1: + return -1 # Incomplete tool call + + # Return position after end token + return end_idx + len(self.tool_call_end_token) + + def _find_tool_call_start(self, text: str, start_pos: int = 0) -> int: + """Find the start position of next tool call. + + Args: + text: Text to search in + start_pos: Position to start searching from + + Returns: + Position of token, or -1 if not found + """ + return text.find(self.tool_call_start_token, start_pos) + + def _extract_content_between_tool_calls_list(self, text: str) -> list[str]: + """Extract content segments after each tool call. + + For n tool calls, returns n segments where segment[i] is the content + after tool_call[i] (before tool_call[i+1] or at the end). + + Empty or whitespace-only segments are represented as empty string "". + + Args: + text: Text containing tool calls + + Returns: + List of content segments (one per tool call) + """ + content_segments = [] + pos = 0 + + while True: + # Find end of current tool call + end_pos = text.find(self.tool_call_end_token, pos) + if end_pos == -1: + break + + # Move past the end token + end_pos += len(self.tool_call_end_token) + + # Find start of next tool call + next_start = self._find_tool_call_start(text, end_pos) + + # Extract content between current end and next start (or text end) + content = text[end_pos:next_start] if next_start != -1 else text[ + end_pos:] + + # Store content (empty string if whitespace-only) + content_segments.append(content if content.strip() else "") + + if next_start == -1: + break + pos = next_start + + return content_segments + + def _convert_tool_calls_to_deltas( + self, + tool_calls: list[ToolCall], + starting_index: int = 0) -> list[DeltaMessage]: + """Convert complete ToolCall list to delta message sequence. + + Format: header (function name) -> { -> param1 -> param2 -> ... -> } + + Args: + tool_calls: List of tool calls to convert + starting_index: Starting index for tool calls (default 0) + """ + deltas = [] + for i, tool_call in enumerate(tool_calls): + index = starting_index + i + tool_id = self._generate_tool_call_id() + + # Header delta: function name + deltas.append( + DeltaMessage(tool_calls=[ + DeltaToolCall( + index=index, + id=tool_id, + function=DeltaFunctionCall( + name=tool_call.function.name, arguments=""), + type="function", + ) + ])) + + # Opening brace + deltas.append( + DeltaMessage(tool_calls=[ + DeltaToolCall(index=index, + function=DeltaFunctionCall(arguments="{")) + ])) + + # Parse arguments JSON to extract parameters + try: + args_dict = json.loads(tool_call.function.arguments) + param_names = list(args_dict.keys()) + + for param_idx, param_name in enumerate(param_names): + param_value = args_dict[param_name] + serialized_value = json.dumps(param_value, + ensure_ascii=False) + + if param_idx == 0: + json_fragment = f'"{param_name}": {serialized_value}' + else: + json_fragment = f', "{param_name}": {serialized_value}' + + deltas.append( + DeltaMessage(tool_calls=[ + DeltaToolCall(index=index, + function=DeltaFunctionCall( + arguments=json_fragment)) + ])) + except (json.JSONDecodeError, KeyError): + # If parsing fails, just send the arguments as-is + if tool_call.function.arguments: + deltas.append( + DeltaMessage(tool_calls=[ + DeltaToolCall( + index=index, + function=DeltaFunctionCall( + arguments=tool_call.function.arguments)) + ])) + + # Closing brace + deltas.append( + DeltaMessage(tool_calls=[ + DeltaToolCall(index=index, + function=DeltaFunctionCall(arguments="}")) + ])) + + return deltas + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, list[DeltaMessage], None]: + """Extract tool calls from streaming text using complete parsing. + + Strategy: + 1. Accumulate text in buffer and track processed position + 2. In each iteration, try to extract content or complete tool calls + 3. Parse complete tool calls using non-streaming method + 4. Convert parsed results to delta sequence + 5. Handle EOS token to flush incomplete tool calls as content + """ + # Initialize state for new request + if not previous_text: + self._reset_streaming_state() + self.streaming_request = request + + # Check for EOS token + has_eos = (self.eos_token_id is not None and delta_token_ids + and self.eos_token_id in delta_token_ids) + + # NOTE: The above simple check may incorrectly detect EOS when model output + # contains <|im_end|> tokens (e.g., in multi-turn conversations). + # If needed, use the more sophisticated check below: + # + # has_eos = False + # if self.eos_token_id is not None and delta_token_ids and self.eos_token_id in delta_token_ids: + # if not delta_text: + # # Mode 1: Empty delta with EOS - definitely stream terminator + # has_eos = True + # elif delta_text: + # # Mode 2: Check if EOS is extra (not part of delta_text encoding) + # # Encode delta_text to see how many tokens it should produce + # encoded_delta = self.model_tokenizer.encode(delta_text, add_special_tokens=False) + # # If delta_token_ids has MORE tokens than encoded_delta, + # # the extra token is the EOS terminator + # if len(delta_token_ids) > len(encoded_delta): + # has_eos = True + + # If no delta text, check if we need to return empty delta for finish_reason + if not delta_text and not has_eos: + # Check if this is an EOS token after all tool calls are complete + # Similar to qwen3coder_tool_parser.py logic + if (delta_token_ids + and self.tool_call_end_token_id not in delta_token_ids): + # Count complete tool calls + complete_calls = len( + self.tool_call_complete_regex.findall(current_text)) + + # If we have completed tool calls and populated prev_tool_call_arr + if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: + # Check if all tool calls are closed + open_calls = current_text.count( + self.tool_call_start_token) - current_text.count( + self.tool_call_end_token) + if open_calls == 0: + # Return empty delta for finish_reason processing + return DeltaMessage(content="") + return None + + # Process all available content + accumulated_deltas: list[DeltaMessage] = [] + + while self._has_unprocessed_content(current_text): + # Try to process next chunk (content or tool call) + delta = self._process_next_chunk(current_text) + + if delta is None: + # Cannot proceed further, need more tokens + break + + # Accumulate deltas + if isinstance(delta, list): + accumulated_deltas.extend(delta) + else: + accumulated_deltas.append(delta) + + # Handle EOS: flush any remaining incomplete tool calls as content + if has_eos: + remaining_delta = self._flush_remaining_content(current_text) + if remaining_delta: + accumulated_deltas.append(remaining_delta) + # If no remaining content but we have tool calls, return empty delta + elif len(self.prev_tool_call_arr) > 0: + # Check if all tool calls are closed + open_calls = current_text.count( + self.tool_call_start_token) - current_text.count( + self.tool_call_end_token) + if open_calls == 0: + accumulated_deltas.append(DeltaMessage(content="")) + + # Return results + return self._format_delta_result(accumulated_deltas) + + def _has_unprocessed_content(self, current_text: str) -> bool: + """Check if there's unprocessed content in the buffer.""" + return self._processed_length < len(current_text) + + def _process_next_chunk( + self, current_text: str + ) -> Union[DeltaMessage, list[DeltaMessage], None]: + """Process next chunk: either regular content or a complete tool call. + + Args: + current_text: Current accumulated text + + Returns: + - DeltaMessage or list of DeltaMessage if processed successfully + - None if cannot proceed (need more tokens) + """ + # Find next tool call start + tool_start_idx = self._find_tool_call_start(current_text, + self._processed_length) + + # Case 1: No tool call found - return remaining content + if tool_start_idx == -1: + return self._process_content(current_text, self._processed_length, + len(current_text)) + + # Case 2: Content before tool call + if tool_start_idx > self._processed_length: + return self._process_content(current_text, self._processed_length, + tool_start_idx) + + # Case 3: Tool call at current position + # Find end of the first complete tool call + tool_end_idx = self._find_first_complete_tool_call_end( + current_text, tool_start_idx) + + if tool_end_idx == -1: + # Tool call incomplete, wait for more tokens + return None + + # Process complete tool call + return self._process_complete_tool_calls(current_text, tool_start_idx, + tool_end_idx) + + def _process_content(self, current_text: str, start_pos: int, + end_pos: int) -> Union[DeltaMessage, None]: + """Process regular content (non-tool-call text). + + Args: + current_text: Current accumulated text + start_pos: Start position in buffer + end_pos: End position in buffer + + Returns: + DeltaMessage with content if non-empty + """ + if start_pos >= end_pos: + return None + + content = current_text[start_pos:end_pos] + + # Check if we're between tool calls - skip whitespace + # Similar to qwen3coder_tool_parser.py logic + if start_pos > 0: + # Check if text before start_pos (after stripping trailing whitespace) ends with + text_before = current_text[:start_pos] + if (text_before.rstrip().endswith(self.tool_call_end_token) + and content.strip() == ""): + # We just ended a tool call, skip whitespace between tool calls + self._processed_length = end_pos + return None + + # Return content if non-empty + if content: + self._processed_length = end_pos + return DeltaMessage(content=content) + + # Mark as processed even if empty + self._processed_length = end_pos + return None + + def _flush_remaining_content( + self, current_text: str) -> Union[DeltaMessage, None]: + """Flush any remaining unprocessed content as regular content. + + Args: + current_text: Current accumulated text + + Used when EOS token is encountered to handle incomplete tool calls. + """ + if not self._has_unprocessed_content(current_text): + return None + + remaining = current_text[self._processed_length:] + if remaining: + self._processed_length = len(current_text) + return DeltaMessage(content=remaining) + + self._processed_length = len(current_text) + return None + + def _format_delta_result( + self, deltas: list[DeltaMessage] + ) -> Union[DeltaMessage, list[DeltaMessage], None]: + """Format delta result for return. + + Args: + deltas: List of delta messages + + Returns: + - None if empty + - Single DeltaMessage if only one + - List of DeltaMessage if multiple + """ + if not deltas: + return None + elif len(deltas) == 1: + return deltas[0] + else: + return deltas + + def _process_complete_tool_calls( + self, current_text: str, start_pos: int, + end_pos: int) -> Union[list[DeltaMessage], None]: + """Process complete tool calls and convert to delta sequence. + + Args: + current_text: Current accumulated text + start_pos: Start position (should be at ) + end_pos: End position (after ) + + Returns: + List of DeltaMessage if successful, None otherwise + """ + try: + # Extract text segment containing complete tool call(s) + text_to_parse = current_text[start_pos:end_pos] + + # Parse using non-streaming method + result = self.extract_tool_calls(text_to_parse, + self.streaming_request) + + # Case 1: Successfully parsed tool calls + if result.tools_called and result.tool_calls: + # Note: Due to _find_first_complete_tool_call_end, we typically + # process only one tool call at a time + # but we can also process multiple tool calls below + deltas = self._build_tool_call_deltas(result.tool_calls, + text_to_parse) + self._update_state_after_tool_calls(result.tool_calls, end_pos) + return deltas if deltas else None + + # Case 2: Parsing failed - treat as regular content + self._processed_length = end_pos + return [DeltaMessage(content=text_to_parse)] + + except Exception as e: + # Exception during parsing - treat as content + logger.debug( + f"Failed to parse tool calls: {e}, treating as content") + self._processed_length = end_pos + failed_text = current_text[start_pos:end_pos] + return [DeltaMessage(content=failed_text)] if failed_text else None + + def _build_tool_call_deltas(self, tool_calls: list[ToolCall], + parsed_text: str) -> list[DeltaMessage]: + """Build delta messages from parsed tool calls with interleaved content. + + Args: + tool_calls: List of parsed tool calls + parsed_text: Original text that was parsed + + Returns: + List of DeltaMessage with tool calls and content interleaved + """ + deltas = [] + + # Extract content segments between tool calls + content_segments = self._extract_content_between_tool_calls_list( + parsed_text) + + # Build deltas: tool_call[i] -> content[i] (if exists) + for i, tool_call in enumerate(tool_calls): + # Convert tool call to delta sequence + tool_deltas = self._convert_tool_calls_to_deltas( + [tool_call], self._tool_call_index + i) + deltas.extend(tool_deltas) + + # Add content after this tool call if exists + if i < len(content_segments) and content_segments[i]: + deltas.append(DeltaMessage(content=content_segments[i])) + + return deltas + + def _update_state_after_tool_calls(self, tool_calls: list[ToolCall], + end_pos: int) -> None: + """Update internal state after processing tool calls. + + Args: + tool_calls: List of processed tool calls + end_pos: End position in buffer + """ + # Update processed position + self._processed_length = end_pos + + # Update tool call index + self._tool_call_index += len(tool_calls) + + # Update prev_tool_call_arr for finish_reason + self.prev_tool_call_arr.clear() + for tool_call in tool_calls: + if tool_call: + self.prev_tool_call_arr.append({ + "name": + tool_call.function.name, + "arguments": + tool_call.function.arguments, + }) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 6679c8dd5548..5094e3fbb844 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -97,6 +97,7 @@ def __getitem__(self, key): ultravox="UltravoxConfig", step3_vl="Step3VLConfig", step3_text="Step3TextConfig", + step3p5="Step3p5Config", qwen3_asr="Qwen3ASRConfig", qwen3_next="Qwen3NextConfig", lfm2_moe="Lfm2MoeConfig", diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 2f8179602b2e..7cd236532154 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -52,6 +52,7 @@ "Step3VLConfig": "vllm.transformers_utils.configs.step3_vl", "Step3VisionEncoderConfig": "vllm.transformers_utils.configs.step3_vl", "Step3TextConfig": "vllm.transformers_utils.configs.step3_vl", + "Step3p5Config": "vllm.transformers_utils.configs.step3p5", "Qwen3ASRConfig": "vllm.transformers_utils.configs.qwen3_asr", "Qwen3NextConfig": "vllm.transformers_utils.configs.qwen3_next", "Tarsier2Config": "vllm.transformers_utils.configs.tarsier2", @@ -95,6 +96,7 @@ "Step3VLConfig", "Step3VisionEncoderConfig", "Step3TextConfig", + "Step3p5Config", "Qwen3ASRConfig", "Qwen3NextConfig", "Tarsier2Config", diff --git a/vllm/transformers_utils/configs/step3p5.py b/vllm/transformers_utils/configs/step3p5.py new file mode 100644 index 000000000000..5d34608a917c --- /dev/null +++ b/vllm/transformers_utils/configs/step3p5.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional, Union + +from transformers.configuration_utils import PretrainedConfig + + +class Step3p5Config(PretrainedConfig): + model_type = "step3p5" + + def __init__( + self, + hidden_size: int = 5120, + intermediate_size: int = 13312, + num_attention_heads: int = 40, + num_attention_groups: int = 8, + num_hidden_layers: int = 48, + max_seq_len: int = 4096, + vocab_size: int = 65536, + rms_norm_eps: float = 1e-5, + moe_every_n_layer: int = 2, + use_moe: bool = False, + moe_intermediate_size: int = 10240, + moe_num_experts: int = 16, + moe_top_k: int = 4, + max_pos_interp_ratio: float = 1, + moe_layer_offset: int = 0, + moe_dynamic_exp_p: float = 1.0, + rope_theta: Optional[Union[float, list[float]]] = 500000, + rope_scaling: Optional[dict[str, Any]] = None, + head_dim: Optional[int] = None, + share_expert_dim: Optional[int] = None, + allgather_dtype: Optional[str] = None, + share_q_dim: Optional[int] = None, + norm_expert_weight: bool = True, + bos_token_id: Optional[Union[list[int], int]] = None, + eos_token_id: Optional[Union[list[int], int]] = None, + moe_router_activation: str = "softmax", + moe_router_scaling_factor: float = 1.0, + qk_nope_head_dim: Optional[int] = None, + qk_rope_head_dim: Optional[int] = None, + v_head_dim: Optional[int] = None, + q_lora_rank: Optional[int] = None, + kv_lora_rank: Optional[int] = None, + att_impl_type: str = "MFA", + use_head_wise_attn_gate: bool = False, + use_moe_router_bias: bool = False, + need_fp32_gate: bool = False, + layer_types: Optional[list[str]] = None, + use_rope_layers: Optional[list[bool]] = None, + yarn_only_types: Optional[list[str]] = None, + attention_other_setting: Optional[dict[str, Any]] = None, + num_nextn_predict_layers: int = 0, + swa_num_attention_heads: Optional[int] = None, + swiglu_limits: Optional[list[float]] = None, + swiglu_limits_shared: Optional[list[float]] = None, + zero_centered: bool = True, + max_position_embeddings: Optional[int] = None, + **kwargs, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_attention_groups = num_attention_groups + self.num_hidden_layers = num_hidden_layers + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.rms_norm_eps = rms_norm_eps + self.use_moe = use_moe + self.moe_intermediate_size = moe_intermediate_size + self.moe_every_n_layer = moe_every_n_layer + self.moe_num_experts = moe_num_experts + self.num_experts_per_tok = moe_top_k + self.moe_top_k = moe_top_k + self.max_pos_interp_ratio = max_pos_interp_ratio + self.moe_layer_offset = moe_layer_offset + self.moe_dynamic_exp_p = moe_dynamic_exp_p + + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.head_dim = head_dim + if share_expert_dim is None: + self.share_expert_dim = self.moe_intermediate_size * self.moe_top_k + else: + self.share_expert_dim = share_expert_dim + self.share_q_dim = share_q_dim + self.norm_expert_weight = norm_expert_weight + + self.allgather_dtype = allgather_dtype + + self.max_position_embeddings = max_position_embeddings + self.moe_router_activation = moe_router_activation + self.moe_router_scaling_factor = moe_router_scaling_factor + self.use_moe_router_bias = use_moe_router_bias + self.need_fp32_gate = need_fp32_gate + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.att_impl_type = att_impl_type + self.use_head_wise_attn_gate = use_head_wise_attn_gate + self.layer_types = layer_types + self.use_rope_layers = use_rope_layers + self.yarn_only_types = yarn_only_types + self.attention_other_setting = attention_other_setting + self.num_nextn_predict_layers = num_nextn_predict_layers + self.swa_num_attention_heads = swa_num_attention_heads + self.swiglu_limits = swiglu_limits + self.swiglu_limits_shared = swiglu_limits_shared + self.zero_centered = zero_centered + + resolved_bos_token_id = 1 if bos_token_id is None else bos_token_id + resolved_eos_token_id = [2, 3] if eos_token_id is None else eos_token_id + self.bos_token_id = resolved_bos_token_id + self.eos_token_id = resolved_eos_token_id + + super().__init__( + bos_token_id=resolved_bos_token_id, + eos_token_id=resolved_eos_token_id, + **kwargs, + ) From 4c60890f0ec5cda82100419945ba60190a3f78aa Mon Sep 17 00:00:00 2001 From: csy0225 Date: Wed, 28 Jan 2026 02:14:11 +0800 Subject: [PATCH 02/34] fix: resove diff from 014 --- vllm/model_executor/models/step3p5.py | 1 - vllm/reasoning/step3p5_reasoning_parser.py | 11 +++++----- .../tool_parsers/qwen3coder_tool_parser_rl.py | 22 ++++++++++++------- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index 1db8d13ca0fb..56dc92cb503b 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -177,7 +177,6 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(hidden_states) - # dynamo 在sharedfusedmoe里面 合不了silu的两个torch op,不如直接调用cuda op intermediate_act = self.act_fn.forward_cuda(gate_up) output, _ = self.down_proj(intermediate_act) return output diff --git a/vllm/reasoning/step3p5_reasoning_parser.py b/vllm/reasoning/step3p5_reasoning_parser.py index 50ac65aaa1b6..93aa7f5ee08d 100644 --- a/vllm/reasoning/step3p5_reasoning_parser.py +++ b/vllm/reasoning/step3p5_reasoning_parser.py @@ -3,16 +3,17 @@ from collections.abc import Sequence -from transformers import PreTrainedTokenizerBase +from vllm.tokenizers import TokenizerLike -from vllm.entrypoints.openai.protocol import ( +from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, - DeltaMessage, +) +from vllm.entrypoints.openai.responses.protocol import ( ResponsesRequest, ) +from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser - class Step3p5ReasoningParser(BaseThinkingReasoningParser): """ Reasoning parser for Step3p5 model. @@ -31,7 +32,7 @@ def start_token(self) -> str: def end_token(self) -> str: return "" - def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + def __init__(self, tokenizer: TokenizerLike, *args, **kwargs): super().__init__(tokenizer, *args, **kwargs) # Used to hold a trailing "\n" from reasoning content so we can decide diff --git a/vllm/tool_parsers/qwen3coder_tool_parser_rl.py b/vllm/tool_parsers/qwen3coder_tool_parser_rl.py index 78c5ee4f9585..ab8030a64027 100644 --- a/vllm/tool_parsers/qwen3coder_tool_parser_rl.py +++ b/vllm/tool_parsers/qwen3coder_tool_parser_rl.py @@ -8,24 +8,30 @@ import regex as re -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionToolsParam, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, +) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.tokenizers import TokenizerLike logger = init_logger(__name__) class Qwen3CoderToolParserRL(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): + def __init__(self, tokenizer: TokenizerLike): super().__init__(tokenizer) self.current_tool_name_sent: bool = False From 28eecbab828b58258b39dd27c48d2d4c169f7b7b Mon Sep 17 00:00:00 2001 From: csy0225 Date: Wed, 28 Jan 2026 03:08:23 +0800 Subject: [PATCH 03/34] fix: fp8 activation should support swigluoai_step --- vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py | 2 +- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 2 +- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index ac37cff9329a..8d0816308f85 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -304,7 +304,7 @@ def _supports_quant_scheme( @staticmethod def _supports_activation(activation: str) -> bool: - return activation in ["silu"] + return activation in ["silu", "swigluoai-step"] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 0a94757e335b..ce7411a45650 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -144,7 +144,7 @@ def _supports_quant_scheme( @staticmethod def _supports_activation(activation: str) -> bool: - return activation in ["silu"] + return activation in ["silu", "swigluoai-step"] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2161af17ac8e..41abfa4d8a48 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1946,7 +1946,7 @@ def _supports_quant_scheme( @staticmethod def _supports_activation(activation: str) -> bool: - return activation in ["silu", "gelu", "swigluoai"] + return activation in ["silu", "gelu", "swigluoai", "swigluoai-step"] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: From 17494c70279ea2b9bdcab41ea3650c6a10b91b96 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Wed, 28 Jan 2026 11:57:24 +0800 Subject: [PATCH 04/34] format: fix review comments for step3p5, remove groupwise-quant code --- vllm/model_executor/layers/activation.py | 20 +- .../layers/fused_moe/batched_deep_gemm_moe.py | 2 +- .../layers/fused_moe/deep_gemm_moe.py | 2 +- .../layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 4 +- vllm/model_executor/layers/fused_moe/utils.py | 10 +- vllm/model_executor/models/step3p5.py | 221 ++---------------- vllm/model_executor/models/step3p5_mtp.py | 82 ++----- 8 files changed, 63 insertions(+), 280 deletions(-) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 3420731542ed..b41278fb9fbc 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -22,20 +22,20 @@ logger = init_logger(__name__) -def swigluoai_step_and_mul_out( +def swiglustep_and_mul_out( out: torch.Tensor, x: torch.Tensor, limit: float, ) -> torch.Tensor: - """Out-variant of swigluoai-step activation. + """Out-variant of swiglustep activation. Writes into `out`: silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit) """ # Prefer the fused custom op when available (CUDA); fallback to PyTorch ops # otherwise. - if x.is_cuda and hasattr(torch.ops._C, "swigluoai_step_and_mul"): - torch.ops._C.swigluoai_step_and_mul(out, x, limit) + if x.is_cuda and hasattr(torch.ops._C, "swiglustep_and_mul"): + torch.ops._C.swiglustep_and_mul(out, x, limit) else: gate, up = x.chunk(2, dim=-1) gate = F.silu(gate) @@ -327,9 +327,9 @@ def extra_repr(self) -> str: return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}" -# --8<-- [start:swigluoai_step_and_mul] -@CustomOp.register("swigluoai_step_and_mul") -class SwigluOAIStepAndMul(CustomOp): +# --8<-- [start:swiglustep_and_mul] +@CustomOp.register("swiglustep_and_mul") +class SwigluStepAndMul(CustomOp): """An activation function for SwiGLU with clamping. Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit) @@ -340,12 +340,12 @@ class SwigluOAIStepAndMul(CustomOp): return: (num_tokens, d) or (batch_size, seq_len, d) """ - # --8<-- [end:swigluoai_step_and_mul] + # --8<-- [end:swiglustep_and_mul] def __init__(self, limit: float): super().__init__() if limit is None: - raise ValueError("SwigluOAIStepAndMul requires limit to be set.") + raise ValueError("SwigluStepAndMul requires limit to be set.") self.limit = limit def forward_native(self, x: torch.Tensor) -> torch.Tensor: @@ -360,7 +360,7 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - swigluoai_step_and_mul_out(out, x, self.limit) + swiglustep_and_mul_out(out, x, self.limit) return out def extra_repr(self) -> str: diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 8d0816308f85..8c081ef440f1 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -304,7 +304,7 @@ def _supports_quant_scheme( @staticmethod def _supports_activation(activation: str) -> bool: - return activation in ["silu", "swigluoai-step"] + return activation in ["silu", "swiglustep"] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index ce7411a45650..775ba132992e 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -144,7 +144,7 @@ def _supports_quant_scheme( @staticmethod def _supports_activation(activation: str) -> bool: - return activation in ["silu", "swigluoai-step"] + return activation in ["silu", "swiglustep"] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 41abfa4d8a48..1872ab817302 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1946,7 +1946,7 @@ def _supports_quant_scheme( @staticmethod def _supports_activation(activation: str) -> bool: - return activation in ["silu", "gelu", "swigluoai", "swigluoai-step"] + return activation in ["silu", "gelu", "swigluoai", "swiglustep"] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 19c5d8dc92f3..532a700986d6 100755 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -521,9 +521,9 @@ def __init__( self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation self.activation_limit = activation_limit - if self.activation == "swigluoai-step" and self.activation_limit is None: + if self.activation == "swiglustep" and self.activation_limit is None: raise ValueError( - "activation='swigluoai-step' requires activation_limit to be set." + "activation='swiglustep' requires activation_limit to be set." ) self.router = create_fused_moe_router( diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 9cdbf33e0564..f0d98d807a28 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -336,7 +336,7 @@ def apply_moe_activation( """ Apply MoE activation function. - For *_and_mul activations (silu, gelu, swigluoai, swigluoai-step): + For *_and_mul activations (silu, gelu, swigluoai, swiglustep): - Expects output.size(-1) * 2 == input.size(-1) For *_no_mul activations (silu_no_mul, gelu_no_mul, relu2_no_mul): @@ -359,14 +359,14 @@ def apply_moe_activation( torch.ops._C.gelu_and_mul(output, input) elif activation == "swigluoai": torch.ops._C.swigluoai_and_mul(output, input) - elif activation == "swigluoai-step": + elif activation == "swiglustep": if activation_limit is None: raise ValueError( - "activation='swigluoai-step' requires activation_limit to be set." + "activation='swiglustep' requires activation_limit to be set." ) - from vllm.model_executor.layers.activation import swigluoai_step_and_mul_out + from vllm.model_executor.layers.activation import swiglustep_and_mul_out - swigluoai_step_and_mul_out(output, input, activation_limit) + swiglustep_and_mul_out(output, input, activation_limit) # Activations without gated multiplication elif activation == SILU_NO_MUL: output.copy_(F.silu(input)) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index 56dc92cb503b..31a95d360a8b 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Jurassic model.""" -import math from collections.abc import Iterable from typing import Any, Optional, Union @@ -19,9 +18,8 @@ get_tensor_model_parallel_world_size, get_tp_group) from vllm.logger import init_logger -from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIStepAndMul +from vllm.model_executor.layers.activation import SiluAndMul, SwigluStepAndMul from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -58,78 +56,7 @@ def sigmoid_routing_function(hidden_states: torch.Tensor, if renormalize: expert_topk_weight = expert_topk_weight / torch.sum( expert_topk_weight, dim=-1, keepdim=True) - return expert_topk_weight, indices.to(torch.int32) - -def Step3p5RMSNorm( - hidden_size: int, - eps: float = 1e-6, - zero_centered: bool = True, -): - if zero_centered: - return GemmaRMSNorm(hidden_size, eps) - else: - return RMSNorm(hidden_size, eps) - -def pad_param( - weight: torch.Tensor, - name: str, - param: torch.nn.Parameter, - quant_config: Optional[QuantizationConfig] = None, -) -> torch.Tensor: - """Pad 2D weight for groupwise quantization TP sharding. - - Decide whether to pad based on `param.quant_method`: - - None / UnquantizedLinearMethod / UnquantizedFusedMoEMethod => no padding - - otherwise => treat as quantized and pad if groupwise_quant config is found - """ - if weight.dim() != 2: - return weight - - quant_method = getattr(param, "quant_method", None) - if not quant_config or quant_config.get_name( - ) != "groupwise_quant" or not quant_method: - return weight - - world_size = get_tensor_model_parallel_world_size() - group_size = quant_config.group_size - - if ("down_proj.scales" in name) or ("w2_weight_scale" in name): - group_size = 1 - - ic, oc = weight.shape - if ("down" in name) or ("w2" in name): - ic_pad = int( - math.ceil(ic / group_size / world_size) * world_size * - group_size) - ic - out = torch.nn.functional.pad(weight, (0, 0, 0, ic_pad), "constant", 0) - else: - oc_pad = int( - math.ceil(oc / group_size / world_size) * world_size * - group_size) - oc - out = torch.nn.functional.pad(weight, (0, oc_pad, 0, 0), "constant", 0) - - logger.debug( - f"padding {name} ,quant_config={quant_config},original weight.shape: {weight.shape}, padded weight.shape: {out.shape}" - ) - return out - - -def _pad_size_for_groupwise_quant( - size: int, - quant_config: Optional[QuantizationConfig], -) -> int: - """Pad `size` to be a multiple of (group_size * tensor_parallel_world_size). - - This is needed for groupwise quantization TP sharding. - """ - if quant_config is None or quant_config.get_name() != "groupwise_quant": - return size - world_size = get_tensor_model_parallel_world_size() - - group_size = quant_config.group_size - multiple = world_size * group_size - return math.ceil(size / multiple) * multiple - + return expert_topk_weight, indices.to(torch.int32) class Step3p5MLP(nn.Module): @@ -144,10 +71,6 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - - intermediate_size = _pad_size_for_groupwise_quant( - intermediate_size, quant_config) - self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, @@ -173,7 +96,7 @@ def __init__( layer_idx] is not None and config.swiglu_limits_shared[ layer_idx] != 0: self.limit = config.swiglu_limits_shared[layer_idx] - self.act_fn = SwigluOAIStepAndMul(limit=self.limit) + self.act_fn = SwigluStepAndMul(limit=self.limit) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(hidden_states) @@ -198,17 +121,14 @@ def __init__( rope_scaling: Optional[tuple] = None, prefix: str = "", attn_type: str = AttentionType.DECODER, - dual_chunk_attention_config: Optional[dict[str, Any]] = None, # Step3p5 specific args sliding_window: Optional[int] = None, - enable_sink: bool = False, use_head_wise_attn_gate: bool = False, layer_types: list = None, use_rope_layers: list = None, yarn_only_types: list = None, swa_num_attention_heads: Optional[int] = None, partial_rotary_factor: float = 1.0, - zero_centered: bool = True, ): super().__init__() self.hidden_size = hidden_size @@ -229,14 +149,7 @@ def __init__( if swa_num_attention_heads is not None: num_heads = swa_num_attention_heads self.total_num_heads = swa_num_attention_heads - if enable_sink: - self.sinks = torch.nn.Parameter(torch.empty( - self.total_num_heads // tp_size, dtype=torch.bfloat16), - requires_grad=False) - else: - self.sinks = None else: - self.sinks = None sliding_window = None if isinstance(rope_theta, list): @@ -261,7 +174,6 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.dual_chunk_attention_config = dual_chunk_attention_config self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, @@ -279,34 +191,16 @@ def __init__( prefix=f"{prefix}.o_proj", ) - # vLLM >= 0.7 uses `rope_parameters` for all RoPE scaling variants. - # `rope_scaling` (HF-style) is mapped into `rope_parameters` here to - # preserve the behavior of older Step3p5 implementations. rope_parameters: dict[str, Any] = { "rope_type": "default", "partial_rotary_factor": partial_rotary_factor, } if rope_scaling is not None: - if isinstance(rope_scaling, dict): - rope_parameters.update(rope_scaling) - elif isinstance(rope_scaling, (tuple, list)) and rope_scaling and isinstance( - rope_scaling[0], dict): - # Per-layer rope scaling configs. - if self.layer_idx < len(rope_scaling): - rope_parameters.update(rope_scaling[self.layer_idx]) - elif isinstance(rope_scaling, - (tuple, list)) and len(rope_scaling) == 2 and isinstance( - rope_scaling[0], str): - # Legacy tuple format: (type, factor) - rope_parameters.update({ - "rope_type": rope_scaling[0], - "factor": rope_scaling[1], - }) - if "type" in rope_parameters: - rope_parameters.setdefault("rope_type", rope_parameters["type"]) - rope_parameters.pop("type", None) - # Always take the per-layer resolved rope theta, instead of trusting - # any potentially list-valued rope_theta coming from rope_scaling. + if not isinstance(rope_scaling, dict): + raise ValueError( + "rope_scaling must be a dict for Step3p5Attention." + ) + rope_parameters.update(rope_scaling) rope_parameters["rope_theta"] = self.rope_theta rope_parameters["partial_rotary_factor"] = partial_rotary_factor @@ -314,16 +208,10 @@ def __init__( head_size=self.head_dim, max_position=max_position, rope_parameters=rope_parameters, - dual_chunk_attention_config=dual_chunk_attention_config, ) - self.q_norm = Step3p5RMSNorm(self.head_dim, - eps=rms_norm_eps, - zero_centered=zero_centered) - self.k_norm = Step3p5RMSNorm(self.head_dim, - eps=rms_norm_eps, - zero_centered=zero_centered) - self.zero_centered = zero_centered + self.q_norm = GemmaRMSNorm(self.head_dim, rms_norm_eps) + self.k_norm = GemmaRMSNorm(self.head_dim, rms_norm_eps) self.use_head_wise_attn_gate = use_head_wise_attn_gate if use_head_wise_attn_gate: self.g_proj = ColumnParallelLinear( @@ -337,7 +225,6 @@ def __init__( if use_rope_layers: self.use_rope = use_rope_layers[self.layer_idx] - # TODO: Add sink attention self.attn = Attention( self.num_heads, self.head_dim, @@ -348,10 +235,6 @@ def __init__( prefix=f"{prefix}.attn", per_layer_sliding_window=sliding_window, attn_type=attn_type, - **{ - "layer_idx": extract_layer_index(prefix), - "dual_chunk_attention_config": dual_chunk_attention_config, - } if dual_chunk_attention_config else {}, ) self.max_position_embeddings = max_position @@ -359,7 +242,6 @@ def __init__( self.rotary_dim = self.head_dim if self.partial_rotary_factor == 1 else self.head_dim // 2 def qk_norm_rope(self, q, k, positions): - # Add qk-norm q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = self.q_norm(q_by_head.contiguous()) @@ -443,20 +325,16 @@ def __init__(self, if config.swiglu_limits and config.swiglu_limits[ layer_idx] is not None and config.swiglu_limits[layer_idx] != 0: swigluoai_step_limit = config.swiglu_limits[layer_idx] - activation = "swigluoai-step" + activation = "swiglustep" logger.info( - f"step3p5 layer_idx: {layer_idx}, activation limit: {config.swiglu_limits[layer_idx]}, will use swigluoai-step" + f"step3p5 layer_idx: {layer_idx}, activation limit: {config.swiglu_limits[layer_idx]}, will use swiglustep" ) - moe_intermediate_size = _pad_size_for_groupwise_quant( - config.moe_intermediate_size, - quant_config, - ) self.experts = SharedFusedMoE( shared_experts=shared_experts, num_experts=config.moe_num_experts, top_k=config.moe_top_k, hidden_size=config.hidden_size, - intermediate_size=moe_intermediate_size, + intermediate_size=config.moe_intermediate_size, reduce_results=reduce_results, renormalize=config.norm_expert_weight, quant_config=quant_config, @@ -497,7 +375,6 @@ def forward( router_logits = hidden_states.to( torch.float32) @ self.gate.weight.to(torch.float32).t() else: - # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) shared_out, final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits) @@ -512,7 +389,6 @@ def __init__(self, parallel_config: ParallelConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - use_fused_moe: bool = False, prefix: str = "") -> None: super().__init__() config = config.hf_config @@ -553,7 +429,6 @@ def __init__(self, quant_config=quant_config, rope_scaling=rope_scaling, sliding_window=getattr(config, 'sliding_window', None), - enable_sink=getattr(config, "sink", False), use_head_wise_attn_gate=getattr(config, "use_head_wise_attn_gate", False), @@ -562,7 +437,6 @@ def __init__(self, yarn_only_types=getattr(config, "yarn_only_types", []), partial_rotary_factor=partial_rotary_factors[layer_idx] if partial_rotary_factors else 1.0, - zero_centered=getattr(config, "zero_centered", False), prefix=f"{prefix}.self_attn", ) else: @@ -584,19 +458,16 @@ def __init__(self, int(i) for i in moe_layers_enum.strip().split(',') ] else: - # Default to 1dense. moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] if layer_idx in moe_layers_idx: reduce_results = True if self.use_fused_all_reduce or self.tp_group.world_size == 1 and get_ep_group( ).world_size == 1: reduce_results = False - moe_intermediate_size = _pad_size_for_groupwise_quant( - config.share_expert_dim, quant_config) self.share_expert = Step3p5MLP( config=config, hidden_size=self.hidden_size, - intermediate_size=moe_intermediate_size, + intermediate_size=config.share_expert_dim, hidden_act="silu", reduce_results=reduce_results, quant_config=quant_config, @@ -616,15 +487,8 @@ def __init__(self, quant_config=quant_config, reduce_results=True, prefix=f"{prefix}.mlp") - self.use_fused_moe = use_fused_moe - self.input_layernorm = Step3p5RMSNorm( - config.hidden_size, - eps=config.rms_norm_eps, - zero_centered=config.zero_centered) - self.post_attention_layernorm = Step3p5RMSNorm( - config.hidden_size, - eps=config.rms_norm_eps, - zero_centered=config.zero_centered) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) self.prefix = prefix def add_and_maybe_inplace_all_reduce(self, in1: torch.Tensor, @@ -656,22 +520,18 @@ def forward(self, positions: torch.Tensor, hidden_states = ffn_output + residual return hidden_states - -# Note: max-num-batched-tokens 开到64k 编译会不通过,小于64k没啥问题 @support_torch_compile class Step3p5Model(nn.Module): def __init__(self, vllm_config: VllmConfig, - prefix: str = "", - use_fused_moe: bool = False) -> None: + prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.vocab_size = config.vocab_size self.config = config - self.use_fused_moe = use_fused_moe self.moe_num_experts = config.moe_num_experts @@ -691,14 +551,11 @@ def __init__(self, parallel_config=vllm_config.parallel_config, cache_config=cache_config, quant_config=quant_config, - use_fused_moe=self.use_fused_moe, prefix=prefix), prefix=f"{prefix}.layers", ) if get_pp_group().is_last_rank: - self.norm = Step3p5RMSNorm(config.hidden_size, - eps=config.rms_norm_eps, - zero_centered=config.zero_centered) + self.norm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) else: self.norm = PPMissingLayer() @@ -751,8 +608,7 @@ def __init__( self.vllm_config = vllm_config self.model = Step3p5Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model"), - use_fused_moe=True) + prefix=maybe_prefix(prefix, "model")) self.moe_layers: list[FusedMoEBlock] = [] for layer in self.model.layers: @@ -818,7 +674,6 @@ def set_eplb_state( logical_replica_count: torch.Tensor, ) -> None: for layer_idx, layer in enumerate(self.moe_layers): - # layer_idx = layer.layer_idx experts = layer.experts assert isinstance(experts, FusedMoE) # Register the expert weights. @@ -851,7 +706,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): vllm_config = self.vllm_config config = vllm_config.model_config.hf_config assert config.num_attention_groups > 1, "Only support GQA" - #GQA qkv_params_mapping = [] stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -865,30 +719,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) loaded_params = set() - if self.model.use_fused_moe: - is_groupwise_quant = self.vllm_config.quant_config is not None and self.vllm_config.quant_config.get_name( - ) == "groupwise_quant" - if is_groupwise_quant: - expert_params_mapping = [ - (".moe.experts.w13_weight", ".moe.gate_proj.qweight", - "w1"), - (".moe.experts.w13_weight", ".moe.up_proj.qweight", "w3"), - (".moe.experts.w2_weight", ".moe.down_proj.qweight", "w2"), - (".moe.experts.w13_weight_scale", ".moe.gate_proj.scales", - "w1"), - (".moe.experts.w13_weight_scale", ".moe.up_proj.scales", - "w3"), - (".moe.experts.w2_weight_scale", ".moe.down_proj.scales", - "w2"), - ] - else: - expert_params_mapping = [ - (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), - (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), - (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2") - ] - else: - expert_params_mapping = [] + expert_params_mapping = [ + (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), + (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), + (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2") + ] disable_moe_stacked_params = [data[1] for data in expert_params_mapping] @@ -909,8 +744,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): continue param = params_dict[name] weight_loader = param.weight_loader - loaded_weight = pad_param(loaded_weight, name, param, - self.vllm_config.quant_config) weight_loader(param, loaded_weight, shard_id) loaded_params.add(name) break @@ -933,9 +766,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): assert loaded_weight.shape[0] == moe_expert_num for expert_id in range(moe_expert_num): loaded_weight_expert = loaded_weight[expert_id] - loaded_weight_expert = pad_param( - loaded_weight_expert, name, param, - self.vllm_config.quant_config) weight_loader(param, loaded_weight_expert, name, @@ -967,9 +797,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): logger.warning_once("ignore expert_bias") continue param = params_dict[name] - loaded_weight = pad_param( - loaded_weight, name, param, - self.vllm_config.quant_config) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/step3p5_mtp.py b/vllm/model_executor/models/step3p5_mtp.py index eb3da14f2b7d..a7747d09e9d5 100644 --- a/vllm/model_executor/models/step3p5_mtp.py +++ b/vllm/model_executor/models/step3p5_mtp.py @@ -9,15 +9,14 @@ from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger -# from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors - -from .step3p5 import Step3p5RMSNorm, Step3p5DecoderLayer, get_spec_layer_idx_from_weight_name +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from .step3p5 import Step3p5DecoderLayer, get_spec_layer_idx_from_weight_name from .utils import maybe_prefix logger = init_logger(__name__) @@ -31,9 +30,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.norm = Step3p5RMSNorm(config.hidden_size, - eps=config.rms_norm_eps, - zero_centered=config.zero_centered) + self.norm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) self.head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) @@ -55,12 +52,8 @@ def __init__( ) -> None: super().__init__() - self.enorm = Step3p5RMSNorm(config.hidden_size, - eps=config.rms_norm_eps, - zero_centered=config.zero_centered) - self.hnorm = Step3p5RMSNorm(config.hidden_size, - eps=config.rms_norm_eps, - zero_centered=config.zero_centered) + self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) + self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) @@ -69,7 +62,6 @@ def __init__( parallel_config=parallel_config, cache_config=cache_config, quant_config=quant_config, - use_fused_moe=True, prefix=f"{prefix}.mtp_block") def forward( @@ -121,7 +113,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): }) self.logits_processor = LogitsProcessor(config.vocab_size) - self.use_fused_moe = True def forward( self, @@ -196,55 +187,20 @@ def load_weights(self, weights: Iterable[tuple[str, vllm_config = self.vllm_config config = vllm_config.model_config.hf_config - if config.num_attention_groups > 1: - #GQA - # qkv_params_mapping=[] - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - elif config.att_impl_type == "MLA": - # qkv_params_mapping = [] - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - else: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - - if self.model.use_fused_moe: - if self.vllm_config.quant_config is not None and self.vllm_config.quant_config.get_name( - ) == "groupwise_quant": - expert_params_mapping = [ - (".moe.experts.w13_weight", ".moe.gate_proj.qweight", - "w1"), - (".moe.experts.w13_weight", ".moe.up_proj.qweight", "w3"), - (".moe.experts.w2_weight", ".moe.down_proj.qweight", "w2"), - (".moe.experts.w13_weight_scale", ".moe.gate_proj.scales", - "w1"), - (".moe.experts.w13_weight_scale", ".moe.up_proj.scales", - "w3"), - (".moe.experts.w2_weight_scale", ".moe.down_proj.scales", - "w2"), - ] - else: - expert_params_mapping = [ - (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), - (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), - (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2") - ] - else: - expert_params_mapping = [] + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = [ + (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), + (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), + (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2") + ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() From 9b0cbdbef4b1a13d1edca4182798185401295fe4 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Wed, 28 Jan 2026 20:36:23 +0800 Subject: [PATCH 05/34] feat: add step3p5_tool_parser --- vllm/tool_parsers/__init__.py | 8 +- .../tool_parsers/qwen3coder_tool_parser_rl.py | 812 ---------- vllm/tool_parsers/step3p5_tool_parser.py | 1401 +++++++++++++++++ 3 files changed, 1405 insertions(+), 816 deletions(-) delete mode 100644 vllm/tool_parsers/qwen3coder_tool_parser_rl.py create mode 100644 vllm/tool_parsers/step3p5_tool_parser.py diff --git a/vllm/tool_parsers/__init__.py b/vllm/tool_parsers/__init__.py index fee125c398ad..c1a39f2afa02 100644 --- a/vllm/tool_parsers/__init__.py +++ b/vllm/tool_parsers/__init__.py @@ -122,10 +122,6 @@ "qwen3coder_tool_parser", "Qwen3CoderToolParser", ), - "qwen3_coder_rl": ( - "qwen3coder_tool_parser_rl", - "Qwen3CoderToolParserRL", - ), "qwen3_xml": ( "qwen3xml_tool_parser", "Qwen3XMLToolParser", @@ -138,6 +134,10 @@ "step3_tool_parser", "Step3ToolParser", ), + "step3p5": ( + "step3p5_tool_parser", + "Step3p5ToolParser", + ), "xlam": ( "xlam_tool_parser", "xLAMToolParser", diff --git a/vllm/tool_parsers/qwen3coder_tool_parser_rl.py b/vllm/tool_parsers/qwen3coder_tool_parser_rl.py deleted file mode 100644 index ab8030a64027..000000000000 --- a/vllm/tool_parsers/qwen3coder_tool_parser_rl.py +++ /dev/null @@ -1,812 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import ast -import json -import uuid -from collections.abc import Sequence -from typing import Any, Optional, Union - -import regex as re - -from vllm.entrypoints.openai.chat_completion.protocol import ( - ChatCompletionRequest, - ChatCompletionToolsParam, -) -from vllm.entrypoints.openai.engine.protocol import ( - DeltaFunctionCall, - DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, - ToolCall, -) -from vllm.tool_parsers.abstract_tool_parser import ( - ToolParser, -) -from vllm.logger import init_logger -from vllm.tokenizers import TokenizerLike - -logger = init_logger(__name__) - - -class Qwen3CoderToolParserRL(ToolParser): - - def __init__(self, tokenizer: TokenizerLike): - super().__init__(tokenizer) - - self.current_tool_name_sent: bool = False - self.prev_tool_call_arr: list[dict] = [] - # Override base class type - we use string IDs for tool calls - self.current_tool_id: Optional[str] = None # type: ignore - self.streamed_args_for_tool: list[str] = [] - - # Sentinel tokens for streaming mode - self.tool_call_start_token: str = "" - self.tool_call_end_token: str = "" - self.tool_call_prefix: str = "(.*?)", re.DOTALL) - self.tool_call_function_regex = re.compile( - r"", re.DOTALL) - self.tool_call_parameter_regex = re.compile( - r"", re.DOTALL) - - if not self.model_tokenizer: - raise ValueError( - "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") - - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) - self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - - if (self.tool_call_start_token_id is None - or self.tool_call_end_token_id is None): - raise RuntimeError( - "Qwen3 XML Tool parser could not locate tool call start/end " - "tokens in the tokenizer!") - - # Get EOS token ID for EOS detection - self.eos_token_id = getattr(self.model_tokenizer, 'eos_token_id', None) - - logger.info("vLLM Successfully import tool parser %s !", - self.__class__.__name__) - - def _generate_tool_call_id(self) -> str: - """Generate a unique tool call ID.""" - return f"call_{uuid.uuid4().hex[:24]}" - - def _reset_streaming_state(self): - """Reset all streaming state for a new request.""" - self._processed_length: int = 0 # Position of last processed character - self._tool_call_index: int = 0 # Number of tool calls processed so far - self.streaming_request = None # Current request being processed - - def _get_arguments_config( - self, func_name: str, - tools: Optional[list[ChatCompletionToolsParam]]) -> dict: - """Extract argument configuration for a function.""" - if tools is None: - return {} - for config in tools: - if not hasattr(config, "type") or not (hasattr( - config, "function") and hasattr(config.function, "name")): - continue - if config.type == "function" and config.function.name == func_name: - if not hasattr(config.function, "parameters"): - return {} - params = config.function.parameters - if isinstance(params, dict) and "properties" in params: - return params["properties"] - elif isinstance(params, dict): - return params - else: - return {} - logger.warning("Tool '%s' is not defined in the tools list.", - func_name) - return {} - - def _convert_param_value(self, param_value: str, param_name: str, - param_config: dict, func_name: str) -> Any: - """Convert parameter value based on its type in the schema.""" - # Handle null value for any type - if param_value.lower() == "null": - return None - - if param_name not in param_config: - if param_config != {}: - logger.warning( - "Parsed parameter '%s' is not defined in the tool " - "parameters for tool '%s', directly returning the " - "string value.", param_name, func_name) - return param_value - - if isinstance(param_config[param_name], - dict) and "type" in param_config[param_name]: - param_type = str(param_config[param_name]["type"]).strip().lower() - else: - param_type = "string" - if param_type in ["string", "str", "text", "varchar", "char", "enum"]: - return param_value - elif param_type.startswith("int") or param_type.startswith( - "uint") or param_type.startswith( - "long") or param_type.startswith( - "short") or param_type.startswith("unsigned"): - try: - return int(param_value) - except (ValueError, TypeError) as e: - raise ValueError( - f"Parsed value '{param_value}' of parameter '{param_name}' " - f"is not an integer in tool '{func_name}'.") from e - elif param_type.startswith("num") or param_type.startswith("float"): - try: - float_param_value = float(param_value) - return float_param_value if float_param_value - int( - float_param_value) != 0 else int(float_param_value) - except (ValueError, TypeError) as e: - raise ValueError( - f"Parsed value '{param_value}' of parameter '{param_name}' " - f"is not a float in tool '{func_name}'.") from e - elif param_type in ["boolean", "bool", "binary"]: - param_value = param_value.lower() - if param_value not in ["true", "false"]: - raise ValueError( - f"Parsed value '{param_value}' of parameter '{param_name}' " - f"is not a boolean (`true` or `false`) in tool '{func_name}'." - ) - return param_value == "true" - else: - if param_type in ["object", "array", "arr" - ] or param_type.startswith( - "dict") or param_type.startswith("list"): - try: - param_value = json.loads(param_value) - return param_value - except (json.JSONDecodeError, TypeError, ValueError) as e: - raise ValueError( - f"Parsed value '{param_value}' of parameter '{param_name}' " - f"cannot be parsed with json.loads in tool '{func_name}'." - ) from e - try: - param_value = ast.literal_eval(param_value) # safer - except (ValueError, SyntaxError, TypeError) as e: - raise ValueError( - f"Parsed value '{param_value}' of parameter '{param_name}' " - f"cannot be converted via Python `ast.literal_eval()` in tool " - f"'{func_name}'.") from e - return param_value - - def _parse_xml_function_call( - self, function_call_str: str, - tools: Optional[list[ChatCompletionToolsParam]] - ) -> Optional[ToolCall]: - - # Extract function name - end_index = function_call_str.index(">") - function_name = function_call_str[:end_index] - param_config = self._get_arguments_config(function_name, tools) - parameters = function_call_str[end_index + 1:] - param_dict = {} - for match_text in self.tool_call_parameter_regex.findall(parameters): - idx = match_text.index(">") - param_name = match_text[:idx] - param_value = str(match_text[idx + 1:]) - # Remove prefix and trailing \n - if param_value.startswith("\n"): - param_value = param_value[1:] - if param_value.endswith("\n"): - param_value = param_value[:-1] - - try: - param_dict[param_name] = self._convert_param_value( - param_value, param_name, param_config, function_name) - except Exception: - return None - return ToolCall( - type="function", - function=FunctionCall(name=function_name, - arguments=json.dumps(param_dict, - ensure_ascii=False)), - ) - - def _get_function_calls(self, model_output: str) -> list[str]: - # Find all tool calls - raw_tool_calls = self.tool_call_complete_regex.findall(model_output) - - # if no closed tool_call tags found, return empty list - if len(raw_tool_calls) == 0: - return [] - - raw_function_calls = [] - for tool_call in raw_tool_calls: - function_matches = self.tool_call_function_regex.findall(tool_call) - raw_function_calls.extend(function_matches) - - return raw_function_calls - - def _check_format(self, model_output: str) -> bool: - """Check if model output contains properly formatted tool call. - - Requirements: - 1. Must have closed tool_call tags (...) - 2. Must have closed function tags () - 3. If parameter tags exist, they must be closed and correct - - Returns True if the format is valid, False otherwise. - """ - # Check 1: Must have closed tool_call tags - tool_call_matches = self.tool_call_complete_regex.findall(model_output) - if len(tool_call_matches) == 0: - return False - - # Check 2: Must have closed function tags within tool_call - has_valid_function = False - for tool_call_content in tool_call_matches: - function_matches = self.tool_call_function_regex.findall( - tool_call_content) - if len(function_matches) > 0: - has_valid_function = True - # Check if there's an unclosed function tag - if self.tool_call_prefix in tool_call_content and self.function_end_token not in tool_call_content: - return False - - if not has_valid_function: - return False - - # Check 3: If parameter tags exist, they must be closed and correct - for tool_call_content in tool_call_matches: - # Count opening and closing parameter tags - param_open_count = tool_call_content.count(self.parameter_prefix) - param_close_count = tool_call_content.count( - self.parameter_end_token) - - # If there are parameter tags, they must be balanced - if param_open_count > 0: - if param_open_count != param_close_count: - return False - # Check if all parameter tags are properly closed using regex - param_matches = self.tool_call_parameter_regex.findall( - tool_call_content) - if len(param_matches) != param_open_count: - return False - - return True - - def extract_tool_calls( - self, - model_output: str, - request: ChatCompletionRequest, - ) -> ExtractedToolCallInformation: - # Quick check to avoid unnecessary processing - if not self._check_format(model_output): - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) - - try: - function_calls = self._get_function_calls(model_output) - if len(function_calls) == 0: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) - - tool_calls = [ - self._parse_xml_function_call(function_call_str, request.tools) - for function_call_str in function_calls - ] - - # Populate prev_tool_call_arr for serving layer to set finish_reason - self.prev_tool_call_arr.clear() # Clear previous calls - for tool_call in tool_calls: - if tool_call: - self.prev_tool_call_arr.append({ - "name": - tool_call.function.name, - "arguments": - tool_call.function.arguments, - }) - - # Extract content before tool calls - content_index = model_output.find(self.tool_call_start_token) - content = model_output[:content_index] # .rstrip() - - return ExtractedToolCallInformation( - tools_called=(len(tool_calls) > 0), - tool_calls=tool_calls, - content=content if content else None, - ) - - except Exception: - logger.warning("Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) - - def _find_first_complete_tool_call_end(self, - text: str, - start_pos: int = 0) -> int: - """Find the end position of the first complete tool call. - - Args: - text: Text to search in - start_pos: Position to start searching from - - Returns: - Position after the first tag, or -1 if incomplete - - Example: - "......" returns position after - """ - # Find tool call start - start_idx = text.find(self.tool_call_start_token, start_pos) - if start_idx == -1: - return -1 - - # Find matching end token - end_idx = text.find(self.tool_call_end_token, - start_idx + len(self.tool_call_start_token)) - if end_idx == -1: - return -1 # Incomplete tool call - - # Return position after end token - return end_idx + len(self.tool_call_end_token) - - def _find_tool_call_start(self, text: str, start_pos: int = 0) -> int: - """Find the start position of next tool call. - - Args: - text: Text to search in - start_pos: Position to start searching from - - Returns: - Position of token, or -1 if not found - """ - return text.find(self.tool_call_start_token, start_pos) - - def _extract_content_between_tool_calls_list(self, text: str) -> list[str]: - """Extract content segments after each tool call. - - For n tool calls, returns n segments where segment[i] is the content - after tool_call[i] (before tool_call[i+1] or at the end). - - Empty or whitespace-only segments are represented as empty string "". - - Args: - text: Text containing tool calls - - Returns: - List of content segments (one per tool call) - """ - content_segments = [] - pos = 0 - - while True: - # Find end of current tool call - end_pos = text.find(self.tool_call_end_token, pos) - if end_pos == -1: - break - - # Move past the end token - end_pos += len(self.tool_call_end_token) - - # Find start of next tool call - next_start = self._find_tool_call_start(text, end_pos) - - # Extract content between current end and next start (or text end) - content = text[end_pos:next_start] if next_start != -1 else text[ - end_pos:] - - # Store content (empty string if whitespace-only) - content_segments.append(content if content.strip() else "") - - if next_start == -1: - break - pos = next_start - - return content_segments - - def _convert_tool_calls_to_deltas( - self, - tool_calls: list[ToolCall], - starting_index: int = 0) -> list[DeltaMessage]: - """Convert complete ToolCall list to delta message sequence. - - Format: header (function name) -> { -> param1 -> param2 -> ... -> } - - Args: - tool_calls: List of tool calls to convert - starting_index: Starting index for tool calls (default 0) - """ - deltas = [] - for i, tool_call in enumerate(tool_calls): - index = starting_index + i - tool_id = self._generate_tool_call_id() - - # Header delta: function name - deltas.append( - DeltaMessage(tool_calls=[ - DeltaToolCall( - index=index, - id=tool_id, - function=DeltaFunctionCall( - name=tool_call.function.name, arguments=""), - type="function", - ) - ])) - - # Opening brace - deltas.append( - DeltaMessage(tool_calls=[ - DeltaToolCall(index=index, - function=DeltaFunctionCall(arguments="{")) - ])) - - # Parse arguments JSON to extract parameters - try: - args_dict = json.loads(tool_call.function.arguments) - param_names = list(args_dict.keys()) - - for param_idx, param_name in enumerate(param_names): - param_value = args_dict[param_name] - serialized_value = json.dumps(param_value, - ensure_ascii=False) - - if param_idx == 0: - json_fragment = f'"{param_name}": {serialized_value}' - else: - json_fragment = f', "{param_name}": {serialized_value}' - - deltas.append( - DeltaMessage(tool_calls=[ - DeltaToolCall(index=index, - function=DeltaFunctionCall( - arguments=json_fragment)) - ])) - except (json.JSONDecodeError, KeyError): - # If parsing fails, just send the arguments as-is - if tool_call.function.arguments: - deltas.append( - DeltaMessage(tool_calls=[ - DeltaToolCall( - index=index, - function=DeltaFunctionCall( - arguments=tool_call.function.arguments)) - ])) - - # Closing brace - deltas.append( - DeltaMessage(tool_calls=[ - DeltaToolCall(index=index, - function=DeltaFunctionCall(arguments="}")) - ])) - - return deltas - - def extract_tool_calls_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: Sequence[int], - current_token_ids: Sequence[int], - delta_token_ids: Sequence[int], - request: ChatCompletionRequest, - ) -> Union[DeltaMessage, list[DeltaMessage], None]: - """Extract tool calls from streaming text using complete parsing. - - Strategy: - 1. Accumulate text in buffer and track processed position - 2. In each iteration, try to extract content or complete tool calls - 3. Parse complete tool calls using non-streaming method - 4. Convert parsed results to delta sequence - 5. Handle EOS token to flush incomplete tool calls as content - """ - # Initialize state for new request - if not previous_text: - self._reset_streaming_state() - self.streaming_request = request - - # Check for EOS token - has_eos = (self.eos_token_id is not None and delta_token_ids - and self.eos_token_id in delta_token_ids) - - # NOTE: The above simple check may incorrectly detect EOS when model output - # contains <|im_end|> tokens (e.g., in multi-turn conversations). - # If needed, use the more sophisticated check below: - # - # has_eos = False - # if self.eos_token_id is not None and delta_token_ids and self.eos_token_id in delta_token_ids: - # if not delta_text: - # # Mode 1: Empty delta with EOS - definitely stream terminator - # has_eos = True - # elif delta_text: - # # Mode 2: Check if EOS is extra (not part of delta_text encoding) - # # Encode delta_text to see how many tokens it should produce - # encoded_delta = self.model_tokenizer.encode(delta_text, add_special_tokens=False) - # # If delta_token_ids has MORE tokens than encoded_delta, - # # the extra token is the EOS terminator - # if len(delta_token_ids) > len(encoded_delta): - # has_eos = True - - # If no delta text, check if we need to return empty delta for finish_reason - if not delta_text and not has_eos: - # Check if this is an EOS token after all tool calls are complete - # Similar to qwen3coder_tool_parser.py logic - if (delta_token_ids - and self.tool_call_end_token_id not in delta_token_ids): - # Count complete tool calls - complete_calls = len( - self.tool_call_complete_regex.findall(current_text)) - - # If we have completed tool calls and populated prev_tool_call_arr - if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: - # Check if all tool calls are closed - open_calls = current_text.count( - self.tool_call_start_token) - current_text.count( - self.tool_call_end_token) - if open_calls == 0: - # Return empty delta for finish_reason processing - return DeltaMessage(content="") - return None - - # Process all available content - accumulated_deltas: list[DeltaMessage] = [] - - while self._has_unprocessed_content(current_text): - # Try to process next chunk (content or tool call) - delta = self._process_next_chunk(current_text) - - if delta is None: - # Cannot proceed further, need more tokens - break - - # Accumulate deltas - if isinstance(delta, list): - accumulated_deltas.extend(delta) - else: - accumulated_deltas.append(delta) - - # Handle EOS: flush any remaining incomplete tool calls as content - if has_eos: - remaining_delta = self._flush_remaining_content(current_text) - if remaining_delta: - accumulated_deltas.append(remaining_delta) - # If no remaining content but we have tool calls, return empty delta - elif len(self.prev_tool_call_arr) > 0: - # Check if all tool calls are closed - open_calls = current_text.count( - self.tool_call_start_token) - current_text.count( - self.tool_call_end_token) - if open_calls == 0: - accumulated_deltas.append(DeltaMessage(content="")) - - # Return results - return self._format_delta_result(accumulated_deltas) - - def _has_unprocessed_content(self, current_text: str) -> bool: - """Check if there's unprocessed content in the buffer.""" - return self._processed_length < len(current_text) - - def _process_next_chunk( - self, current_text: str - ) -> Union[DeltaMessage, list[DeltaMessage], None]: - """Process next chunk: either regular content or a complete tool call. - - Args: - current_text: Current accumulated text - - Returns: - - DeltaMessage or list of DeltaMessage if processed successfully - - None if cannot proceed (need more tokens) - """ - # Find next tool call start - tool_start_idx = self._find_tool_call_start(current_text, - self._processed_length) - - # Case 1: No tool call found - return remaining content - if tool_start_idx == -1: - return self._process_content(current_text, self._processed_length, - len(current_text)) - - # Case 2: Content before tool call - if tool_start_idx > self._processed_length: - return self._process_content(current_text, self._processed_length, - tool_start_idx) - - # Case 3: Tool call at current position - # Find end of the first complete tool call - tool_end_idx = self._find_first_complete_tool_call_end( - current_text, tool_start_idx) - - if tool_end_idx == -1: - # Tool call incomplete, wait for more tokens - return None - - # Process complete tool call - return self._process_complete_tool_calls(current_text, tool_start_idx, - tool_end_idx) - - def _process_content(self, current_text: str, start_pos: int, - end_pos: int) -> Union[DeltaMessage, None]: - """Process regular content (non-tool-call text). - - Args: - current_text: Current accumulated text - start_pos: Start position in buffer - end_pos: End position in buffer - - Returns: - DeltaMessage with content if non-empty - """ - if start_pos >= end_pos: - return None - - content = current_text[start_pos:end_pos] - - # Check if we're between tool calls - skip whitespace - # Similar to qwen3coder_tool_parser.py logic - if start_pos > 0: - # Check if text before start_pos (after stripping trailing whitespace) ends with - text_before = current_text[:start_pos] - if (text_before.rstrip().endswith(self.tool_call_end_token) - and content.strip() == ""): - # We just ended a tool call, skip whitespace between tool calls - self._processed_length = end_pos - return None - - # Return content if non-empty - if content: - self._processed_length = end_pos - return DeltaMessage(content=content) - - # Mark as processed even if empty - self._processed_length = end_pos - return None - - def _flush_remaining_content( - self, current_text: str) -> Union[DeltaMessage, None]: - """Flush any remaining unprocessed content as regular content. - - Args: - current_text: Current accumulated text - - Used when EOS token is encountered to handle incomplete tool calls. - """ - if not self._has_unprocessed_content(current_text): - return None - - remaining = current_text[self._processed_length:] - if remaining: - self._processed_length = len(current_text) - return DeltaMessage(content=remaining) - - self._processed_length = len(current_text) - return None - - def _format_delta_result( - self, deltas: list[DeltaMessage] - ) -> Union[DeltaMessage, list[DeltaMessage], None]: - """Format delta result for return. - - Args: - deltas: List of delta messages - - Returns: - - None if empty - - Single DeltaMessage if only one - - List of DeltaMessage if multiple - """ - if not deltas: - return None - elif len(deltas) == 1: - return deltas[0] - else: - return deltas - - def _process_complete_tool_calls( - self, current_text: str, start_pos: int, - end_pos: int) -> Union[list[DeltaMessage], None]: - """Process complete tool calls and convert to delta sequence. - - Args: - current_text: Current accumulated text - start_pos: Start position (should be at ) - end_pos: End position (after ) - - Returns: - List of DeltaMessage if successful, None otherwise - """ - try: - # Extract text segment containing complete tool call(s) - text_to_parse = current_text[start_pos:end_pos] - - # Parse using non-streaming method - result = self.extract_tool_calls(text_to_parse, - self.streaming_request) - - # Case 1: Successfully parsed tool calls - if result.tools_called and result.tool_calls: - # Note: Due to _find_first_complete_tool_call_end, we typically - # process only one tool call at a time - # but we can also process multiple tool calls below - deltas = self._build_tool_call_deltas(result.tool_calls, - text_to_parse) - self._update_state_after_tool_calls(result.tool_calls, end_pos) - return deltas if deltas else None - - # Case 2: Parsing failed - treat as regular content - self._processed_length = end_pos - return [DeltaMessage(content=text_to_parse)] - - except Exception as e: - # Exception during parsing - treat as content - logger.debug( - f"Failed to parse tool calls: {e}, treating as content") - self._processed_length = end_pos - failed_text = current_text[start_pos:end_pos] - return [DeltaMessage(content=failed_text)] if failed_text else None - - def _build_tool_call_deltas(self, tool_calls: list[ToolCall], - parsed_text: str) -> list[DeltaMessage]: - """Build delta messages from parsed tool calls with interleaved content. - - Args: - tool_calls: List of parsed tool calls - parsed_text: Original text that was parsed - - Returns: - List of DeltaMessage with tool calls and content interleaved - """ - deltas = [] - - # Extract content segments between tool calls - content_segments = self._extract_content_between_tool_calls_list( - parsed_text) - - # Build deltas: tool_call[i] -> content[i] (if exists) - for i, tool_call in enumerate(tool_calls): - # Convert tool call to delta sequence - tool_deltas = self._convert_tool_calls_to_deltas( - [tool_call], self._tool_call_index + i) - deltas.extend(tool_deltas) - - # Add content after this tool call if exists - if i < len(content_segments) and content_segments[i]: - deltas.append(DeltaMessage(content=content_segments[i])) - - return deltas - - def _update_state_after_tool_calls(self, tool_calls: list[ToolCall], - end_pos: int) -> None: - """Update internal state after processing tool calls. - - Args: - tool_calls: List of processed tool calls - end_pos: End position in buffer - """ - # Update processed position - self._processed_length = end_pos - - # Update tool call index - self._tool_call_index += len(tool_calls) - - # Update prev_tool_call_arr for finish_reason - self.prev_tool_call_arr.clear() - for tool_call in tool_calls: - if tool_call: - self.prev_tool_call_arr.append({ - "name": - tool_call.function.name, - "arguments": - tool_call.function.arguments, - }) diff --git a/vllm/tool_parsers/step3p5_tool_parser.py b/vllm/tool_parsers/step3p5_tool_parser.py new file mode 100644 index 000000000000..32704079e965 --- /dev/null +++ b/vllm/tool_parsers/step3p5_tool_parser.py @@ -0,0 +1,1401 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast +import json +from collections.abc import Sequence +from typing import Any +from xml.parsers.expat import ParserCreate + +import regex as re + +from vllm.entrypoints.chat_utils import make_tool_call_id +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, +) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, + ToolParserManager, +) + +logger = init_logger(__name__) + +class StreamingXMLToolCallParser: + """ + Simplified streaming XML tool call parser + Supports streaming input, parsing, and output + """ + + def __init__(self): + self.reset_streaming_state() + + # Tool configuration information + self.tools: list[ChatCompletionToolsParam] | None = None + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + self.function_start_token: str = " DeltaMessage: + """ + Parse single streaming XML chunk and return Delta response + This is the actual streaming interface that receives chunks + one by one and maintains internal state + + Args: + xml_chunk: Single XML chunk string + Returns: + DeltaMessage: Contains delta information generated by this chunk, + returns empty response if no complete elements + """ + # Record delta count before processing + initial_delta_count = len(self.deltas) + + self.streaming_buffer += xml_chunk + + found_elements = self._process_complete_xml_elements() + + if found_elements: + # If complete elements found, check if end events were missed + # some tags may not have been triggered + try: + new_deltas = self.deltas[initial_delta_count:] + # If this chunk contains + # but didn't generate '}', then complete it + if (self.current_call_id is not None + and self.function_end_token in xml_chunk): + # - Added '}' (non-empty parameter ending) + # - Added '{}' (empty parameter function) + has_function_close = any((td.tool_calls and any( + (tc.function and tc.id == self.current_call_id + and isinstance(tc.function.arguments, str) and + (tc.function.arguments in ("}", "{}"))) + for tc in td.tool_calls)) for td in new_deltas) + if not has_function_close: + # Close potentially unclosed element + if self.current_param_name: + self._end_element("parameter") + if self.current_function_name: + self._end_element("function") + # If this chunk contains + # but didn't generate final empty delta, then complete it + if (self.current_call_id is not None + and self.tool_call_end_token in xml_chunk): + has_toolcall_close = any((td.tool_calls and any( + (tc.type == "function" and tc.function and tc.function. + arguments == "" and tc.id == self.current_call_id) + for tc in td.tool_calls)) for td in new_deltas) + if not has_toolcall_close: + # Close potentially unclosed element + if self.current_param_name: + self._end_element("parameter") + if self.current_function_name: + self._end_element("function") + self._end_element("tool_call") + except Exception as e: + logger.warning("Error with fallback parsing: %s", e) + # Merge newly generated deltas into single response + result_delta = self._merge_new_deltas_to_single_response( + initial_delta_count) + return result_delta + else: + # No complete elements, check if there's unoutput text content + if self.text_content_buffer and self.tool_call_index == 0: + # Has text content but no tool_call yet, output text content + text_delta = DeltaMessage(content=self.text_content_buffer) + self._emit_delta(text_delta) + # Clear buffer to avoid duplicate output + self.text_content_buffer = "" + return text_delta + + # If this chunk contains end tags but wasn't triggered by parser, + # manually complete end events + # Only execute when still on the same call as when entered, + # to prevent accidentally closing new calls + # in multi scenarios + if self.current_call_id is not None and ( + self.function_end_token in xml_chunk + or self.tool_call_end_token in xml_chunk): + # Close potentially unclosed element + if self.current_param_name: + self._end_element("parameter") + if self.function_end_token in xml_chunk and self.current_function_name: + self._end_element("function") + if self.tool_call_end_token in xml_chunk: + self._end_element("tool_call") + # Return the merged delta result generated by this fallback + result_delta = self._merge_new_deltas_to_single_response( + initial_delta_count) + return result_delta + + # No complete elements, return empty response + return DeltaMessage(content=None) + + def _escape_xml_special_chars(self, text: str) -> str: + """ + Escape XML special characters + Args: + text: Original text + Returns: + Escaped text + """ + xml_escapes = { + "&": "&", + "<": "<", + ">": ">", + '"': """, + "'": "'", + } + + for char, escape in xml_escapes.items(): + text = text.replace(char, escape) + + return text + + def _process_complete_xml_elements(self) -> bool: + """ + Process complete XML elements in buffer + + Returns: + bool: Whether complete elements were found and processed + """ + found_any = False + + while self.last_processed_pos < len(self.streaming_buffer): + # Find next complete xml element + element, end_pos = self._find_next_complete_element( + self.last_processed_pos) + if element is None: + # No complete element found, wait for more data + break + + # Check if this element should be skipped + if self._should_skip_element(element): + self.last_processed_pos = end_pos + continue + + # Found complete XML element, process it + try: + preprocessed_element = self._preprocess_xml_chunk(element) + # Check if this is the first tool_call start + if ((preprocessed_element.strip().startswith("") or + preprocessed_element.strip().startswith("") + and self.tool_call_index > 0 and self.current_call_id + and self.current_function_name): + # Reset parser state but preserve generated deltas + if self.current_param_name: + self._end_element("parameter") + if self.current_function_open: + self._end_element("function") + # Output final tool_call tail delta + final_delta = DeltaMessage( + role=None, + content=None, + reasoning_content=None, + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, + arguments=""), + ) + ], + ) + self._emit_delta(final_delta) + # Reset XML parser and current call state + self._reset_xml_parser_after_tool_call() + # Parse preprocessed element + self.parser.Parse(preprocessed_element, False) + found_any = True + + except Exception as e: + logger.warning("Error when parsing XML elements: %s", e) + + # Update processed position + self.last_processed_pos = end_pos + + return found_any + + def _fix_incomplete_tag_in_chunk(self, chunk: str) -> str: + """ + Fallback: fix incomplete ) + Examples: , + Also handles missing = cases: -> , -> + Only fixes tags that pass validation (parameter exists in tool definition) + """ + # First, handle missing = cases for function tags + chunk = self._fix_missing_equals_in_function_tag(chunk) + + for tag_type in ["parameter", "function"]: + pattern = f"<{tag_type}=" + if pattern not in chunk: + continue + + start_idx = chunk.find(pattern) + after_tag = chunk[start_idx:] + gt_pos = after_tag.find(">") + lt_pos = after_tag.find("<", len(pattern)) + + # Skip if already well-formed + if (gt_pos != -1 and (lt_pos == -1 or gt_pos < lt_pos) + and pattern in after_tag[:gt_pos]): + continue + + # Extract tag name (stop at space, newline, or <) + content = chunk[start_idx + len(pattern):] + end_pos = next( + (i for i, ch in enumerate(content) if ch in (' ', '\n', '<')), + len(content)) + tag_name = content[:end_pos] + + if not tag_name: + continue + + # Remove duplicate prefix: ', 1) + + return chunk + + def _fix_missing_equals_in_function_tag(self, chunk: str) -> str: + """ + Fix missing = in function tags: or + Examples: + -> + -> + Only fixes if function name exists in tool definition + """ + # already correct + if ' (with space/newline but no =) + pattern1 = r'' + match1 = re.search(pattern1, chunk) + if match1: + func_name = match1.group(1).strip() + # must validate function name exists before fixing + if func_name and self._validate_function_name(func_name): + original = match1.group(0) + fixed = f'' + chunk = chunk.replace(original, fixed, 1) + return chunk + + # Pattern 2: (no space, no =) + # only match ' + match2 = re.search(pattern2, chunk) + if match2: + func_name = match2.group(1).strip() + # must validate function name exists before fixing + if func_name and self._validate_function_name(func_name): + original = match2.group(0) + fixed = f'' + chunk = chunk.replace(original, fixed, 1) + return chunk + + return chunk + + def _validate_function_name(self, func_name: str) -> bool: + """Check if function name exists in tool definitions""" + if not self.tools: + return False + + for tool in self.tools: + if (hasattr(tool, "type") and tool.type == "function" + and hasattr(tool, "function") + and hasattr(tool.function, "name") + and tool.function.name == func_name): + return True + + return False + + def _validate_parameter_name(self, param_name: str) -> bool: + """Check if parameter exists in current function's tool definition""" + if not self.tools or not self.current_function_name: + return True + + for tool in self.tools: + if (hasattr(tool, "type") and tool.type == "function" + and hasattr(tool, "function") + and hasattr(tool.function, "name") + and tool.function.name == self.current_function_name): + if not hasattr(tool.function, "parameters"): + return True + params = tool.function.parameters + if isinstance(params, dict): + properties = params.get("properties", params) + return param_name in properties + break + + return True + + def _should_skip_element(self, element: str) -> bool: + """ + Determine whether an element should be skipped + + Args: + element: Element to evaluate + + Returns: + bool: True means should skip, False means should process + """ + + # If it's a tool_call XML tag, don't skip + if (element.startswith(self.tool_call_start_token) + or element.startswith(self.function_start_token) + or element.startswith(self.parameter_start_token)): + return False + + # If currently not parsing tool calls and not blank, + # collect this text instead of skipping + # Only process other XML elements after tool_call appears, + # otherwise treat as plain text + if self.current_call_id is None and element: + # Collect text content to buffer + self.text_content_buffer += element + return True # Still skip, but content has been collected + + # If currently parsing tool calls, + # this might be parameter value, don't skip + if self.current_call_id is not None: + return False + + # Skip blank content + return not element + + def _find_next_complete_element(self, + start_pos: int) -> tuple[str | None, int]: + """ + Find next complete XML element from specified position + + Args: + start_pos: Position to start searching + + Returns: + (Complete element string, element end position), + returns (None, start_pos) if no complete element found + """ + buffer = self.streaming_buffer[start_pos:] + + if not buffer: + return None, start_pos + + if buffer.startswith("<"): + # Check if this is an incomplete parameter/function tag + # e.g., " not in buffer.split("\n")[0]) + is_incomplete_func = (buffer.startswith("" not in buffer.split("\n")[0]) + + if is_incomplete_param or is_incomplete_func: + # Find the corresponding closing tag + tag_type = "parameter" if is_incomplete_param else "function" + closing_tag = f"" + closing_pos = buffer.find(closing_tag) + + if closing_pos != -1: + # Found closing tag, return complete element including closing tag + complete_element = buffer[:closing_pos + len(closing_tag)] + return complete_element, start_pos + closing_pos + len( + closing_tag) + + # Need to ensure no new < appears, + # find the nearest one between < and > + tag_end = buffer.find("<", 1) + tag_end2 = buffer.find(">", 1) + if tag_end != -1 and tag_end2 != -1: + # Next nearest is < + if tag_end < tag_end2: + return buffer[:tag_end], start_pos + tag_end + # Next nearest is >, means found XML element + else: + return buffer[:tag_end2 + 1], start_pos + tag_end2 + 1 + elif tag_end != -1: + return buffer[:tag_end], start_pos + tag_end + elif tag_end2 != -1: + return buffer[:tag_end2 + 1], start_pos + tag_end2 + 1 + else: + # If currently not parsing tool calls (entering a tool_call), + # check if starts with or + if buffer == ""[:len(buffer)]: + # Might be start of , wait for more data + return None, start_pos + elif (buffer.startswith(" DeltaMessage: + """ + Merge newly generated deltas from this processing + into a single DeltaMessage + + Args: + initial_count: Delta count before processing + + Returns: + Merged DeltaMessage containing all newly generated delta information + """ + if len(self.deltas) <= initial_count: + return DeltaMessage(content=None) + + # Get newly generated deltas + new_deltas = self.deltas[initial_count:] + + if len(new_deltas) == 1: + # Only one new delta, return directly + return new_deltas[0] + + # Merge multiple new deltas + merged_tool_calls: list[DeltaToolCall] = [] + merged_content: str = "" + + for delta in new_deltas: + if delta.content: + merged_content += delta.content + if delta.tool_calls: + # For tool_calls, we need to intelligently merge arguments + for tool_call in delta.tool_calls: + # Find if there's already a tool_call with the same call_id + existing_call = None + for existing in merged_tool_calls: + if existing.id == tool_call.id: + existing_call = existing + break + + if existing_call and existing_call.function: + # Merge to existing tool_call + if tool_call.function and tool_call.function.name: + existing_call.function.name = tool_call.function.name + if (tool_call.function + and tool_call.function.arguments is not None): + if existing_call.function.arguments is None: + existing_call.function.arguments = "" + + # For streaming JSON parameters, + # simply concatenate in order + new_args = tool_call.function.arguments + existing_call.function.arguments += new_args + if tool_call.type: + existing_call.type = tool_call.type + else: + # Add new tool_call + merged_tool_calls.append(tool_call) + + return DeltaMessage( + content=merged_content if merged_content else None, + tool_calls=merged_tool_calls, + ) + + def _preprocess_xml_chunk(self, chunk: str) -> str: + """ + Preprocess XML chunk, handle non-standard formats, + and escape special characters + + Args: + chunk: Original XML chunk + + Returns: + Processed XML chunk + """ + + # Check if this is a tool_call related element + is_tool_call = False + if chunk.startswith(self.tool_call_start_token) or chunk.startswith( + self.tool_call_end_token): + is_tool_call = True + # Check for function tags (including malformed ones without =) + # , , , + if (chunk.startswith(self.function_start_token) + or chunk.startswith(self.function_end_token) + or chunk.startswith(" + # This handles cases like: format -> + processed = re.sub(r"]+)>", r'', + chunk) + # Handle format -> + processed = re.sub(r"]+)>", r'', + processed) + + original_chunk = chunk + # If in parameter value accumulation mode + if self._pre_inside_parameter: + # Parameter end: output accumulated raw text + # safely then return + if processed.startswith(""): + body_text = self._pre_param_buffer + # Trigger deferred parsing mode + # literal_eval+json output in end_element + self.defer_current_parameter = True + self.deferred_param_raw_value = body_text + # Clean up state + self._pre_inside_parameter = False + self._pre_param_buffer = "" + self._pre_current_param_name = None + safe_text = self._escape_xml_special_chars(body_text) + return f"{safe_text}" + else: + # If this is the first block of content after entering parameter + # evaluate if deferred parsing is needed; + # If not needed, exit accumulation mode + # and pass through directly + if self._pre_param_buffer == "": + # Get current parameter type + param_type = (self._get_param_type( + self._pre_current_param_name) if + self._pre_current_param_name else "string") + # Only these types need deferred parsing to + # handle Python literals containing single quotes + is_object_type = param_type in ["object"] + is_complex_type = (param_type + in ["array", "arr", "sequence"] + or param_type.startswith("dict") + or param_type.startswith("list")) + + # Only delay when contains container symbols + # and has single quotes and is complex type + has_container_hint = (("[" in original_chunk) + or ("{" in original_chunk) + or ("(" in original_chunk)) + + # Determine if deferred parsing is needed + need_defer = False + if is_complex_type: + # Complex type, always need deferred parsing + need_defer = True + elif (is_object_type and has_container_hint + and ("'" in original_chunk)): + # Object type with container symbols + # and single quotes, need deferred parsing + need_defer = True + + if not need_defer: + # No need for deferred parsing, + # exit parameter mode directly + self._pre_inside_parameter = False + return self._escape_xml_special_chars(original_chunk) + self._pre_param_buffer += original_chunk + return "" + + # Parameter start: enable accumulation + if processed.startswith("', processed) + if m: + self._pre_current_param_name = m.group(1) + self._pre_inside_parameter = True + self._pre_param_buffer = "" + return processed + + # If processed doesn't contain special_token, escape processed + # This is because XML parsing encounters special characters + # and reports errors, so escaping is needed + if not is_tool_call: + processed = self._escape_xml_special_chars(processed) + return processed + + def _emit_delta(self, delta: DeltaMessage): + """Emit Delta response (streaming output)""" + self.deltas.append(delta) + + def _auto_close_open_parameter_if_needed(self, + incoming_tag: str | None = None): + """Before starting to process new elements, + if there are unclosed tags from before, + automatically complete their endings to the parser. + - If there are unclosed parameters, + it's equivalent to feeding `` + - When about to start a new function or tool_call, + if there are unclosed functions, complete ``. + - When about to start a new tool_call, + if there are unclosed tool_calls, complete ``. + """ + # First close unclosed parameters + if self.current_param_name: + self._end_element("parameter") + + # If about to start new function or tool_call, + # and there are unclosed functions, close function first + if incoming_tag in ("function", + "tool_call") and self.current_function_name: + self._end_element("function") + + # If about to start new tool_call, + # and there are unclosed tool_calls, close tool_call first + if incoming_tag == "tool_call" and self.current_call_id: + self._end_element("tool_call") + + def _start_element(self, name: str, attrs: dict[str, str]): + """Handle XML start element events""" + + if name == "root": + return + + if name == "tool_call": + # Before opening new tool_call, + # automatically complete previous unclosed tags + self._auto_close_open_parameter_if_needed("tool_call") + + self.parameters = {} + self.current_call_id = make_tool_call_id() + self.current_param_is_first = True + self.tool_call_index += 1 + elif name.startswith("function") or (name == "function"): + # If missing tool_call, manually complete + if not self.current_call_id: + self._start_element("tool_call", {}) + # Before opening new function, + # automatically complete previous unclosed tags (parameter/function) + self._auto_close_open_parameter_if_needed("function") + function_name = self._extract_function_name(name, attrs) + self.current_function_name = function_name + self.current_function_open = True + if function_name: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=function_name, + arguments=""), + ) + ]) + self._emit_delta(delta) + elif name.startswith("parameter") or (name == "parameter"): + # If previous parameter hasn't ended normally, + # complete its end first, then start new parameter + self._auto_close_open_parameter_if_needed("parameter") + param_name = self._extract_parameter_name(name, attrs) + self.current_param_name = param_name + self.current_param_value = "" + self.current_param_value_converted = "" + self.start_quote_emitted = False # Reset start quote flag + + # Only output parameter name and colon, + # don't output quotes + # decide after parameter value type is determined + if param_name: + if not self.parameters: + # First parameter + # start JSON, only output parameter name and colon + json_start = f'{{"{param_name}": ' + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, + arguments=json_start), + ) + ]) + self._emit_delta(delta) + self.current_param_is_first = True + else: + # Subsequent parameters + # add comma and parameter name, no quotes + json_continue = f', "{param_name}": ' + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=json_continue), + ) + ]) + self._emit_delta(delta) + self.current_param_is_first = False + + def _char_data(self, data: str): + """Handle XML character data events""" + if data and self.current_param_name: + # If preprocessing stage determines deferred parsing is needed, + # only cache character data, no streaming output + if self.defer_current_parameter: + original_data = data + if self.should_emit_end_newline: + original_data = "\n" + original_data + self.should_emit_end_newline = False + if original_data.endswith("\n"): + self.should_emit_end_newline = True + original_data = original_data[:-1] + self.current_param_value += original_data + return + + param_type = self._get_param_type(self.current_param_name) + + # Check if this is the first time receiving data for this parameter + # If this is the first packet of data and starts with \n, remove \n + if not self.current_param_value and data.startswith("\n"): + data = data[1:] + + # Output start quote for string type (if not already output) + if (param_type + in ["string", "str", "text", "varchar", "char", "enum"] + and not self.start_quote_emitted): + quote_delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='"'), + ) + ]) + self._emit_delta(quote_delta) + self.start_quote_emitted = True + + if not data: + return + + original_data = data + # Delay output of trailing newline + if self.should_emit_end_newline: + original_data = "\n" + original_data + self.should_emit_end_newline = False + if original_data.endswith("\n"): + self.should_emit_end_newline = True + original_data = original_data[:-1] + self.current_param_value += original_data + + # convert parameter value by param_type + converted_value = self._convert_param_value( + self.current_param_value, param_type) + output_data = self._convert_for_json_streaming( + converted_value, param_type) + + delta_data = output_data[len(self.current_param_value_converted):] + self.current_param_value_converted = output_data + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, + arguments=delta_data), + ) + ]) + self._emit_delta(delta) + + def _end_element(self, name: str): + """Handle XML end element events""" + + if name == "root": + return + + # If function or tool_call ends and there are still unclosed parameters, + # complete parameter end first + if (name.startswith("function") or name == "function" + or name == "tool_call") and self.current_param_name: + self._auto_close_open_parameter_if_needed() + + if (name.startswith("parameter") + or name == "parameter") and self.current_param_name: + # End current parameter + param_name = self.current_param_name + param_value = self.current_param_value + + # If in deferred parsing mode, + # perform overall parsing on raw content + # accumulated in preprocessing stage and output once + if self.defer_current_parameter: + raw_text = (self.deferred_param_raw_value + if self.deferred_param_raw_value else param_value) + parsed_value = None + output_arguments = None + try: + # If previously delayed trailing newline, + # add it back before parsing + if self.should_emit_end_newline: + raw_for_parse = raw_text + "\n" + else: + raw_for_parse = raw_text + parsed_value = ast.literal_eval(raw_for_parse) + output_arguments = json.dumps(parsed_value, + ensure_ascii=False) + except Exception: + # Fallback: output as string as-is + output_arguments = json.dumps(raw_text, ensure_ascii=False) + parsed_value = raw_text + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, + arguments=output_arguments), + ) + ]) + self._emit_delta(delta) + + # Clean up and store + self.should_emit_end_newline = False + self.parameters[param_name] = parsed_value + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.start_quote_emitted = False + self.defer_current_parameter = False + self.deferred_param_raw_value = "" + return + + param_type = self._get_param_type(param_name) + + # convert complete parameter value by param_type + converted_value = self._convert_param_value( + param_value, param_type) + + # Decide whether to add end quote based on parameter type + if param_type in [ + "string", "str", "text", "varchar", "char", "enum" + ]: + # For empty string parameters, need special handling + if not param_value and not self.start_quote_emitted: + # No start quote output, + # directly output complete empty string + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, + arguments='""'), + ) + ]) + self._emit_delta(delta) + else: + # Non-empty parameter value, output end quote + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, + arguments='"'), + ) + ]) + self._emit_delta(delta) + + self.should_emit_end_newline = False + # Store converted value + self.parameters[param_name] = converted_value + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.start_quote_emitted = False + + elif name.startswith("function") or name == "function": + # if there are parameters, close JSON object + if self.parameters: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments="}"), + ) + ]) + self._emit_delta(delta) + # return empty object + else: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments="{}"), + ) + ]) + self._emit_delta(delta) + self.current_function_open = False + self.current_function_name = None # Clear function name to prevent duplicate closing + + elif name == "tool_call": + # Before ending tool_call, + # ensure function is closed to complete missing right brace + if self.current_function_open: + # If there are still unclosed parameters, close them first + if self.current_param_name: + self._end_element("parameter") + # Close function, ensure output '}' or '{}' + self._end_element("function") + # Final Delta + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments=""), + ) + ]) + self._emit_delta(delta) + + # Check if there's text content to output (between tool_calls) + if self.text_content_buffer.strip(): + text_delta = DeltaMessage(content=self.text_content_buffer) + self._emit_delta(text_delta) + + self._reset_xml_parser_after_tool_call() + + def setup_parser(self): + """Set up XML parser event handlers""" + self.parser.buffer_text = True + self.parser.StartElementHandler = self._start_element + self.parser.EndElementHandler = self._end_element + self.parser.CharacterDataHandler = self._char_data + + def set_tools(self, tools: list[ChatCompletionToolsParam] | None): + """Set tool configuration information""" + self.tools = tools + + def _extract_function_name(self, name: str, + attrs: dict[str, str]) -> str | None: + """Extract function name from various formats""" + if attrs and "name" in attrs: + return attrs["name"] + + if "=" in name: + parts = name.split("=", 1) + if len(parts) == 2 and parts[0] == "function": + return parts[1] + + return None + + def _extract_parameter_name(self, name: str, + attrs: dict[str, str]) -> str | None: + """Extract parameter name from various formats""" + if attrs and "name" in attrs: + return attrs["name"] + + if "=" in name: + parts = name.split("=", 1) + if len(parts) == 2 and parts[0] == "parameter": + return parts[1] + + return None + + def _get_param_type(self, param_name: str) -> str: + """Get parameter type based on tool configuration, defaults to string + Args: + param_name: Parameter name + + Returns: + Parameter type + """ + if not self.tools or not self.current_function_name: + return "string" + + for tool in self.tools: + if not hasattr(tool, "type") or not (hasattr( + tool, "function") and hasattr(tool.function, "name")): + continue + if (tool.type == "function" + and tool.function.name == self.current_function_name): + if not hasattr(tool.function, "parameters"): + return "string" + params = tool.function.parameters + if isinstance(params, dict) and "properties" in params: + properties = params["properties"] + if param_name in properties and isinstance( + properties[param_name], dict): + return self.repair_param_type( + str(properties[param_name].get("type", "string"))) + elif isinstance(params, dict) and param_name in params: + param_config = params[param_name] + if isinstance(param_config, dict): + return self.repair_param_type( + str(param_config.get("type", "string"))) + break + return "string" + + def repair_param_type(self, param_type: str) -> str: + """Repair unknown parameter types by treating them as string + Args: + param_type: Parameter type + + Returns: + Repaired parameter type + """ + if (param_type in ["string", "str", "text", "varchar", "char", "enum"] + or param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + or param_type.startswith("num") + or param_type.startswith("float") + or param_type in ["boolean", "bool", "binary"] + or (param_type in ["object", "array", "arr", "sequence"] + or param_type.startswith("dict") + or param_type.startswith("list"))): + return param_type + else: + return "string" + + def _convert_param_value(self, param_value: str, param_type: str) -> Any: + """Convert value based on parameter type + Args: + param_value: Parameter value + param_type: Parameter type + + Returns: + Converted value + """ + if param_value.lower() == "null": + return None + + param_type = param_type.strip().lower() + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + return param_value + elif (param_type.startswith("int") or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned")): + try: + return int(param_value) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' is not an integer, " + "degenerating to string.", + param_value, + ) + return param_value + elif param_type.startswith("num") or param_type.startswith("float"): + try: + float_param_value: float = float(param_value) + return (float_param_value if float_param_value - + int(float_param_value) != 0 else + int(float_param_value)) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' is not a float, " + "degenerating to string.", + param_value, + ) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + return param_value == "true" + else: + return param_value + + def _convert_for_json_streaming(self, converted_value: Any, + param_type: str) -> str: + """Convert converted_value based on + whether it's empty and if type is string + Args: + converted_value: Converted value + param_type: Parameter type + + Returns: + Converted string for streaming output + """ + # Check if value is empty, but exclude numeric 0 + if converted_value is None or converted_value == "": + return "" + + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + # String type, remove double quotes + return json.dumps(converted_value, ensure_ascii=False)[1:-1] + else: + # Non-string type, return complete JSON string + if not isinstance(converted_value, str): + return json.dumps(converted_value, ensure_ascii=False) + else: + return converted_value + + def _reset_xml_parser_after_tool_call(self): + """ + Each tool_call is treated as a separate XML document, + so we need to reset the parser after each tool_call. + """ + + # recreate XML parser + self.parser = ParserCreate() + self.setup_parser() + + # Reset current tool_call state + if self.current_call_id: + self.last_completed_call_id = self.current_call_id + self.current_call_id = None + self.current_function_name = None + self.current_function_open = False + self.parameters = {} + self.current_param_name = None + self.current_param_value = "" + self.current_param_value_converted = "" + self.current_param_is_first = False + self.should_emit_end_newline = False + self.start_quote_emitted = False + self.text_content_buffer = "" + + # Reset preprocessing and deferred parsing state + self._pre_inside_parameter = False + self._pre_param_buffer = "" + self._pre_current_param_name = None + self.defer_current_parameter = False + self.deferred_param_raw_value = "" + + +@ToolParserManager.register_module("step3p5") +class Step3p5ToolParser(ToolParser): + + def __init__(self, tokenizer: TokenizerLike): + super().__init__(tokenizer) + self.parser = StreamingXMLToolCallParser() + + # Add missing attributes for compatibility with serving_chat.py + self.prev_tool_call_arr: list[dict] = [] + self.streamed_args_for_tool: list[str] = [] + + logger.info("vLLM Successfully import tool parser %s !", + self.__class__.__name__) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + self.parser.reset_streaming_state() + # Reset tool call tracking arrays for new extraction + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [] + if request: + self.parser.set_tools(request.tools) + result = self.parser.parse_single_streaming_chunks(model_output) + if not result.tool_calls: + return ExtractedToolCallInformation( + tool_calls=[], + tools_called=False, + content=result.content, + ) + else: + tool_calls = [] + for tool_call in result.tool_calls: + if tool_call.function and tool_call.function.name: + tool_calls.append( + ToolCall( + id=tool_call.id, + type=tool_call.type, + function=FunctionCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + )) + + # Update tool call tracking arrays for compatibility + tool_index = (tool_call.index + if tool_call.index is not None else + len(self.prev_tool_call_arr) - 1) + + # Ensure we have enough entries in our tracking arrays + while len(self.prev_tool_call_arr) <= tool_index: + self.prev_tool_call_arr.append({ + "name": "", + "arguments": "" + }) + while len(self.streamed_args_for_tool) <= tool_index: + self.streamed_args_for_tool.append("") + + # Update tool call information + self.prev_tool_call_arr[tool_index]["name"] = ( + tool_call.function.name) + self.prev_tool_call_arr[tool_index]["arguments"] = ( + tool_call.function.arguments) + + # Update streamed arguments + if tool_call.function.arguments: + self.streamed_args_for_tool[tool_index] = ( + tool_call.function.arguments) + + return ExtractedToolCallInformation( + tool_calls=tool_calls, + tools_called=len(tool_calls) > 0, + content=result.content, + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + if not previous_text: + self.parser.reset_streaming_state() + # Reset tool call tracking arrays for new streaming session + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [] + if request: + self.parser.set_tools(request.tools) + + # Model sometimes outputs separately causing delta_text to be empty. + # If there were tool_calls before and all current tool_calls have ended, + # return an empty tool_call for outer streaming output + # to correctly output tool_call field + if not delta_text and delta_token_ids: + open_calls = current_text.count( + self.parser.tool_call_start_token) - current_text.count( + self.parser.tool_call_end_token) + if (open_calls == 0 and self.parser.tool_call_index > 0 + or not self.parser.tool_call_index and current_text): + return DeltaMessage(content="") + return None + + # Parse the delta text and get the result + result = self.parser.parse_single_streaming_chunks(delta_text) + + # Update tool call tracking arrays based on incremental parsing results + if result and result.tool_calls: + for tool_call in result.tool_calls: + if tool_call.function: + tool_index = (tool_call.index + if tool_call.index is not None else + len(self.prev_tool_call_arr) - 1) + + # Ensure we have enough entries in our tracking arrays + while len(self.prev_tool_call_arr) <= tool_index: + self.prev_tool_call_arr.append({ + "name": "", + "arguments": "" + }) + while len(self.streamed_args_for_tool) <= tool_index: + self.streamed_args_for_tool.append("") + + # Update tool name if provided + if tool_call.function.name: + self.prev_tool_call_arr[tool_index]["name"] = ( + tool_call.function.name) + + # Update arguments incrementally + if tool_call.function.arguments is not None: + # Concatenate the incremental arguments + # to the existing streamed arguments + self.prev_tool_call_arr[tool_index]["arguments"] += ( + tool_call.function.arguments) + self.streamed_args_for_tool[tool_index] += ( + tool_call.function.arguments) + return result + + def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool: + """ + Skip the remaining_call calculation in serving_chat + """ + return False From 3268d784ac025fbd2b1ffc85add01d3de5085396 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Wed, 28 Jan 2026 21:19:09 +0800 Subject: [PATCH 06/34] fix: attn module import error --- vllm/model_executor/models/step3p5.py | 2 +- vllm/reasoning/step3p5_reasoning_parser.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index 31a95d360a8b..9f5c8d1433cd 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -8,7 +8,6 @@ from torch import nn import vllm.envs as envs -from vllm.attention.layer import Attention from vllm.v1.attention.backend import AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig @@ -19,6 +18,7 @@ get_tp_group) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul, SwigluStepAndMul +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, diff --git a/vllm/reasoning/step3p5_reasoning_parser.py b/vllm/reasoning/step3p5_reasoning_parser.py index 93aa7f5ee08d..c1c5a5123e43 100644 --- a/vllm/reasoning/step3p5_reasoning_parser.py +++ b/vllm/reasoning/step3p5_reasoning_parser.py @@ -122,6 +122,7 @@ def extract_reasoning_streaming( # Content: handle the newline immediately after . if content_to_output is not None: + self.end_offset -= 1 # If we have content, reasoning must have ended. self._pending_reasoning_newline = False From b020ee213dd4abee98f59fa921343c36af709c63 Mon Sep 17 00:00:00 2001 From: i-zhangmingming Date: Wed, 28 Jan 2026 14:16:51 +0000 Subject: [PATCH 07/34] feat: support mtp3 --- examples/offline_inference/spec_decode.py | 15 +- tests/v1/spec_decode/test_mtp3.py | 661 ++++++++++++++++++++++ vllm/config/speculative.py | 21 +- vllm/model_executor/models/step3p5.py | 1 + vllm/model_executor/models/step3p5_mtp.py | 164 +++--- vllm/v1/core/kv_cache_utils.py | 19 +- vllm/v1/spec_decode/eagle.py | 11 +- vllm/v1/spec_decode/multi_layer_eagle.py | 475 ++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 18 +- 9 files changed, 1294 insertions(+), 91 deletions(-) create mode 100644 tests/v1/spec_decode/test_mtp3.py create mode 100644 vllm/v1/spec_decode/multi_layer_eagle.py diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 1e3e310e7fa5..d890c1e6c673 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -62,6 +62,11 @@ def parse_args(): parser.add_argument("--tp", type=int, default=1) parser.add_argument("--enforce-eager", action="store_true") parser.add_argument("--enable-chunked-prefill", action="store_true") + parser.add_argument( + "--enable-multi-layers-mtp", + action="store_true", + help="Enable multi-layer MTP (only effective when --method=mtp).", + ) parser.add_argument("--max-model-len", type=int, default=16384) parser.add_argument("--temp", type=float, default=0) parser.add_argument("--top-p", type=float, default=1.0) @@ -71,6 +76,7 @@ def parse_args(): parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) parser.add_argument("--draft-model", type=str, default=None) + parser.add_argument("--tokenizer-dir", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) parser.add_argument("--disable-padded-drafter-batch", action="store_true") @@ -90,7 +96,12 @@ def main(args): "please specify model_dir to give a mm based model" ) model_dir = "meta-llama/Llama-3.1-8B-Instruct" - tokenizer = AutoTokenizer.from_pretrained(model_dir) + + tokenizer_dir = args.tokenizer_dir + if tokenizer_dir is None: + tokenizer_dir = model_dir + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) args.custom_skip_chat_template = True if not args.custom_mm_prompts: @@ -138,6 +149,8 @@ def main(args): "method": "mtp", "num_speculative_tokens": args.num_spec_tokens, } + if args.enable_multi_layers_mtp: + speculative_config["enable_multi_layers_mtp"] = True else: raise ValueError(f"unknown method: {args.method}") diff --git a/tests/v1/spec_decode/test_mtp3.py b/tests/v1/spec_decode/test_mtp3.py new file mode 100644 index 000000000000..5eff1f6a3075 --- /dev/null +++ b/tests/v1/spec_decode/test_mtp3.py @@ -0,0 +1,661 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +import pytest +import torch + +from vllm.v1.attention.backend import AttentionMetadataBuilder +from vllm.v1.spec_decode.multi_layer_eagle import ( + DraftInputStates, + MultiLayerEagleProposer, +) + + +class DummyBuilder(AttentionMetadataBuilder): + def __init__(self, return_value: str): + # attention metadata builders normally take multiple runtime args; + # the test double shortcuts that setup. + self.return_value = return_value + self.calls: list[dict] = [] + self.kv_cache_spec = None + self.layer_names: list[str] = [] + self.vllm_config = None + self.device = torch.device("cpu") + + def build( + self, common_prefix_len: int, common_attn_metadata, fast_build: bool = False + ): + self.calls.append( + { + "common_prefix_len": common_prefix_len, + "common_attn_metadata": common_attn_metadata, + "fast_build": fast_build, + } + ) + return self.return_value + + +@pytest.fixture +def proposer_stub(): + proposer = MultiLayerEagleProposer.__new__(MultiLayerEagleProposer) + proposer.layer_num = 3 + proposer.running_req_ids = ["req-0"] + proposer.attn_layer_names = ["attn_layer"] + proposer.indexer_layer_names = ["indexer_layer"] + proposer.attn_metadata_builder = DummyBuilder("attn_meta") + proposer.draft_indexer_metadata_builder = DummyBuilder("indexer_meta") + proposer.draft_input_states_pool = { + "req-0": DraftInputStates( + len=3, + token_ids=torch.tensor([800, 801, 802], dtype=torch.int32), + hidden_states=torch.tensor( + [[30.0, 31.0, 32.0], [33.0, 34.0, 35.0], [36.0, 37.0, 38.0]] + ), + positions=torch.tensor([0, 1, 2], dtype=torch.int64), + slot_mapping=torch.tensor([900, 901, 902], dtype=torch.int32), + ), + "req-1": DraftInputStates( + len=2, + token_ids=torch.tensor([910, 911], dtype=torch.int32), + hidden_states=torch.tensor([[40.0, 41.0, 42.0], [43.0, 44.0, 45.0]]), + positions=torch.tensor([0, 1], dtype=torch.int64), + slot_mapping=torch.tensor([990, 991], dtype=torch.int32), + ), + "req-2": DraftInputStates( + len=3, + token_ids=torch.tensor([820, 821, 822], dtype=torch.int32), + hidden_states=torch.tensor( + [[46.0, 47.0, 48.0], [49.0, 50.0, 51.0], [52.0, 53.0, 54.0]] + ), + positions=torch.tensor([0, 1, 2], dtype=torch.int64), + slot_mapping=torch.tensor([920, 921, 922], dtype=torch.int32), + ), + "req-3": DraftInputStates( + len=3, + token_ids=torch.tensor([830, 831, 832], dtype=torch.int32), + hidden_states=torch.tensor( + [[55.0, 56.0, 57.0], [58.0, 59.0, 60.0], [61.0, 62.0, 63.0]] + ), + positions=torch.tensor([0, 1, 2], dtype=torch.int64), + slot_mapping=torch.tensor([930, 931, 932], dtype=torch.int32), + ), + } + return proposer + + +LAYER3_CASES = [ + { + "name": "layer3_shift0_sequence_end", + "batch_size": 1, + "running_req_ids": ["req-0"], + "target_token_ids": [10, 11, 12, 13], + "target_positions": [0, 1, 2, 3], + "last_token_indices": [3], + "common_attn_metadata": { + "query_start_loc": [0, 4], + "query_start_loc_cpu": [0, 4], + "seq_lens": [4], + "seq_lens_cpu": [4], + "num_computed_tokens_cpu": [0], + "slot_mapping": [100, 101, 102, 103], + "max_seq_len": 4, + }, + "expected": { + "prev_token_ids": [10, 11, 12, 13], + "prev_positions": [0, 1, 2, 3], + "last_token_indices": [3], + "seq_lens": [4], + "seq_lens_cpu": [4], + "num_computed_tokens_cpu": [0], + "slot_mapping": [100, 101, 102, 103], + "max_seq_len": 4, + "cached_by_req": { + "req-0": { + "len": 3, + "token_ids": [11, 12, 13], + "positions": [1, 2, 3], + }, + }, + }, + }, + { + "name": "layer3_batch2_short_seq_no_shift", + "batch_size": 2, + "running_req_ids": ["req-0", "req-1"], + "target_token_ids": [10, 11, 20], + "target_positions": [0, 1, 0], + "last_token_indices": [1, 2], + "common_attn_metadata": { + "query_start_loc": [0, 2, 3], + "query_start_loc_cpu": [0, 2, 3], + "seq_lens": [2, 1], + "seq_lens_cpu": [2, 1], + "num_computed_tokens_cpu": [0, 0], + "slot_mapping": [100, 101, 200], + "max_seq_len": 2, + }, + "expected": { + "prev_token_ids": [10, 11, 20], + "prev_positions": [0, 1, 0], + "last_token_indices": [1, 2], + "seq_lens": [2, 1], + "seq_lens_cpu": [2, 1], + "num_computed_tokens_cpu": [0, 0], + "slot_mapping": [100, 101, 200], + "max_seq_len": 2, + "cached_by_req": { + "req-0": { + "len": 2, + "token_ids": [10, 11], + "positions": [0, 1], + }, + "req-1": { + "len": 1, + "token_ids": [20], + "positions": [0], + }, + }, + }, + }, + { + "name": "layer3_batch2_short_seq_shift_first", + "batch_size": 2, + "running_req_ids": ["req-0", "req-1"], + "target_token_ids": [10, 11, 20], + "target_positions": [1, 2, 0], + "last_token_indices": [0, 2], + "common_attn_metadata": { + "query_start_loc": [0, 2, 3], + "query_start_loc_cpu": [0, 2, 3], + "seq_lens": [2, 1], + "seq_lens_cpu": [2, 1], + "num_computed_tokens_cpu": [1, 0], + "slot_mapping": [100, 101, 200], + "max_seq_len": 2, + }, + "expected": { + "prev_token_ids": [802, 10, 20], + "prev_positions": [2, 1, 0], + "last_token_indices": [1, 2], + "seq_lens": [1, 1], + "seq_lens_cpu": [1, 1], + "num_computed_tokens_cpu": [0, 0], + "slot_mapping": [902, 100, 200], + "max_seq_len": 1, + "cached_by_req": { + "req-0": { + "len": 2, + "token_ids": [802, 10], + "positions": [2, 1], + }, + "req-1": { + "len": 1, + "token_ids": [20], + "positions": [0], + }, + }, + }, + }, + { + "name": "layer3_short_seq_len2_shift0_cache1", + "batch_size": 1, + "running_req_ids": ["req-0"], + "target_token_ids": [7, 8], + "target_positions": [0, 1], + "last_token_indices": [0], + "common_attn_metadata": { + "query_start_loc": [0, 2], + "query_start_loc_cpu": [0, 2], + "seq_lens": [2], + "seq_lens_cpu": [2], + "num_computed_tokens_cpu": [0], + "slot_mapping": [1000, 1001], + "max_seq_len": 2, + }, + "expected": { + "prev_token_ids": [7, 8], + "prev_positions": [0, 1], + "last_token_indices": [0], + "seq_lens": [2], + "seq_lens_cpu": [2], + "num_computed_tokens_cpu": [0], + "slot_mapping": [1000, 1001], + "max_seq_len": 2, + "cached_by_req": { + "req-0": { + "len": 1, + "token_ids": [7], + "positions": [0], + }, + }, + }, + }, + { + "name": "layer3_short_seq_len2_shift1_cache2", + "batch_size": 1, + "running_req_ids": ["req-0"], + "target_token_ids": [7, 8], + "target_positions": [1, 2], + "last_token_indices": [0], + "common_attn_metadata": { + "query_start_loc": [0, 2], + "query_start_loc_cpu": [0, 2], + "seq_lens": [2], + "seq_lens_cpu": [2], + "num_computed_tokens_cpu": [1], + "slot_mapping": [1000, 1001], + "max_seq_len": 2, + }, + "expected": { + "prev_token_ids": [802, 7], + "prev_positions": [2, 1], + "last_token_indices": [1], + "seq_lens": [1], + "seq_lens_cpu": [1], + "num_computed_tokens_cpu": [0], + "slot_mapping": [902, 1000], + "max_seq_len": 1, + "cached_by_req": { + "req-0": { + "len": 2, + "token_ids": [802, 7], + "positions": [2, 1], + }, + }, + }, + }, + { + "name": "layer3_shift_bounded_start_pos0", + "batch_size": 1, + "running_req_ids": ["req-0"], + "target_token_ids": [10, 11, 12, 13], + "target_positions": [0, 1, 2, 3], + "last_token_indices": [1], + "common_attn_metadata": { + "query_start_loc": [0, 4], + "query_start_loc_cpu": [0, 4], + "seq_lens": [4], + "seq_lens_cpu": [4], + "num_computed_tokens_cpu": [0], + "slot_mapping": [100, 101, 102, 103], + "max_seq_len": 4, + }, + "expected": { + "prev_token_ids": [10, 11, 12, 13], + "prev_positions": [0, 1, 2, 3], + "last_token_indices": [1], + "seq_lens": [4], + "seq_lens_cpu": [4], + "num_computed_tokens_cpu": [0], + "slot_mapping": [100, 101, 102, 103], + "max_seq_len": 4, + "cached_by_req": { + "req-0": { + "len": 2, + "token_ids": [10, 11], + "positions": [0, 1], + }, + }, + }, + }, + { + "name": "layer3_shift_bounded_start_pos", + "batch_size": 1, + "running_req_ids": ["req-0"], + "target_token_ids": [10, 11, 12, 13, 14], + "target_positions": [0, 1, 2, 3, 4], + "last_token_indices": [1], + "common_attn_metadata": { + "query_start_loc": [0, 5], + "query_start_loc_cpu": [0, 5], + "seq_lens": [5], + "seq_lens_cpu": [5], + "num_computed_tokens_cpu": [1], + "slot_mapping": [100, 101, 102, 103, 104], + "max_seq_len": 5, + }, + "expected": { + "prev_token_ids": [10, 11, 12, 13, 14], + "prev_positions": [0, 1, 2, 3, 4], + "last_token_indices": [1], + "seq_lens": [5], + "seq_lens_cpu": [5], + "num_computed_tokens_cpu": [1], + "slot_mapping": [100, 101, 102, 103, 104], + "max_seq_len": 5, + "cached_by_req": { + "req-0": { + "len": 2, + "token_ids": [10, 11], + "positions": [0, 1], + }, + }, + }, + }, + { + "name": "layer3_shift2_bounded_remaining", + "batch_size": 1, + "running_req_ids": ["req-0"], + "target_token_ids": [10, 11, 12, 13, 14], + "target_positions": [0, 1, 2, 3, 4], + "last_token_indices": [2], + "common_attn_metadata": { + "query_start_loc": [0, 5], + "query_start_loc_cpu": [0, 5], + "seq_lens": [5], + "seq_lens_cpu": [5], + "num_computed_tokens_cpu": [2], + "slot_mapping": [100, 101, 102, 103, 104], + "max_seq_len": 5, + }, + "expected": { + "prev_token_ids": [10, 11, 12, 13, 14], + "prev_positions": [0, 1, 2, 3, 4], + "last_token_indices": [2], + "seq_lens": [5], + "seq_lens_cpu": [5], + "num_computed_tokens_cpu": [2], + "slot_mapping": [100, 101, 102, 103, 104], + "max_seq_len": 5, + "cached_by_req": { + "req-0": { + "len": 3, + "token_ids": [10, 11, 12], + "positions": [0, 1, 2], + }, + }, + }, + }, + { + "name": "layer3_shift3_full_cache_window", + "batch_size": 1, + "running_req_ids": ["req-0"], + "target_token_ids": [20, 21, 22, 23, 24], + "target_positions": [0, 1, 2, 3, 4], + "last_token_indices": [1], + "common_attn_metadata": { + "query_start_loc": [0, 5], + "query_start_loc_cpu": [0, 5], + "seq_lens": [5], + "seq_lens_cpu": [5], + "num_computed_tokens_cpu": [3], + "slot_mapping": [100, 101, 102, 103, 104], + "max_seq_len": 5, + }, + "expected": { + "prev_token_ids": [20, 21, 22, 23, 24], + "prev_positions": [0, 1, 2, 3, 4], + "last_token_indices": [1], + "seq_lens": [5], + "seq_lens_cpu": [5], + "num_computed_tokens_cpu": [3], + "slot_mapping": [100, 101, 102, 103, 104], + "max_seq_len": 5, + "cached_by_req": { + "req-0": { + "len": 2, + "token_ids": [20, 21], + "positions": [0, 1], + }, + }, + }, + }, + { + "name": "layer3_batch2_shift1_and1", + "batch_size": 2, + "running_req_ids": ["req-0", "req-1"], + "target_token_ids": [10, 11, 12, 13, 20, 21, 22], + "target_positions": [0, 1, 2, 3, 0, 1, 2], + "last_token_indices": [1, 5], + "common_attn_metadata": { + "query_start_loc": [0, 4, 7], + "query_start_loc_cpu": [0, 4, 7], + "seq_lens": [4, 3], + "seq_lens_cpu": [4, 3], + "num_computed_tokens_cpu": [1, 1], + "slot_mapping": [100, 101, 102, 103, 200, 201, 202], + "max_seq_len": 4, + }, + "expected": { + "prev_token_ids": [10, 11, 12, 13, 20, 21, 22], + "prev_positions": [0, 1, 2, 3, 0, 1, 2], + "last_token_indices": [1, 5], + "seq_lens": [4, 3], + "seq_lens_cpu": [4, 3], + "num_computed_tokens_cpu": [1, 1], + "slot_mapping": [100, 101, 102, 103, 200, 201, 202], + "max_seq_len": 4, + "cached_by_req": { + "req-0": { + "len": 2, + "token_ids": [10, 11], + "positions": [0, 1], + }, + "req-1": { + "len": 2, + "token_ids": [20, 21], + "positions": [0, 1], + }, + }, + }, + }, + { + "name": "layer3_batch4_mixed_shifts", + "batch_size": 4, + "running_req_ids": ["req-0", "req-1", "req-2", "req-3"], + "target_token_ids": [10, 11, 20, 21, 22, 30, 31, 32, 33, 40, 41, 42], + "target_positions": [0, 1, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2], + "last_token_indices": [1, 2, 6, 10], + "common_attn_metadata": { + "query_start_loc": [0, 2, 5, 9, 12], + "query_start_loc_cpu": [0, 2, 5, 9, 12], + "seq_lens": [2, 3, 4, 3], + "seq_lens_cpu": [2, 3, 4, 3], + "num_computed_tokens_cpu": [0, 1, 2, 1], + "slot_mapping": [ + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + ], + "max_seq_len": 4, + }, + "expected": { + "prev_token_ids": [10, 11, 911, 20, 21, 30, 31, 32, 33, 40, 41, 42], + "prev_positions": [0, 1, 1, 1, 2, 0, 1, 2, 3, 0, 1, 2], + "last_token_indices": [1, 3, 6, 10], + "seq_lens": [2, 2, 4, 3], + "seq_lens_cpu": [2, 2, 4, 3], + "num_computed_tokens_cpu": [0, 0, 2, 1], + "slot_mapping": [ + 100, + 101, + 991, + 102, + 103, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + ], + "max_seq_len": 4, + "cached_by_req": { + "req-0": { + "len": 2, + "token_ids": [10, 11], + "positions": [0, 1], + }, + "req-1": { + "len": 2, + "token_ids": [911, 20], + "positions": [1, 1], + }, + "req-2": { + "len": 2, + "token_ids": [30, 31], + "positions": [0, 1], + }, + "req-3": { + "len": 2, + "token_ids": [40, 41], + "positions": [0, 1], + }, + }, + }, + }, + { + "name": "layer3_batch2_shift0_and2", + "batch_size": 2, + "running_req_ids": ["req-0", "req-1"], + "target_token_ids": [30, 31, 32, 40, 41, 42, 43], + "target_positions": [0, 1, 2, 0, 1, 2, 3], + "last_token_indices": [2, 4], + "common_attn_metadata": { + "query_start_loc": [0, 3, 7], + "query_start_loc_cpu": [0, 3, 7], + "seq_lens": [3, 4], + "seq_lens_cpu": [3, 4], + "num_computed_tokens_cpu": [0, 2], + "slot_mapping": [100, 101, 102, 200, 201, 202, 203], + "max_seq_len": 4, + }, + "expected": { + "prev_token_ids": [30, 31, 32, 40, 41, 42, 43], + "prev_positions": [0, 1, 2, 0, 1, 2, 3], + "last_token_indices": [2, 4], + "seq_lens": [3, 4], + "seq_lens_cpu": [3, 4], + "num_computed_tokens_cpu": [0, 2], + "slot_mapping": [100, 101, 102, 200, 201, 202, 203], + "max_seq_len": 4, + "cached_by_req": { + "req-0": { + "len": 3, + "token_ids": [30, 31, 32], + "positions": [0, 1, 2], + }, + "req-1": { + "len": 2, + "token_ids": [40, 41], + "positions": [0, 1], + }, + }, + }, + }, +] + +LAYER5_CASES = [ + { + "name": "layer5_cache_window5", + "batch_size": 1, + "running_req_ids": ["req-0"], + "target_token_ids": [1, 2, 3, 4, 5, 6], + "target_positions": [0, 1, 2, 3, 4, 5], + "last_token_indices": [2], + "common_attn_metadata": { + "query_start_loc": [0, 6], + "query_start_loc_cpu": [0, 6], + "seq_lens": [6], + "seq_lens_cpu": [6], + "num_computed_tokens_cpu": [2], + "slot_mapping": [100, 101, 102, 103, 104, 105], + "max_seq_len": 6, + }, + "expected": { + "prev_token_ids": [1, 2, 3, 4, 5, 6], + "prev_positions": [0, 1, 2, 3, 4, 5], + "last_token_indices": [2], + "seq_lens": [6], + "seq_lens_cpu": [6], + "num_computed_tokens_cpu": [2], + "slot_mapping": [100, 101, 102, 103, 104, 105], + "max_seq_len": 6, + "cached_by_req": { + "req-0": { + "len": 3, + "token_ids": [1, 2, 3], + "positions": [0, 1, 2], + }, + }, + }, + }, +] + + +def _run_adjust_input_case(proposer_stub, case, layer_num): + proposer_stub.layer_num = layer_num + proposer_stub.running_req_ids = case["running_req_ids"] + meta = case["common_attn_metadata"] + common_attn_metadata = SimpleNamespace( + query_start_loc=torch.tensor(meta["query_start_loc"], dtype=torch.int32), + query_start_loc_cpu=torch.tensor(meta["query_start_loc"], dtype=torch.int32), + seq_lens=torch.tensor(meta["seq_lens"], dtype=torch.int32), + seq_lens_cpu=torch.tensor(meta["seq_lens_cpu"], dtype=torch.int32), + num_computed_tokens_cpu=torch.tensor( + meta["num_computed_tokens_cpu"], dtype=torch.int32 + ), + slot_mapping=torch.tensor(meta["slot_mapping"], dtype=torch.int32), + max_seq_len=meta["max_seq_len"], + ) + + target_token_ids = torch.tensor(case["target_token_ids"], dtype=torch.int32) + target_positions = torch.tensor(case["target_positions"], dtype=torch.int64) + target_hidden_states = torch.arange( + 0, target_token_ids.numel() * 3, dtype=torch.float32 + ).reshape(-1, 3) + last_token_indices = torch.tensor(case["last_token_indices"], dtype=torch.int32) + + prev_token_ids, prev_positions, _, _, _ = proposer_stub.adjust_input( + batch_size=case["batch_size"], + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + last_token_indices=last_token_indices, + common_attn_metadata=common_attn_metadata, + ) + + expected = case["expected"] + assert prev_token_ids.tolist() == expected["prev_token_ids"] + assert prev_positions.tolist() == expected["prev_positions"] + assert last_token_indices.tolist() == expected["last_token_indices"] + assert common_attn_metadata.seq_lens.tolist() == expected["seq_lens"] + assert common_attn_metadata.seq_lens_cpu.tolist() == expected["seq_lens_cpu"] + assert ( + common_attn_metadata.num_computed_tokens_cpu.tolist() + == expected["num_computed_tokens_cpu"] + ) + assert common_attn_metadata.slot_mapping.tolist() == expected["slot_mapping"] + assert common_attn_metadata.max_seq_len == expected["max_seq_len"] + + for req_id, cached_expect in expected["cached_by_req"].items(): + cached = proposer_stub.draft_input_states_pool[req_id] + assert cached.len == cached_expect["len"] + assert cached.token_ids.tolist() == cached_expect["token_ids"] + assert cached.positions.tolist() == cached_expect["positions"] + + +@pytest.mark.parametrize( + "case", LAYER3_CASES, ids=[case["name"] for case in LAYER3_CASES] +) +def test_adjust_input_layer3_cases(proposer_stub, case): + _run_adjust_input_case(proposer_stub, case, layer_num=3) + + +@pytest.mark.parametrize( + "case", LAYER5_CASES, ids=[case["name"] for case in LAYER5_CASES] +) +def test_adjust_input_layer5_cases(proposer_stub, case): + _run_adjust_input_case(proposer_stub, case, layer_num=5) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index c7648487866c..5db0ddd778f0 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -41,7 +41,7 @@ "longcat_flash_mtp", "mtp", "pangu_ultra_moe_mtp", - "step3p5_mtp" + "step3p5_mtp", ] EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes] SpeculativeMethod = Literal[ @@ -76,6 +76,12 @@ class SpeculativeConfig: If using `ngram` method, the related configuration `prompt_lookup_max` and `prompt_lookup_min` should be considered.""" + + enable_multi_layers_mtp: bool = False + """If set to True, the MTP method will run multiple layers of MTP + speculator. If set to False, it will run only one layer of MTP speculator. + This is only effective when the method is set to `mtp`.""" + draft_tensor_parallel_size: int | None = Field(default=None, ge=1) """The degree of the tensor parallelism for the draft model. Can only be 1 or the same as the target model's tensor parallel size.""" @@ -264,14 +270,12 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.update( {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]} ) - + if hf_config.model_type == "step3p5": hf_config.model_type = "step3p5_mtp" n_predict = getattr(hf_config, "num_nextn_predict_layers", 1) - hf_config.update( - {"n_predict": n_predict, "architectures": ["Step3p5MTP"]} - ) - + hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]}) + if initial_architecture == "MistralLarge3ForCausalLM": hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]}) @@ -404,7 +408,10 @@ def __post_init__(self): MTPModelTypes ): self.method = "mtp" - if self.num_speculative_tokens > 1: + if ( + self.enable_multi_layers_mtp is False + and self.num_speculative_tokens > 1 + ): logger.warning( "Enabling num_speculative_tokens > 1 will run" "multiple times of forward on same MTP layer" diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index 9f5c8d1433cd..a07d5ad18929 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -8,6 +8,7 @@ from torch import nn import vllm.envs as envs +from vllm.model_executor.layers.attention import Attention from vllm.v1.attention.backend import AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig diff --git a/vllm/model_executor/models/step3p5_mtp.py b/vllm/model_executor/models/step3p5_mtp.py index a7747d09e9d5..996a717e54ba 100644 --- a/vllm/model_executor/models/step3p5_mtp.py +++ b/vllm/model_executor/models/step3p5_mtp.py @@ -1,18 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.model_executor.layers.layernorm import GemmaRMSNorm @@ -23,11 +24,10 @@ class SharedHead(nn.Module): - def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.norm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) @@ -40,17 +40,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Step3p5AMultiTokenPredictorLayer(nn.Module): - def __init__( self, - config: PretrainedConfig, - prefix: str, - model_config: ModelConfig, - parallel_config: ParallelConfig = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + *, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) @@ -69,25 +70,22 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, spec_step_index: int = 0, ) -> torch.Tensor: assert inputs_embeds is not None - # masking inputs at position 0, as not needed by MTP - inputs_embeds[positions == 0] = 0 inputs_embeds = self.enorm(inputs_embeds) previous_hidden_states = self.hnorm(previous_hidden_states) hidden_states = self.eh_proj( - torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + torch.cat([inputs_embeds, previous_hidden_states], dim=-1) + ) - hidden_states = self.mtp_block(positions=positions, - hidden_states=hidden_states) + hidden_states = self.mtp_block(positions=positions, hidden_states=hidden_states) return hidden_states class Step3p5AMultiTokenPredictor(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -98,19 +96,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.mtp_start_layer_idx = config.num_hidden_layers self.num_mtp_layers = config.num_nextn_predict_layers # to map the exact layer index from weights - self.layers = torch.nn.ModuleDict({ - str(idx): - Step3p5AMultiTokenPredictorLayer( - config, - f"{prefix}.layers.{idx}", - model_config=vllm_config.model_config, - parallel_config=vllm_config.parallel_config, - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, - ) - for idx in range(self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers) - }) + self.layers = torch.nn.ModuleDict( + { + str(idx): Step3p5AMultiTokenPredictorLayer( + vllm_config=vllm_config, + prefix=f"{prefix}.layers.{idx}", + ) + for idx in range( + self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers, + ) + } + ) self.logits_processor = LogitsProcessor(config.vocab_size) @@ -119,12 +116,12 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - current_step_idx = (spec_step_idx % self.num_mtp_layers) + current_step_idx = spec_step_idx % self.num_mtp_layers return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( input_ids, positions, @@ -138,11 +135,11 @@ def compute_logits( hidden_states: torch.Tensor, spec_step_idx: int = 0, ) -> torch.Tensor: - current_step_idx = (spec_step_idx % self.num_mtp_layers) - mtp_layer = self.layers[str(self.mtp_start_layer_idx + - current_step_idx)] - logits = self.logits_processor(mtp_layer.shared_head.head, - mtp_layer.shared_head(hidden_states)) + current_step_idx = spec_step_idx % self.num_mtp_layers + mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] + logits = self.logits_processor( + mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states) + ) return logits def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -150,14 +147,13 @@ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: class Step3p5MTP(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config self.vllm_config = vllm_config - self.model = Step3p5AMultiTokenPredictor(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) + self.model = Step3p5AMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) @@ -167,23 +163,23 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, hidden_states, - inputs_embeds, spec_step_idx) + hidden_states = self.model( + input_ids, positions, hidden_states, inputs_embeds, spec_step_idx + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, spec_step_idx: int = 0, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.model.compute_logits(hidden_states, spec_step_idx) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: vllm_config = self.vllm_config config = vllm_config.model_config.hf_config @@ -211,7 +207,7 @@ def load_weights(self, weights: Iterable[tuple[str, if "embed_tokens" not in name and spec_layer is None: continue name = self._rewrite_spec_layer_name(spec_layer, name) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -221,7 +217,7 @@ def load_weights(self, weights: Iterable[tuple[str, # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue if "experts" in name or "moe" in name: continue @@ -241,40 +237,52 @@ def load_weights(self, weights: Iterable[tuple[str, continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader for expert_id in range(loaded_weight.shape[0]): loaded_weight_expert = loaded_weight[expert_id] - weight_loader(param, - loaded_weight_expert, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight_expert, + name, + shard_id=shard_id, + expert_id=expert_id, + ) loaded_params.add(name) break else: # Skip loading extra bias for GPTQ models. - if name.endswith( - ".bias" - ) and name not in params_dict or "tok_embeddings" in name: + if ( + name.endswith(".bias") + and name not in params_dict + or "tok_embeddings" in name + ): continue - if f"{config.num_hidden_layers}.transformer." in name: - name = name.replace(".transformer.", ".") + mtp_start_layer_idx = config.num_hidden_layers + num_mtp_layers = config.num_nextn_predict_layers + + for idx in range( + mtp_start_layer_idx, mtp_start_layer_idx + num_mtp_layers + ): + if f"{idx}.transformer." in name: + name = name.replace(".transformer.", ".") if "shared_head" in name: - name = name.replace("shared_head.output", - "shared_head.head") + name = name.replace("shared_head.output", "shared_head.head") if "embed_tokens" in name: - assert hasattr( - self.config, "num_nextn_predict_layers" - ) and self.config.num_nextn_predict_layers > 0 + assert ( + hasattr(self.config, "num_nextn_predict_layers") + and self.config.num_nextn_predict_layers > 0 + ) name = "model.embed_tokens.weight" param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) params_need_to_load = set(params_dict.keys()) @@ -292,7 +300,8 @@ def load_weights(self, weights: Iterable[tuple[str, missing_params = list(params_need_to_load - loaded_params) param_name_example = missing_params[0] raise RuntimeError( - f"Some parameters like {param_name_example} are not in the checkpoint and will falsely use random initialization" + f"Some parameters like {param_name_example} are not in the checkpoint" + f" and will falsely use random initialization" ) return loaded_params @@ -302,7 +311,11 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: Add .mtp_block for modules in transformer layer block for spec layer """ spec_layer_weight_names = [ - "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + "embed_tokens", + "enorm", + "hnorm", + "eh_proj", + "shared_head", ] spec_layer_weight = False for weight_name in spec_layer_weight_names: @@ -311,6 +324,7 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: break if not spec_layer_weight: # treat rest weights as weights for transformer layer block - name = name.replace(f"model.layers.{spec_layer}.", - f"model.layers.{spec_layer}.mtp_block.") + name = name.replace( + f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block." + ) return name diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index fd12dfe045a4..605c91012d3f 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -947,6 +947,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo def _get_kv_cache_groups_uniform_page_size( + vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], ) -> list[KVCacheGroupSpec]: """ @@ -1061,8 +1062,18 @@ def _get_kv_cache_groups_uniform_page_size( # the same and will cause memory waste. # To avoid this, we assign layers[i::num_groups] to the i-th group # instead of layers[i * group_size: (i + 1) * group_size] - for i in range(num_groups): - grouped_layers.append(layers[i::num_groups]) + + # for support multi layer mtp, we need to + # make all mtp layers in the same group + if ( + vllm_config.speculative_config is not None + and vllm_config.speculative_config.enable_multi_layers_mtp + ): + for i in range(0, len(layers), group_size): + grouped_layers.append(layers[i : i + group_size]) + else: + for i in range(num_groups): + grouped_layers.append(layers[i::num_groups]) return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) @@ -1247,7 +1258,9 @@ def get_kv_cache_groups( # have the same physical memory per block per layer. Split the layers # into groups with the same number of layers, and thus same total page # size. - return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) + return _get_kv_cache_groups_uniform_page_size( + vllm_config=vllm_config, kv_cache_spec=kv_cache_spec + ) def generate_scheduler_kv_cache_config( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 43b84f4be8a2..0e50e0eb9f72 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -381,8 +381,11 @@ def propose( input_ids = None inputs_embeds = self.inputs_embeds[:num_input_tokens] else: + self.inputs_embeds[:num_tokens] = self.model.embed_input_ids( + self.input_ids[:num_tokens], + ) input_ids = self.input_ids[:num_input_tokens] - inputs_embeds = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] model_kwargs = { "input_ids": input_ids, @@ -575,12 +578,12 @@ def propose( self.hidden_states[:batch_size] = hidden_states if self.supports_mm_inputs: self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids) - input_ids = None inputs_embeds = self.inputs_embeds[:input_batch_size] else: + self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids) input_ids = self.input_ids[:input_batch_size] - inputs_embeds = None + inputs_embeds = self.inputs_embeds[:input_batch_size] # Run the model. model_kwargs = { @@ -1326,7 +1329,7 @@ def dummy_run( inputs_embeds = self.inputs_embeds[:num_input_tokens] else: input_ids = self.input_ids[:num_input_tokens] - inputs_embeds = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] kwargs = dict( input_ids=input_ids, diff --git a/vllm/v1/spec_decode/multi_layer_eagle.py b/vllm/v1/spec_decode/multi_layer_eagle.py new file mode 100644 index 000000000000..9731d8bcbe9c --- /dev/null +++ b/vllm/v1/spec_decode/multi_layer_eagle.py @@ -0,0 +1,475 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import torch + +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.v1.attention.backend import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) +from vllm.v1.attention.backends.tree_attn import TreeAttentionMetadata +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.eagle import EagleProposer + +logger = init_logger(__name__) + +PADDING_SLOT_ID = -1 + + +class DraftInputStates: + def __init__( + self, + len: int, + token_ids: torch.Tensor, + hidden_states: torch.Tensor, + positions: torch.Tensor, + slot_mapping: torch.Tensor, + ): + self.len = len + self.token_ids = token_ids + self.hidden_states = hidden_states + self.positions = positions + self.slot_mapping = slot_mapping + + +class MultiLayerEagleProposer(EagleProposer): + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + super().__init__(vllm_config, device, runner) + + self.layer_num: int = getattr( + self.speculative_config.draft_model_config.hf_text_config, "n_predict", 0 + ) + self.num_speculative_tokens: int = ( + self.speculative_config.num_speculative_tokens + ) + if self.num_speculative_tokens != self.layer_num: + logger.warning_once( + "For multi_layer_eagle, num_speculative_tokens " + "does not match layer_num, adjusting to layer_num" + ) + self.num_speculative_tokens = self.layer_num + self.running_req_ids: list[str] | None = None + self.draft_input_states_pool: dict[str, DraftInputStates] = {} + + def set_running_req_ids(self, req_ids: list[str]): + self.running_req_ids = req_ids + + def _get_draft_input_states(self, req_id: str, len: int) -> DraftInputStates: + draft_input_states = self.draft_input_states_pool.get(req_id, None) + assert draft_input_states is not None + assert draft_input_states.len >= len + return draft_input_states + + def clean_req_cache(self, req_id: str): + self.draft_input_states_pool.pop(req_id, None) + + def adjust_input( + self, + batch_size: int, + target_token_ids: torch.Tensor, + target_positions: torch.Tensor, + target_hidden_states: torch.Tensor, + last_token_indices: torch.Tensor, + common_attn_metadata: CommonAttentionMetadata, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any, dict[str, Any]]: + start_token_indices = common_attn_metadata.query_start_loc[:-1] + start_token_pos = target_positions[start_token_indices] + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + query_start_loc_cpu_np = query_start_loc_cpu.numpy() + start_token_indices_cpu = query_start_loc_cpu_np[:-1] + end_token_indices_cpu = query_start_loc_cpu_np[1:] - 1 + last_token_indices_cpu = last_token_indices.cpu().numpy() + start_token_pos_cpu = start_token_pos.cpu().numpy() + + prev_token_ids = target_token_ids + prev_positions = target_positions + prev_hidden_states = target_hidden_states + + for i in range(batch_size): + last_token_index: int = int(last_token_indices_cpu[i]) + start_token_index: int = int(start_token_indices_cpu[i]) + end_token_index: int = int(end_token_indices_cpu[i]) + start_pos: int = int(start_token_pos_cpu[i]) + assert self.running_req_ids is not None + req_id = self.running_req_ids[i] + shift = min(end_token_index - last_token_index, start_pos) + + modify_last_token_index = last_token_index + if shift > 0: + + def shift_input( + input: torch.Tensor, + cached: torch.Tensor, + start_token_index: int = start_token_index, + end_token_index: int = end_token_index, + shift: int = shift, + ) -> torch.Tensor: + window_len = end_token_index - start_token_index + 1 + dest = input.narrow( + 0, start_token_index + shift, window_len - shift + ) + # clone is used to ensure correctness in the case of + # overlap between src and dest + src = input.narrow(0, start_token_index, window_len - shift).clone() + dest.copy_(src) + head = input.narrow(0, start_token_index, shift) + head.copy_(cached[-shift:]) + return input + + cached_input_state = self._get_draft_input_states(req_id, shift) + prev_token_ids = shift_input( + prev_token_ids, cached_input_state.token_ids + ) + prev_positions = shift_input( + prev_positions, cached_input_state.positions + ) + prev_hidden_states = shift_input( + prev_hidden_states, cached_input_state.hidden_states + ) + common_attn_metadata.slot_mapping = shift_input( + common_attn_metadata.slot_mapping, cached_input_state.slot_mapping + ) + common_attn_metadata.seq_lens[i] -= shift + common_attn_metadata.num_computed_tokens_cpu[i] -= shift + common_attn_metadata.seq_lens_cpu[i] -= shift + + modify_last_token_index = last_token_index + shift + last_token_indices[i] += shift + + cache_start_index = max( + start_token_index, modify_last_token_index + 1 - self.layer_num + ) + + self.draft_input_states_pool[req_id] = DraftInputStates( + len=modify_last_token_index + 1 - cache_start_index, + token_ids=prev_token_ids[ + cache_start_index : modify_last_token_index + 1 + ].clone(), + hidden_states=prev_hidden_states[ + cache_start_index : modify_last_token_index + 1 + ].clone(), + positions=prev_positions[ + cache_start_index : modify_last_token_index + 1 + ].clone(), + slot_mapping=common_attn_metadata.slot_mapping[ + cache_start_index : modify_last_token_index + 1 + ].clone(), + ) + + common_attn_metadata.max_seq_len = torch.max( + common_attn_metadata.seq_lens + ).item() + + if self.attn_metadata_builder is None: + attn_metadata_builder = self._get_attention_metadata_builder() + else: + attn_metadata_builder = self.attn_metadata_builder + + assert isinstance(attn_metadata_builder, AttentionMetadataBuilder) + + attn_metadata = attn_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, draft_index=0 + ) + + # FIXME: support hybrid kv for draft model (remove separate indexer) + if self.draft_indexer_metadata_builder: + draft_indexer_metadata = ( + self.draft_indexer_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, + draft_index=0, + ) + ) + else: + draft_indexer_metadata = None + + # At this moment, we assume all eagle layers belong to the same KV + # cache group, thus using the same attention metadata. + per_layer_attn_metadata = {} + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + for layer_name in self.indexer_layer_names: + assert draft_indexer_metadata is not None + per_layer_attn_metadata[layer_name] = draft_indexer_metadata + + return ( + prev_token_ids, + prev_positions, + prev_hidden_states, + attn_metadata, + per_layer_attn_metadata, + ) + + def initial_inputs_for_forward( + self, + num_tokens: int, + prev_token_ids: torch.Tensor, + prev_positions: torch.Tensor, + prev_hidden_states: torch.Tensor, + next_token_ids: torch.Tensor, + last_token_indices: torch.Tensor, + spec_step_idx: int = 0, + mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, + ): + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[: num_tokens - 1] = prev_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[last_token_indices] = next_token_ids + self._set_positions(num_tokens, prev_positions) + self.hidden_states[:num_tokens] = prev_hidden_states[:num_tokens] + if self.supports_mm_inputs: + mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) + + self.inputs_embeds[:num_tokens] = self.model.embed_input_ids( + self.input_ids[:num_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, + ) + else: + self.inputs_embeds[:num_tokens] = self.model.embed_input_ids( + self.input_ids[:num_tokens], + ) + + def draft_model_forward( + self, + num_tokens: int, + per_layer_attn_metadata: dict[str, Any], + last_token_indices: torch.Tensor, + sampling_metadata: SamplingMetadata, + common_attn_metadata: CommonAttentionMetadata, + spec_step_idx: int = 0, + ): + num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp( + num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens + ) + + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + num_tokens_dp_padded + ) + num_input_tokens = batch_desc.num_tokens + + if num_tokens_across_dp is not None: + num_tokens_across_dp[self.dp_rank] = num_input_tokens + + if self.supports_mm_inputs: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] + else: + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = self.inputs_embeds[:num_input_tokens] + + model_kwargs = { + "input_ids": input_ids, + "positions": self._get_positions(num_input_tokens), + "hidden_states": self.hidden_states[:num_input_tokens], + "inputs_embeds": inputs_embeds, + "spec_step_idx": spec_step_idx, + } + + with set_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + slot_mapping=self._get_slot_mapping( + num_input_tokens, common_attn_metadata.slot_mapping + ), + ): + last_hidden_states = self.model(**model_kwargs) + + sample_hidden_states = last_hidden_states[last_token_indices] + logits = self.model.compute_logits( + sample_hidden_states, spec_step_idx=spec_step_idx + ) + + draft_token_ids = logits.argmax(dim=-1) + + return draft_token_ids, last_hidden_states + + def propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] or [3, num_tokens] when M-RoPE is enabled + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + last_token_indices: torch.Tensor | None, + common_attn_metadata: CommonAttentionMetadata, + sampling_metadata: SamplingMetadata, + mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, + num_rejected_tokens_gpu: torch.Tensor | None = None, + slot_mappings: dict[str, torch.Tensor] + | list[dict[str, torch.Tensor]] + | None = None, + ) -> torch.Tensor: + assert self.method == "mtp" + assert self.runner is not None + assert target_positions.dim() == 1, ( + "MultiLayerEagleProposer does not support M-RoPE yet; " + f"got target_positions with shape {tuple(target_positions.shape)}" + ) + + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + + if last_token_indices is None: + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + + ( + prev_token_ids, + prev_positions, + prev_hidden_states, + attn_metadata, + per_layer_attn_metadata, + ) = self.adjust_input( + batch_size=batch_size, + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + last_token_indices=last_token_indices, + common_attn_metadata=common_attn_metadata, + ) + + if isinstance(attn_metadata, TreeAttentionMetadata): + raise NotImplementedError( + "Tree attention is not supported for multi layer eagle." + ) + + if self.allowed_attn_types is not None and not isinstance( + attn_metadata, self.allowed_attn_types + ): + raise ValueError( + f"Unsupported attention metadata type for speculative " + "decoding for multi layer eagle: " + f"{type(attn_metadata)}. Supported types are: " + f"{self.allowed_attn_types}" + ) + + # Generate the remaining draft tokens. + draft_token_ids_list: list[torch.Tensor] = [] + + for token_index in range(self.num_speculative_tokens): + if token_index != 0: + prev_token_ids = self.input_ids[:num_tokens].clone() + next_token_ids = draft_token_ids_list[-1].int() + + self.initial_inputs_for_forward( + num_tokens=num_tokens, + prev_token_ids=prev_token_ids, + prev_positions=prev_positions, + prev_hidden_states=prev_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=last_token_indices, + spec_step_idx=token_index, + mm_embed_inputs=mm_embed_inputs, + ) + + draft_token_ids, prev_hidden_states = self.draft_model_forward( + num_tokens=num_tokens, + per_layer_attn_metadata=per_layer_attn_metadata, + last_token_indices=last_token_indices, + sampling_metadata=sampling_metadata, + common_attn_metadata=common_attn_metadata, + spec_step_idx=token_index, + ) + + # Early exit if there is only one draft token to be generated. + if self.num_speculative_tokens == 1: + return draft_token_ids.view(-1, 1) + + draft_token_ids_list.append(draft_token_ids) + + # [batch_size, num_speculative_tokens] + draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + + return draft_token_ids + + def prepare_inputs( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: list[list[int]], + num_draft_tokens: list[int], + ) -> tuple[CommonAttentionMetadata, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. + It updates to the common_attn_metadata to account for the rejected + tokens (and newly sampled tokens). It also returns the token indices + of the tokens that should be fed to the speculator. + """ + raise Exception( + "speculative_config.disable_padded_drafter_batch" + " is not supported now for MultiLayerEagleProposer." + ) + + @torch.inference_mode() + def dummy_run( + self, + num_tokens: int, + use_cudagraphs: bool = True, + is_graph_capturing: bool = False, + slot_mappings: dict[str, torch.Tensor] | None = None, + ) -> None: + num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp( + num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens + ) + if use_cudagraphs: + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + num_tokens_dp_padded + ) + num_input_tokens = batch_desc.num_tokens + else: + cudagraph_runtime_mode = CUDAGraphMode.NONE + num_input_tokens = num_tokens_dp_padded + if num_tokens_across_dp is not None: + num_tokens_across_dp[self.dp_rank] = num_input_tokens + + # Make sure to use EAGLE's own buffer during cudagraph capture. + if ( + self.attn_layer_names + and slot_mappings is not None + and self.attn_layer_names[0] in slot_mappings + ): + slot_mapping_dict = self._get_slot_mapping(num_input_tokens) + else: + slot_mapping_dict = slot_mappings or {} + + for fwd_idx in range(self.layer_num): + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + slot_mapping=slot_mapping_dict, + ): + if self.supports_mm_inputs: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] + else: + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = self.inputs_embeds[:num_input_tokens] + + model_kwargs = { + "input_ids": input_ids, + "positions": self._get_positions(num_input_tokens), + "hidden_states": self.hidden_states[:num_input_tokens], + "inputs_embeds": inputs_embeds, + "spec_step_idx": fwd_idx, + } + + self.model(**model_kwargs) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 061ac8680157..f01604073dfa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -155,6 +155,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.multi_layer_eagle import MultiLayerEagleProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext @@ -463,7 +464,15 @@ def __init__( elif self.speculative_config.method == "suffix": self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, self.device, self) + if ( + self.speculative_config.enable_multi_layers_mtp + and self.speculative_config.method == "mtp" + ): + self.drafter = MultiLayerEagleProposer( + self.vllm_config, self.device, self + ) + else: + self.drafter = EagleProposer(self.vllm_config, self.device, self) if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = ( self.drafter.eagle3_use_aux_hidden_state @@ -884,6 +893,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) self.num_prompt_logprobs.pop(req_id, None) + if hasattr(self, "drafter") and isinstance( + self.drafter, MultiLayerEagleProposer + ): + self.drafter.clean_req_cache(req_id) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -4078,6 +4091,9 @@ def propose_draft_token_ids( else: mm_embed_inputs = None + if isinstance(self.drafter, MultiLayerEagleProposer): + self.drafter.set_running_req_ids(self.input_batch.req_ids) + draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, From b247624a2e7647e8f355289902c4007bf155c46d Mon Sep 17 00:00:00 2001 From: csy0225 Date: Wed, 28 Jan 2026 22:38:04 +0800 Subject: [PATCH 08/34] fix: FlashInferExperts.apply() got an unexpected keyword argument 'activation_limit' --- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 7c27da46fee5..cff0822207c6 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -210,6 +210,7 @@ def apply( workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool | None, + activation_limit: float | None = None ): from flashinfer.fused_moe.core import ActivationType From 45eeaf8163a2ee0d100e7300e534630ca8d6d802 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Thu, 29 Jan 2026 10:44:07 +0800 Subject: [PATCH 09/34] fix: step3p5 reasoning parser error --- vllm/model_executor/layers/activation.py | 13 +++++-------- vllm/reasoning/step3p5_reasoning_parser.py | 1 - 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index b41278fb9fbc..2a623ede3d39 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -34,14 +34,11 @@ def swiglustep_and_mul_out( """ # Prefer the fused custom op when available (CUDA); fallback to PyTorch ops # otherwise. - if x.is_cuda and hasattr(torch.ops._C, "swiglustep_and_mul"): - torch.ops._C.swiglustep_and_mul(out, x, limit) - else: - gate, up = x.chunk(2, dim=-1) - gate = F.silu(gate) - gate = gate.clamp(max=limit) - up = up.clamp(min=-limit, max=limit) - out.copy_(gate * up) + gate, up = x.chunk(2, dim=-1) + gate = F.silu(gate) + gate = gate.clamp(max=limit) + up = up.clamp(min=-limit, max=limit) + out.copy_(gate * up) return out diff --git a/vllm/reasoning/step3p5_reasoning_parser.py b/vllm/reasoning/step3p5_reasoning_parser.py index c1c5a5123e43..93aa7f5ee08d 100644 --- a/vllm/reasoning/step3p5_reasoning_parser.py +++ b/vllm/reasoning/step3p5_reasoning_parser.py @@ -122,7 +122,6 @@ def extract_reasoning_streaming( # Content: handle the newline immediately after . if content_to_output is not None: - self.end_offset -= 1 # If we have content, reasoning must have ended. self._pending_reasoning_newline = False From f911ab0e4fcb3e76f63097d669108d57362ded5a Mon Sep 17 00:00:00 2001 From: xiewuxun Date: Thu, 29 Jan 2026 04:01:52 +0000 Subject: [PATCH 10/34] Revert "feat: support mtp3" --- examples/offline_inference/spec_decode.py | 15 +- tests/v1/spec_decode/test_mtp3.py | 661 ---------------------- vllm/config/speculative.py | 21 +- vllm/model_executor/models/step3p5.py | 1 - vllm/model_executor/models/step3p5_mtp.py | 164 +++--- vllm/v1/core/kv_cache_utils.py | 19 +- vllm/v1/spec_decode/eagle.py | 11 +- vllm/v1/spec_decode/multi_layer_eagle.py | 475 ---------------- vllm/v1/worker/gpu_model_runner.py | 18 +- 9 files changed, 91 insertions(+), 1294 deletions(-) delete mode 100644 tests/v1/spec_decode/test_mtp3.py delete mode 100644 vllm/v1/spec_decode/multi_layer_eagle.py diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index d890c1e6c673..1e3e310e7fa5 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -62,11 +62,6 @@ def parse_args(): parser.add_argument("--tp", type=int, default=1) parser.add_argument("--enforce-eager", action="store_true") parser.add_argument("--enable-chunked-prefill", action="store_true") - parser.add_argument( - "--enable-multi-layers-mtp", - action="store_true", - help="Enable multi-layer MTP (only effective when --method=mtp).", - ) parser.add_argument("--max-model-len", type=int, default=16384) parser.add_argument("--temp", type=float, default=0) parser.add_argument("--top-p", type=float, default=1.0) @@ -76,7 +71,6 @@ def parse_args(): parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) parser.add_argument("--draft-model", type=str, default=None) - parser.add_argument("--tokenizer-dir", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) parser.add_argument("--disable-padded-drafter-batch", action="store_true") @@ -96,12 +90,7 @@ def main(args): "please specify model_dir to give a mm based model" ) model_dir = "meta-llama/Llama-3.1-8B-Instruct" - - tokenizer_dir = args.tokenizer_dir - if tokenizer_dir is None: - tokenizer_dir = model_dir - - tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) + tokenizer = AutoTokenizer.from_pretrained(model_dir) args.custom_skip_chat_template = True if not args.custom_mm_prompts: @@ -149,8 +138,6 @@ def main(args): "method": "mtp", "num_speculative_tokens": args.num_spec_tokens, } - if args.enable_multi_layers_mtp: - speculative_config["enable_multi_layers_mtp"] = True else: raise ValueError(f"unknown method: {args.method}") diff --git a/tests/v1/spec_decode/test_mtp3.py b/tests/v1/spec_decode/test_mtp3.py deleted file mode 100644 index 5eff1f6a3075..000000000000 --- a/tests/v1/spec_decode/test_mtp3.py +++ /dev/null @@ -1,661 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from types import SimpleNamespace - -import pytest -import torch - -from vllm.v1.attention.backend import AttentionMetadataBuilder -from vllm.v1.spec_decode.multi_layer_eagle import ( - DraftInputStates, - MultiLayerEagleProposer, -) - - -class DummyBuilder(AttentionMetadataBuilder): - def __init__(self, return_value: str): - # attention metadata builders normally take multiple runtime args; - # the test double shortcuts that setup. - self.return_value = return_value - self.calls: list[dict] = [] - self.kv_cache_spec = None - self.layer_names: list[str] = [] - self.vllm_config = None - self.device = torch.device("cpu") - - def build( - self, common_prefix_len: int, common_attn_metadata, fast_build: bool = False - ): - self.calls.append( - { - "common_prefix_len": common_prefix_len, - "common_attn_metadata": common_attn_metadata, - "fast_build": fast_build, - } - ) - return self.return_value - - -@pytest.fixture -def proposer_stub(): - proposer = MultiLayerEagleProposer.__new__(MultiLayerEagleProposer) - proposer.layer_num = 3 - proposer.running_req_ids = ["req-0"] - proposer.attn_layer_names = ["attn_layer"] - proposer.indexer_layer_names = ["indexer_layer"] - proposer.attn_metadata_builder = DummyBuilder("attn_meta") - proposer.draft_indexer_metadata_builder = DummyBuilder("indexer_meta") - proposer.draft_input_states_pool = { - "req-0": DraftInputStates( - len=3, - token_ids=torch.tensor([800, 801, 802], dtype=torch.int32), - hidden_states=torch.tensor( - [[30.0, 31.0, 32.0], [33.0, 34.0, 35.0], [36.0, 37.0, 38.0]] - ), - positions=torch.tensor([0, 1, 2], dtype=torch.int64), - slot_mapping=torch.tensor([900, 901, 902], dtype=torch.int32), - ), - "req-1": DraftInputStates( - len=2, - token_ids=torch.tensor([910, 911], dtype=torch.int32), - hidden_states=torch.tensor([[40.0, 41.0, 42.0], [43.0, 44.0, 45.0]]), - positions=torch.tensor([0, 1], dtype=torch.int64), - slot_mapping=torch.tensor([990, 991], dtype=torch.int32), - ), - "req-2": DraftInputStates( - len=3, - token_ids=torch.tensor([820, 821, 822], dtype=torch.int32), - hidden_states=torch.tensor( - [[46.0, 47.0, 48.0], [49.0, 50.0, 51.0], [52.0, 53.0, 54.0]] - ), - positions=torch.tensor([0, 1, 2], dtype=torch.int64), - slot_mapping=torch.tensor([920, 921, 922], dtype=torch.int32), - ), - "req-3": DraftInputStates( - len=3, - token_ids=torch.tensor([830, 831, 832], dtype=torch.int32), - hidden_states=torch.tensor( - [[55.0, 56.0, 57.0], [58.0, 59.0, 60.0], [61.0, 62.0, 63.0]] - ), - positions=torch.tensor([0, 1, 2], dtype=torch.int64), - slot_mapping=torch.tensor([930, 931, 932], dtype=torch.int32), - ), - } - return proposer - - -LAYER3_CASES = [ - { - "name": "layer3_shift0_sequence_end", - "batch_size": 1, - "running_req_ids": ["req-0"], - "target_token_ids": [10, 11, 12, 13], - "target_positions": [0, 1, 2, 3], - "last_token_indices": [3], - "common_attn_metadata": { - "query_start_loc": [0, 4], - "query_start_loc_cpu": [0, 4], - "seq_lens": [4], - "seq_lens_cpu": [4], - "num_computed_tokens_cpu": [0], - "slot_mapping": [100, 101, 102, 103], - "max_seq_len": 4, - }, - "expected": { - "prev_token_ids": [10, 11, 12, 13], - "prev_positions": [0, 1, 2, 3], - "last_token_indices": [3], - "seq_lens": [4], - "seq_lens_cpu": [4], - "num_computed_tokens_cpu": [0], - "slot_mapping": [100, 101, 102, 103], - "max_seq_len": 4, - "cached_by_req": { - "req-0": { - "len": 3, - "token_ids": [11, 12, 13], - "positions": [1, 2, 3], - }, - }, - }, - }, - { - "name": "layer3_batch2_short_seq_no_shift", - "batch_size": 2, - "running_req_ids": ["req-0", "req-1"], - "target_token_ids": [10, 11, 20], - "target_positions": [0, 1, 0], - "last_token_indices": [1, 2], - "common_attn_metadata": { - "query_start_loc": [0, 2, 3], - "query_start_loc_cpu": [0, 2, 3], - "seq_lens": [2, 1], - "seq_lens_cpu": [2, 1], - "num_computed_tokens_cpu": [0, 0], - "slot_mapping": [100, 101, 200], - "max_seq_len": 2, - }, - "expected": { - "prev_token_ids": [10, 11, 20], - "prev_positions": [0, 1, 0], - "last_token_indices": [1, 2], - "seq_lens": [2, 1], - "seq_lens_cpu": [2, 1], - "num_computed_tokens_cpu": [0, 0], - "slot_mapping": [100, 101, 200], - "max_seq_len": 2, - "cached_by_req": { - "req-0": { - "len": 2, - "token_ids": [10, 11], - "positions": [0, 1], - }, - "req-1": { - "len": 1, - "token_ids": [20], - "positions": [0], - }, - }, - }, - }, - { - "name": "layer3_batch2_short_seq_shift_first", - "batch_size": 2, - "running_req_ids": ["req-0", "req-1"], - "target_token_ids": [10, 11, 20], - "target_positions": [1, 2, 0], - "last_token_indices": [0, 2], - "common_attn_metadata": { - "query_start_loc": [0, 2, 3], - "query_start_loc_cpu": [0, 2, 3], - "seq_lens": [2, 1], - "seq_lens_cpu": [2, 1], - "num_computed_tokens_cpu": [1, 0], - "slot_mapping": [100, 101, 200], - "max_seq_len": 2, - }, - "expected": { - "prev_token_ids": [802, 10, 20], - "prev_positions": [2, 1, 0], - "last_token_indices": [1, 2], - "seq_lens": [1, 1], - "seq_lens_cpu": [1, 1], - "num_computed_tokens_cpu": [0, 0], - "slot_mapping": [902, 100, 200], - "max_seq_len": 1, - "cached_by_req": { - "req-0": { - "len": 2, - "token_ids": [802, 10], - "positions": [2, 1], - }, - "req-1": { - "len": 1, - "token_ids": [20], - "positions": [0], - }, - }, - }, - }, - { - "name": "layer3_short_seq_len2_shift0_cache1", - "batch_size": 1, - "running_req_ids": ["req-0"], - "target_token_ids": [7, 8], - "target_positions": [0, 1], - "last_token_indices": [0], - "common_attn_metadata": { - "query_start_loc": [0, 2], - "query_start_loc_cpu": [0, 2], - "seq_lens": [2], - "seq_lens_cpu": [2], - "num_computed_tokens_cpu": [0], - "slot_mapping": [1000, 1001], - "max_seq_len": 2, - }, - "expected": { - "prev_token_ids": [7, 8], - "prev_positions": [0, 1], - "last_token_indices": [0], - "seq_lens": [2], - "seq_lens_cpu": [2], - "num_computed_tokens_cpu": [0], - "slot_mapping": [1000, 1001], - "max_seq_len": 2, - "cached_by_req": { - "req-0": { - "len": 1, - "token_ids": [7], - "positions": [0], - }, - }, - }, - }, - { - "name": "layer3_short_seq_len2_shift1_cache2", - "batch_size": 1, - "running_req_ids": ["req-0"], - "target_token_ids": [7, 8], - "target_positions": [1, 2], - "last_token_indices": [0], - "common_attn_metadata": { - "query_start_loc": [0, 2], - "query_start_loc_cpu": [0, 2], - "seq_lens": [2], - "seq_lens_cpu": [2], - "num_computed_tokens_cpu": [1], - "slot_mapping": [1000, 1001], - "max_seq_len": 2, - }, - "expected": { - "prev_token_ids": [802, 7], - "prev_positions": [2, 1], - "last_token_indices": [1], - "seq_lens": [1], - "seq_lens_cpu": [1], - "num_computed_tokens_cpu": [0], - "slot_mapping": [902, 1000], - "max_seq_len": 1, - "cached_by_req": { - "req-0": { - "len": 2, - "token_ids": [802, 7], - "positions": [2, 1], - }, - }, - }, - }, - { - "name": "layer3_shift_bounded_start_pos0", - "batch_size": 1, - "running_req_ids": ["req-0"], - "target_token_ids": [10, 11, 12, 13], - "target_positions": [0, 1, 2, 3], - "last_token_indices": [1], - "common_attn_metadata": { - "query_start_loc": [0, 4], - "query_start_loc_cpu": [0, 4], - "seq_lens": [4], - "seq_lens_cpu": [4], - "num_computed_tokens_cpu": [0], - "slot_mapping": [100, 101, 102, 103], - "max_seq_len": 4, - }, - "expected": { - "prev_token_ids": [10, 11, 12, 13], - "prev_positions": [0, 1, 2, 3], - "last_token_indices": [1], - "seq_lens": [4], - "seq_lens_cpu": [4], - "num_computed_tokens_cpu": [0], - "slot_mapping": [100, 101, 102, 103], - "max_seq_len": 4, - "cached_by_req": { - "req-0": { - "len": 2, - "token_ids": [10, 11], - "positions": [0, 1], - }, - }, - }, - }, - { - "name": "layer3_shift_bounded_start_pos", - "batch_size": 1, - "running_req_ids": ["req-0"], - "target_token_ids": [10, 11, 12, 13, 14], - "target_positions": [0, 1, 2, 3, 4], - "last_token_indices": [1], - "common_attn_metadata": { - "query_start_loc": [0, 5], - "query_start_loc_cpu": [0, 5], - "seq_lens": [5], - "seq_lens_cpu": [5], - "num_computed_tokens_cpu": [1], - "slot_mapping": [100, 101, 102, 103, 104], - "max_seq_len": 5, - }, - "expected": { - "prev_token_ids": [10, 11, 12, 13, 14], - "prev_positions": [0, 1, 2, 3, 4], - "last_token_indices": [1], - "seq_lens": [5], - "seq_lens_cpu": [5], - "num_computed_tokens_cpu": [1], - "slot_mapping": [100, 101, 102, 103, 104], - "max_seq_len": 5, - "cached_by_req": { - "req-0": { - "len": 2, - "token_ids": [10, 11], - "positions": [0, 1], - }, - }, - }, - }, - { - "name": "layer3_shift2_bounded_remaining", - "batch_size": 1, - "running_req_ids": ["req-0"], - "target_token_ids": [10, 11, 12, 13, 14], - "target_positions": [0, 1, 2, 3, 4], - "last_token_indices": [2], - "common_attn_metadata": { - "query_start_loc": [0, 5], - "query_start_loc_cpu": [0, 5], - "seq_lens": [5], - "seq_lens_cpu": [5], - "num_computed_tokens_cpu": [2], - "slot_mapping": [100, 101, 102, 103, 104], - "max_seq_len": 5, - }, - "expected": { - "prev_token_ids": [10, 11, 12, 13, 14], - "prev_positions": [0, 1, 2, 3, 4], - "last_token_indices": [2], - "seq_lens": [5], - "seq_lens_cpu": [5], - "num_computed_tokens_cpu": [2], - "slot_mapping": [100, 101, 102, 103, 104], - "max_seq_len": 5, - "cached_by_req": { - "req-0": { - "len": 3, - "token_ids": [10, 11, 12], - "positions": [0, 1, 2], - }, - }, - }, - }, - { - "name": "layer3_shift3_full_cache_window", - "batch_size": 1, - "running_req_ids": ["req-0"], - "target_token_ids": [20, 21, 22, 23, 24], - "target_positions": [0, 1, 2, 3, 4], - "last_token_indices": [1], - "common_attn_metadata": { - "query_start_loc": [0, 5], - "query_start_loc_cpu": [0, 5], - "seq_lens": [5], - "seq_lens_cpu": [5], - "num_computed_tokens_cpu": [3], - "slot_mapping": [100, 101, 102, 103, 104], - "max_seq_len": 5, - }, - "expected": { - "prev_token_ids": [20, 21, 22, 23, 24], - "prev_positions": [0, 1, 2, 3, 4], - "last_token_indices": [1], - "seq_lens": [5], - "seq_lens_cpu": [5], - "num_computed_tokens_cpu": [3], - "slot_mapping": [100, 101, 102, 103, 104], - "max_seq_len": 5, - "cached_by_req": { - "req-0": { - "len": 2, - "token_ids": [20, 21], - "positions": [0, 1], - }, - }, - }, - }, - { - "name": "layer3_batch2_shift1_and1", - "batch_size": 2, - "running_req_ids": ["req-0", "req-1"], - "target_token_ids": [10, 11, 12, 13, 20, 21, 22], - "target_positions": [0, 1, 2, 3, 0, 1, 2], - "last_token_indices": [1, 5], - "common_attn_metadata": { - "query_start_loc": [0, 4, 7], - "query_start_loc_cpu": [0, 4, 7], - "seq_lens": [4, 3], - "seq_lens_cpu": [4, 3], - "num_computed_tokens_cpu": [1, 1], - "slot_mapping": [100, 101, 102, 103, 200, 201, 202], - "max_seq_len": 4, - }, - "expected": { - "prev_token_ids": [10, 11, 12, 13, 20, 21, 22], - "prev_positions": [0, 1, 2, 3, 0, 1, 2], - "last_token_indices": [1, 5], - "seq_lens": [4, 3], - "seq_lens_cpu": [4, 3], - "num_computed_tokens_cpu": [1, 1], - "slot_mapping": [100, 101, 102, 103, 200, 201, 202], - "max_seq_len": 4, - "cached_by_req": { - "req-0": { - "len": 2, - "token_ids": [10, 11], - "positions": [0, 1], - }, - "req-1": { - "len": 2, - "token_ids": [20, 21], - "positions": [0, 1], - }, - }, - }, - }, - { - "name": "layer3_batch4_mixed_shifts", - "batch_size": 4, - "running_req_ids": ["req-0", "req-1", "req-2", "req-3"], - "target_token_ids": [10, 11, 20, 21, 22, 30, 31, 32, 33, 40, 41, 42], - "target_positions": [0, 1, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2], - "last_token_indices": [1, 2, 6, 10], - "common_attn_metadata": { - "query_start_loc": [0, 2, 5, 9, 12], - "query_start_loc_cpu": [0, 2, 5, 9, 12], - "seq_lens": [2, 3, 4, 3], - "seq_lens_cpu": [2, 3, 4, 3], - "num_computed_tokens_cpu": [0, 1, 2, 1], - "slot_mapping": [ - 100, - 101, - 102, - 103, - 104, - 105, - 106, - 107, - 108, - 109, - 110, - 111, - ], - "max_seq_len": 4, - }, - "expected": { - "prev_token_ids": [10, 11, 911, 20, 21, 30, 31, 32, 33, 40, 41, 42], - "prev_positions": [0, 1, 1, 1, 2, 0, 1, 2, 3, 0, 1, 2], - "last_token_indices": [1, 3, 6, 10], - "seq_lens": [2, 2, 4, 3], - "seq_lens_cpu": [2, 2, 4, 3], - "num_computed_tokens_cpu": [0, 0, 2, 1], - "slot_mapping": [ - 100, - 101, - 991, - 102, - 103, - 105, - 106, - 107, - 108, - 109, - 110, - 111, - ], - "max_seq_len": 4, - "cached_by_req": { - "req-0": { - "len": 2, - "token_ids": [10, 11], - "positions": [0, 1], - }, - "req-1": { - "len": 2, - "token_ids": [911, 20], - "positions": [1, 1], - }, - "req-2": { - "len": 2, - "token_ids": [30, 31], - "positions": [0, 1], - }, - "req-3": { - "len": 2, - "token_ids": [40, 41], - "positions": [0, 1], - }, - }, - }, - }, - { - "name": "layer3_batch2_shift0_and2", - "batch_size": 2, - "running_req_ids": ["req-0", "req-1"], - "target_token_ids": [30, 31, 32, 40, 41, 42, 43], - "target_positions": [0, 1, 2, 0, 1, 2, 3], - "last_token_indices": [2, 4], - "common_attn_metadata": { - "query_start_loc": [0, 3, 7], - "query_start_loc_cpu": [0, 3, 7], - "seq_lens": [3, 4], - "seq_lens_cpu": [3, 4], - "num_computed_tokens_cpu": [0, 2], - "slot_mapping": [100, 101, 102, 200, 201, 202, 203], - "max_seq_len": 4, - }, - "expected": { - "prev_token_ids": [30, 31, 32, 40, 41, 42, 43], - "prev_positions": [0, 1, 2, 0, 1, 2, 3], - "last_token_indices": [2, 4], - "seq_lens": [3, 4], - "seq_lens_cpu": [3, 4], - "num_computed_tokens_cpu": [0, 2], - "slot_mapping": [100, 101, 102, 200, 201, 202, 203], - "max_seq_len": 4, - "cached_by_req": { - "req-0": { - "len": 3, - "token_ids": [30, 31, 32], - "positions": [0, 1, 2], - }, - "req-1": { - "len": 2, - "token_ids": [40, 41], - "positions": [0, 1], - }, - }, - }, - }, -] - -LAYER5_CASES = [ - { - "name": "layer5_cache_window5", - "batch_size": 1, - "running_req_ids": ["req-0"], - "target_token_ids": [1, 2, 3, 4, 5, 6], - "target_positions": [0, 1, 2, 3, 4, 5], - "last_token_indices": [2], - "common_attn_metadata": { - "query_start_loc": [0, 6], - "query_start_loc_cpu": [0, 6], - "seq_lens": [6], - "seq_lens_cpu": [6], - "num_computed_tokens_cpu": [2], - "slot_mapping": [100, 101, 102, 103, 104, 105], - "max_seq_len": 6, - }, - "expected": { - "prev_token_ids": [1, 2, 3, 4, 5, 6], - "prev_positions": [0, 1, 2, 3, 4, 5], - "last_token_indices": [2], - "seq_lens": [6], - "seq_lens_cpu": [6], - "num_computed_tokens_cpu": [2], - "slot_mapping": [100, 101, 102, 103, 104, 105], - "max_seq_len": 6, - "cached_by_req": { - "req-0": { - "len": 3, - "token_ids": [1, 2, 3], - "positions": [0, 1, 2], - }, - }, - }, - }, -] - - -def _run_adjust_input_case(proposer_stub, case, layer_num): - proposer_stub.layer_num = layer_num - proposer_stub.running_req_ids = case["running_req_ids"] - meta = case["common_attn_metadata"] - common_attn_metadata = SimpleNamespace( - query_start_loc=torch.tensor(meta["query_start_loc"], dtype=torch.int32), - query_start_loc_cpu=torch.tensor(meta["query_start_loc"], dtype=torch.int32), - seq_lens=torch.tensor(meta["seq_lens"], dtype=torch.int32), - seq_lens_cpu=torch.tensor(meta["seq_lens_cpu"], dtype=torch.int32), - num_computed_tokens_cpu=torch.tensor( - meta["num_computed_tokens_cpu"], dtype=torch.int32 - ), - slot_mapping=torch.tensor(meta["slot_mapping"], dtype=torch.int32), - max_seq_len=meta["max_seq_len"], - ) - - target_token_ids = torch.tensor(case["target_token_ids"], dtype=torch.int32) - target_positions = torch.tensor(case["target_positions"], dtype=torch.int64) - target_hidden_states = torch.arange( - 0, target_token_ids.numel() * 3, dtype=torch.float32 - ).reshape(-1, 3) - last_token_indices = torch.tensor(case["last_token_indices"], dtype=torch.int32) - - prev_token_ids, prev_positions, _, _, _ = proposer_stub.adjust_input( - batch_size=case["batch_size"], - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - last_token_indices=last_token_indices, - common_attn_metadata=common_attn_metadata, - ) - - expected = case["expected"] - assert prev_token_ids.tolist() == expected["prev_token_ids"] - assert prev_positions.tolist() == expected["prev_positions"] - assert last_token_indices.tolist() == expected["last_token_indices"] - assert common_attn_metadata.seq_lens.tolist() == expected["seq_lens"] - assert common_attn_metadata.seq_lens_cpu.tolist() == expected["seq_lens_cpu"] - assert ( - common_attn_metadata.num_computed_tokens_cpu.tolist() - == expected["num_computed_tokens_cpu"] - ) - assert common_attn_metadata.slot_mapping.tolist() == expected["slot_mapping"] - assert common_attn_metadata.max_seq_len == expected["max_seq_len"] - - for req_id, cached_expect in expected["cached_by_req"].items(): - cached = proposer_stub.draft_input_states_pool[req_id] - assert cached.len == cached_expect["len"] - assert cached.token_ids.tolist() == cached_expect["token_ids"] - assert cached.positions.tolist() == cached_expect["positions"] - - -@pytest.mark.parametrize( - "case", LAYER3_CASES, ids=[case["name"] for case in LAYER3_CASES] -) -def test_adjust_input_layer3_cases(proposer_stub, case): - _run_adjust_input_case(proposer_stub, case, layer_num=3) - - -@pytest.mark.parametrize( - "case", LAYER5_CASES, ids=[case["name"] for case in LAYER5_CASES] -) -def test_adjust_input_layer5_cases(proposer_stub, case): - _run_adjust_input_case(proposer_stub, case, layer_num=5) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 5db0ddd778f0..c7648487866c 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -41,7 +41,7 @@ "longcat_flash_mtp", "mtp", "pangu_ultra_moe_mtp", - "step3p5_mtp", + "step3p5_mtp" ] EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes] SpeculativeMethod = Literal[ @@ -76,12 +76,6 @@ class SpeculativeConfig: If using `ngram` method, the related configuration `prompt_lookup_max` and `prompt_lookup_min` should be considered.""" - - enable_multi_layers_mtp: bool = False - """If set to True, the MTP method will run multiple layers of MTP - speculator. If set to False, it will run only one layer of MTP speculator. - This is only effective when the method is set to `mtp`.""" - draft_tensor_parallel_size: int | None = Field(default=None, ge=1) """The degree of the tensor parallelism for the draft model. Can only be 1 or the same as the target model's tensor parallel size.""" @@ -270,12 +264,14 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.update( {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]} ) - + if hf_config.model_type == "step3p5": hf_config.model_type = "step3p5_mtp" n_predict = getattr(hf_config, "num_nextn_predict_layers", 1) - hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]}) - + hf_config.update( + {"n_predict": n_predict, "architectures": ["Step3p5MTP"]} + ) + if initial_architecture == "MistralLarge3ForCausalLM": hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]}) @@ -408,10 +404,7 @@ def __post_init__(self): MTPModelTypes ): self.method = "mtp" - if ( - self.enable_multi_layers_mtp is False - and self.num_speculative_tokens > 1 - ): + if self.num_speculative_tokens > 1: logger.warning( "Enabling num_speculative_tokens > 1 will run" "multiple times of forward on same MTP layer" diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index a07d5ad18929..9f5c8d1433cd 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -8,7 +8,6 @@ from torch import nn import vllm.envs as envs -from vllm.model_executor.layers.attention import Attention from vllm.v1.attention.backend import AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig diff --git a/vllm/model_executor/models/step3p5_mtp.py b/vllm/model_executor/models/step3p5_mtp.py index 996a717e54ba..a7747d09e9d5 100644 --- a/vllm/model_executor/models/step3p5_mtp.py +++ b/vllm/model_executor/models/step3p5_mtp.py @@ -1,19 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from typing import Optional import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import VllmConfig +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.model_executor.layers.layernorm import GemmaRMSNorm @@ -24,10 +23,11 @@ class SharedHead(nn.Module): + def __init__( self, config: PretrainedConfig, - quant_config: QuantizationConfig | None = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.norm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) @@ -40,18 +40,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Step3p5AMultiTokenPredictorLayer(nn.Module): + def __init__( self, - *, - vllm_config: VllmConfig, - prefix: str = "", + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + parallel_config: ParallelConfig = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - config = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - parallel_config = vllm_config.parallel_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) @@ -70,22 +69,25 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, - inputs_embeds: torch.Tensor | None = None, + inputs_embeds: Optional[torch.Tensor] = None, spec_step_index: int = 0, ) -> torch.Tensor: assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 inputs_embeds = self.enorm(inputs_embeds) previous_hidden_states = self.hnorm(previous_hidden_states) hidden_states = self.eh_proj( - torch.cat([inputs_embeds, previous_hidden_states], dim=-1) - ) + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) - hidden_states = self.mtp_block(positions=positions, hidden_states=hidden_states) + hidden_states = self.mtp_block(positions=positions, + hidden_states=hidden_states) return hidden_states class Step3p5AMultiTokenPredictor(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -96,18 +98,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.mtp_start_layer_idx = config.num_hidden_layers self.num_mtp_layers = config.num_nextn_predict_layers # to map the exact layer index from weights - self.layers = torch.nn.ModuleDict( - { - str(idx): Step3p5AMultiTokenPredictorLayer( - vllm_config=vllm_config, - prefix=f"{prefix}.layers.{idx}", - ) - for idx in range( - self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers, - ) - } - ) + self.layers = torch.nn.ModuleDict({ + str(idx): + Step3p5AMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + parallel_config=vllm_config.parallel_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) self.logits_processor = LogitsProcessor(config.vocab_size) @@ -116,12 +119,12 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, - inputs_embeds: torch.Tensor | None = None, + inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - current_step_idx = spec_step_idx % self.num_mtp_layers + current_step_idx = (spec_step_idx % self.num_mtp_layers) return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( input_ids, positions, @@ -135,11 +138,11 @@ def compute_logits( hidden_states: torch.Tensor, spec_step_idx: int = 0, ) -> torch.Tensor: - current_step_idx = spec_step_idx % self.num_mtp_layers - mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] - logits = self.logits_processor( - mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states) - ) + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers[str(self.mtp_start_layer_idx + + current_step_idx)] + logits = self.logits_processor(mtp_layer.shared_head.head, + mtp_layer.shared_head(hidden_states)) return logits def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -147,13 +150,14 @@ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: class Step3p5MTP(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config self.vllm_config = vllm_config - self.model = Step3p5AMultiTokenPredictor( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) + self.model = Step3p5AMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) @@ -163,23 +167,23 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - hidden_states = self.model( - input_ids, positions, hidden_states, inputs_embeds, spec_step_idx - ) + hidden_states = self.model(input_ids, positions, hidden_states, + inputs_embeds, spec_step_idx) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, spec_step_idx: int = 0, - ) -> torch.Tensor | None: + ) -> Optional[torch.Tensor]: return self.model.compute_logits(hidden_states, spec_step_idx) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: vllm_config = self.vllm_config config = vllm_config.model_config.hf_config @@ -207,7 +211,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if "embed_tokens" not in name and spec_layer is None: continue name = self._rewrite_spec_layer_name(spec_layer, name) - for param_name, weight_name, shard_id in stacked_params_mapping: + for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -217,7 +221,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if ("mlp.experts." in name) and name not in params_dict: + if (("mlp.experts." in name) and name not in params_dict): continue if "experts" in name or "moe" in name: continue @@ -237,52 +241,40 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ( - name.endswith(".bias") or name.endswith("_bias") - ) and name not in params_dict: + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): continue param = params_dict[name] weight_loader = param.weight_loader for expert_id in range(loaded_weight.shape[0]): loaded_weight_expert = loaded_weight[expert_id] - weight_loader( - param, - loaded_weight_expert, - name, - shard_id=shard_id, - expert_id=expert_id, - ) + weight_loader(param, + loaded_weight_expert, + name, + shard_id=shard_id, + expert_id=expert_id) loaded_params.add(name) break else: # Skip loading extra bias for GPTQ models. - if ( - name.endswith(".bias") - and name not in params_dict - or "tok_embeddings" in name - ): + if name.endswith( + ".bias" + ) and name not in params_dict or "tok_embeddings" in name: continue - mtp_start_layer_idx = config.num_hidden_layers - num_mtp_layers = config.num_nextn_predict_layers - - for idx in range( - mtp_start_layer_idx, mtp_start_layer_idx + num_mtp_layers - ): - if f"{idx}.transformer." in name: - name = name.replace(".transformer.", ".") + if f"{config.num_hidden_layers}.transformer." in name: + name = name.replace(".transformer.", ".") if "shared_head" in name: - name = name.replace("shared_head.output", "shared_head.head") + name = name.replace("shared_head.output", + "shared_head.head") if "embed_tokens" in name: - assert ( - hasattr(self.config, "num_nextn_predict_layers") - and self.config.num_nextn_predict_layers > 0 - ) + assert hasattr( + self.config, "num_nextn_predict_layers" + ) and self.config.num_nextn_predict_layers > 0 name = "model.embed_tokens.weight" param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) params_need_to_load = set(params_dict.keys()) @@ -300,8 +292,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: missing_params = list(params_need_to_load - loaded_params) param_name_example = missing_params[0] raise RuntimeError( - f"Some parameters like {param_name_example} are not in the checkpoint" - f" and will falsely use random initialization" + f"Some parameters like {param_name_example} are not in the checkpoint and will falsely use random initialization" ) return loaded_params @@ -311,11 +302,7 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: Add .mtp_block for modules in transformer layer block for spec layer """ spec_layer_weight_names = [ - "embed_tokens", - "enorm", - "hnorm", - "eh_proj", - "shared_head", + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" ] spec_layer_weight = False for weight_name in spec_layer_weight_names: @@ -324,7 +311,6 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: break if not spec_layer_weight: # treat rest weights as weights for transformer layer block - name = name.replace( - f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block." - ) + name = name.replace(f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.mtp_block.") return name diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 605c91012d3f..fd12dfe045a4 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -947,7 +947,6 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo def _get_kv_cache_groups_uniform_page_size( - vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], ) -> list[KVCacheGroupSpec]: """ @@ -1062,18 +1061,8 @@ def _get_kv_cache_groups_uniform_page_size( # the same and will cause memory waste. # To avoid this, we assign layers[i::num_groups] to the i-th group # instead of layers[i * group_size: (i + 1) * group_size] - - # for support multi layer mtp, we need to - # make all mtp layers in the same group - if ( - vllm_config.speculative_config is not None - and vllm_config.speculative_config.enable_multi_layers_mtp - ): - for i in range(0, len(layers), group_size): - grouped_layers.append(layers[i : i + group_size]) - else: - for i in range(num_groups): - grouped_layers.append(layers[i::num_groups]) + for i in range(num_groups): + grouped_layers.append(layers[i::num_groups]) return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) @@ -1258,9 +1247,7 @@ def get_kv_cache_groups( # have the same physical memory per block per layer. Split the layers # into groups with the same number of layers, and thus same total page # size. - return _get_kv_cache_groups_uniform_page_size( - vllm_config=vllm_config, kv_cache_spec=kv_cache_spec - ) + return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) def generate_scheduler_kv_cache_config( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 0e50e0eb9f72..43b84f4be8a2 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -381,11 +381,8 @@ def propose( input_ids = None inputs_embeds = self.inputs_embeds[:num_input_tokens] else: - self.inputs_embeds[:num_tokens] = self.model.embed_input_ids( - self.input_ids[:num_tokens], - ) input_ids = self.input_ids[:num_input_tokens] - inputs_embeds = self.inputs_embeds[:num_input_tokens] + inputs_embeds = None model_kwargs = { "input_ids": input_ids, @@ -578,12 +575,12 @@ def propose( self.hidden_states[:batch_size] = hidden_states if self.supports_mm_inputs: self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids) + input_ids = None inputs_embeds = self.inputs_embeds[:input_batch_size] else: - self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids) input_ids = self.input_ids[:input_batch_size] - inputs_embeds = self.inputs_embeds[:input_batch_size] + inputs_embeds = None # Run the model. model_kwargs = { @@ -1329,7 +1326,7 @@ def dummy_run( inputs_embeds = self.inputs_embeds[:num_input_tokens] else: input_ids = self.input_ids[:num_input_tokens] - inputs_embeds = self.inputs_embeds[:num_input_tokens] + inputs_embeds = None kwargs = dict( input_ids=input_ids, diff --git a/vllm/v1/spec_decode/multi_layer_eagle.py b/vllm/v1/spec_decode/multi_layer_eagle.py deleted file mode 100644 index 9731d8bcbe9c..000000000000 --- a/vllm/v1/spec_decode/multi_layer_eagle.py +++ /dev/null @@ -1,475 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any - -import torch - -from vllm.config import CUDAGraphMode, VllmConfig -from vllm.forward_context import set_forward_context -from vllm.logger import init_logger -from vllm.v1.attention.backend import ( - AttentionMetadataBuilder, - CommonAttentionMetadata, -) -from vllm.v1.attention.backends.tree_attn import TreeAttentionMetadata -from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.eagle import EagleProposer - -logger = init_logger(__name__) - -PADDING_SLOT_ID = -1 - - -class DraftInputStates: - def __init__( - self, - len: int, - token_ids: torch.Tensor, - hidden_states: torch.Tensor, - positions: torch.Tensor, - slot_mapping: torch.Tensor, - ): - self.len = len - self.token_ids = token_ids - self.hidden_states = hidden_states - self.positions = positions - self.slot_mapping = slot_mapping - - -class MultiLayerEagleProposer(EagleProposer): - def __init__( - self, - vllm_config: VllmConfig, - device: torch.device, - runner=None, - ): - super().__init__(vllm_config, device, runner) - - self.layer_num: int = getattr( - self.speculative_config.draft_model_config.hf_text_config, "n_predict", 0 - ) - self.num_speculative_tokens: int = ( - self.speculative_config.num_speculative_tokens - ) - if self.num_speculative_tokens != self.layer_num: - logger.warning_once( - "For multi_layer_eagle, num_speculative_tokens " - "does not match layer_num, adjusting to layer_num" - ) - self.num_speculative_tokens = self.layer_num - self.running_req_ids: list[str] | None = None - self.draft_input_states_pool: dict[str, DraftInputStates] = {} - - def set_running_req_ids(self, req_ids: list[str]): - self.running_req_ids = req_ids - - def _get_draft_input_states(self, req_id: str, len: int) -> DraftInputStates: - draft_input_states = self.draft_input_states_pool.get(req_id, None) - assert draft_input_states is not None - assert draft_input_states.len >= len - return draft_input_states - - def clean_req_cache(self, req_id: str): - self.draft_input_states_pool.pop(req_id, None) - - def adjust_input( - self, - batch_size: int, - target_token_ids: torch.Tensor, - target_positions: torch.Tensor, - target_hidden_states: torch.Tensor, - last_token_indices: torch.Tensor, - common_attn_metadata: CommonAttentionMetadata, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any, dict[str, Any]]: - start_token_indices = common_attn_metadata.query_start_loc[:-1] - start_token_pos = target_positions[start_token_indices] - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - - query_start_loc_cpu_np = query_start_loc_cpu.numpy() - start_token_indices_cpu = query_start_loc_cpu_np[:-1] - end_token_indices_cpu = query_start_loc_cpu_np[1:] - 1 - last_token_indices_cpu = last_token_indices.cpu().numpy() - start_token_pos_cpu = start_token_pos.cpu().numpy() - - prev_token_ids = target_token_ids - prev_positions = target_positions - prev_hidden_states = target_hidden_states - - for i in range(batch_size): - last_token_index: int = int(last_token_indices_cpu[i]) - start_token_index: int = int(start_token_indices_cpu[i]) - end_token_index: int = int(end_token_indices_cpu[i]) - start_pos: int = int(start_token_pos_cpu[i]) - assert self.running_req_ids is not None - req_id = self.running_req_ids[i] - shift = min(end_token_index - last_token_index, start_pos) - - modify_last_token_index = last_token_index - if shift > 0: - - def shift_input( - input: torch.Tensor, - cached: torch.Tensor, - start_token_index: int = start_token_index, - end_token_index: int = end_token_index, - shift: int = shift, - ) -> torch.Tensor: - window_len = end_token_index - start_token_index + 1 - dest = input.narrow( - 0, start_token_index + shift, window_len - shift - ) - # clone is used to ensure correctness in the case of - # overlap between src and dest - src = input.narrow(0, start_token_index, window_len - shift).clone() - dest.copy_(src) - head = input.narrow(0, start_token_index, shift) - head.copy_(cached[-shift:]) - return input - - cached_input_state = self._get_draft_input_states(req_id, shift) - prev_token_ids = shift_input( - prev_token_ids, cached_input_state.token_ids - ) - prev_positions = shift_input( - prev_positions, cached_input_state.positions - ) - prev_hidden_states = shift_input( - prev_hidden_states, cached_input_state.hidden_states - ) - common_attn_metadata.slot_mapping = shift_input( - common_attn_metadata.slot_mapping, cached_input_state.slot_mapping - ) - common_attn_metadata.seq_lens[i] -= shift - common_attn_metadata.num_computed_tokens_cpu[i] -= shift - common_attn_metadata.seq_lens_cpu[i] -= shift - - modify_last_token_index = last_token_index + shift - last_token_indices[i] += shift - - cache_start_index = max( - start_token_index, modify_last_token_index + 1 - self.layer_num - ) - - self.draft_input_states_pool[req_id] = DraftInputStates( - len=modify_last_token_index + 1 - cache_start_index, - token_ids=prev_token_ids[ - cache_start_index : modify_last_token_index + 1 - ].clone(), - hidden_states=prev_hidden_states[ - cache_start_index : modify_last_token_index + 1 - ].clone(), - positions=prev_positions[ - cache_start_index : modify_last_token_index + 1 - ].clone(), - slot_mapping=common_attn_metadata.slot_mapping[ - cache_start_index : modify_last_token_index + 1 - ].clone(), - ) - - common_attn_metadata.max_seq_len = torch.max( - common_attn_metadata.seq_lens - ).item() - - if self.attn_metadata_builder is None: - attn_metadata_builder = self._get_attention_metadata_builder() - else: - attn_metadata_builder = self.attn_metadata_builder - - assert isinstance(attn_metadata_builder, AttentionMetadataBuilder) - - attn_metadata = attn_metadata_builder.build_for_drafting( - common_attn_metadata=common_attn_metadata, draft_index=0 - ) - - # FIXME: support hybrid kv for draft model (remove separate indexer) - if self.draft_indexer_metadata_builder: - draft_indexer_metadata = ( - self.draft_indexer_metadata_builder.build_for_drafting( - common_attn_metadata=common_attn_metadata, - draft_index=0, - ) - ) - else: - draft_indexer_metadata = None - - # At this moment, we assume all eagle layers belong to the same KV - # cache group, thus using the same attention metadata. - per_layer_attn_metadata = {} - for layer_name in self.attn_layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata - for layer_name in self.indexer_layer_names: - assert draft_indexer_metadata is not None - per_layer_attn_metadata[layer_name] = draft_indexer_metadata - - return ( - prev_token_ids, - prev_positions, - prev_hidden_states, - attn_metadata, - per_layer_attn_metadata, - ) - - def initial_inputs_for_forward( - self, - num_tokens: int, - prev_token_ids: torch.Tensor, - prev_positions: torch.Tensor, - prev_hidden_states: torch.Tensor, - next_token_ids: torch.Tensor, - last_token_indices: torch.Tensor, - spec_step_idx: int = 0, - mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, - ): - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[: num_tokens - 1] = prev_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids - self._set_positions(num_tokens, prev_positions) - self.hidden_states[:num_tokens] = prev_hidden_states[:num_tokens] - if self.supports_mm_inputs: - mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) - - self.inputs_embeds[:num_tokens] = self.model.embed_input_ids( - self.input_ids[:num_tokens], - multimodal_embeddings=mm_embeds, - is_multimodal=is_mm_embed, - ) - else: - self.inputs_embeds[:num_tokens] = self.model.embed_input_ids( - self.input_ids[:num_tokens], - ) - - def draft_model_forward( - self, - num_tokens: int, - per_layer_attn_metadata: dict[str, Any], - last_token_indices: torch.Tensor, - sampling_metadata: SamplingMetadata, - common_attn_metadata: CommonAttentionMetadata, - spec_step_idx: int = 0, - ): - num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp( - num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens - ) - - cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( - num_tokens_dp_padded - ) - num_input_tokens = batch_desc.num_tokens - - if num_tokens_across_dp is not None: - num_tokens_across_dp[self.dp_rank] = num_input_tokens - - if self.supports_mm_inputs: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_input_tokens] - else: - input_ids = self.input_ids[:num_input_tokens] - inputs_embeds = self.inputs_embeds[:num_input_tokens] - - model_kwargs = { - "input_ids": input_ids, - "positions": self._get_positions(num_input_tokens), - "hidden_states": self.hidden_states[:num_input_tokens], - "inputs_embeds": inputs_embeds, - "spec_step_idx": spec_step_idx, - } - - with set_forward_context( - per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=cudagraph_runtime_mode, - slot_mapping=self._get_slot_mapping( - num_input_tokens, common_attn_metadata.slot_mapping - ), - ): - last_hidden_states = self.model(**model_kwargs) - - sample_hidden_states = last_hidden_states[last_token_indices] - logits = self.model.compute_logits( - sample_hidden_states, spec_step_idx=spec_step_idx - ) - - draft_token_ids = logits.argmax(dim=-1) - - return draft_token_ids, last_hidden_states - - def propose( - self, - # [num_tokens] - target_token_ids: torch.Tensor, - # [num_tokens] or [3, num_tokens] when M-RoPE is enabled - target_positions: torch.Tensor, - # [num_tokens, hidden_size] - target_hidden_states: torch.Tensor, - # [batch_size] - next_token_ids: torch.Tensor, - last_token_indices: torch.Tensor | None, - common_attn_metadata: CommonAttentionMetadata, - sampling_metadata: SamplingMetadata, - mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, - num_rejected_tokens_gpu: torch.Tensor | None = None, - slot_mappings: dict[str, torch.Tensor] - | list[dict[str, torch.Tensor]] - | None = None, - ) -> torch.Tensor: - assert self.method == "mtp" - assert self.runner is not None - assert target_positions.dim() == 1, ( - "MultiLayerEagleProposer does not support M-RoPE yet; " - f"got target_positions with shape {tuple(target_positions.shape)}" - ) - - num_tokens = target_token_ids.shape[0] - batch_size = next_token_ids.shape[0] - - if last_token_indices is None: - last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 - - ( - prev_token_ids, - prev_positions, - prev_hidden_states, - attn_metadata, - per_layer_attn_metadata, - ) = self.adjust_input( - batch_size=batch_size, - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - last_token_indices=last_token_indices, - common_attn_metadata=common_attn_metadata, - ) - - if isinstance(attn_metadata, TreeAttentionMetadata): - raise NotImplementedError( - "Tree attention is not supported for multi layer eagle." - ) - - if self.allowed_attn_types is not None and not isinstance( - attn_metadata, self.allowed_attn_types - ): - raise ValueError( - f"Unsupported attention metadata type for speculative " - "decoding for multi layer eagle: " - f"{type(attn_metadata)}. Supported types are: " - f"{self.allowed_attn_types}" - ) - - # Generate the remaining draft tokens. - draft_token_ids_list: list[torch.Tensor] = [] - - for token_index in range(self.num_speculative_tokens): - if token_index != 0: - prev_token_ids = self.input_ids[:num_tokens].clone() - next_token_ids = draft_token_ids_list[-1].int() - - self.initial_inputs_for_forward( - num_tokens=num_tokens, - prev_token_ids=prev_token_ids, - prev_positions=prev_positions, - prev_hidden_states=prev_hidden_states, - next_token_ids=next_token_ids, - last_token_indices=last_token_indices, - spec_step_idx=token_index, - mm_embed_inputs=mm_embed_inputs, - ) - - draft_token_ids, prev_hidden_states = self.draft_model_forward( - num_tokens=num_tokens, - per_layer_attn_metadata=per_layer_attn_metadata, - last_token_indices=last_token_indices, - sampling_metadata=sampling_metadata, - common_attn_metadata=common_attn_metadata, - spec_step_idx=token_index, - ) - - # Early exit if there is only one draft token to be generated. - if self.num_speculative_tokens == 1: - return draft_token_ids.view(-1, 1) - - draft_token_ids_list.append(draft_token_ids) - - # [batch_size, num_speculative_tokens] - draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - - return draft_token_ids - - def prepare_inputs( - self, - common_attn_metadata: CommonAttentionMetadata, - sampled_token_ids: list[list[int]], - num_draft_tokens: list[int], - ) -> tuple[CommonAttentionMetadata, torch.Tensor]: - """ - This function is used to prepare the inputs for speculative decoding. - It updates to the common_attn_metadata to account for the rejected - tokens (and newly sampled tokens). It also returns the token indices - of the tokens that should be fed to the speculator. - """ - raise Exception( - "speculative_config.disable_padded_drafter_batch" - " is not supported now for MultiLayerEagleProposer." - ) - - @torch.inference_mode() - def dummy_run( - self, - num_tokens: int, - use_cudagraphs: bool = True, - is_graph_capturing: bool = False, - slot_mappings: dict[str, torch.Tensor] | None = None, - ) -> None: - num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp( - num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens - ) - if use_cudagraphs: - cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( - num_tokens_dp_padded - ) - num_input_tokens = batch_desc.num_tokens - else: - cudagraph_runtime_mode = CUDAGraphMode.NONE - num_input_tokens = num_tokens_dp_padded - if num_tokens_across_dp is not None: - num_tokens_across_dp[self.dp_rank] = num_input_tokens - - # Make sure to use EAGLE's own buffer during cudagraph capture. - if ( - self.attn_layer_names - and slot_mappings is not None - and self.attn_layer_names[0] in slot_mappings - ): - slot_mapping_dict = self._get_slot_mapping(num_input_tokens) - else: - slot_mapping_dict = slot_mappings or {} - - for fwd_idx in range(self.layer_num): - with set_forward_context( - None, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=cudagraph_runtime_mode, - slot_mapping=slot_mapping_dict, - ): - if self.supports_mm_inputs: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_input_tokens] - else: - input_ids = self.input_ids[:num_input_tokens] - inputs_embeds = self.inputs_embeds[:num_input_tokens] - - model_kwargs = { - "input_ids": input_ids, - "positions": self._get_positions(num_input_tokens), - "hidden_states": self.hidden_states[:num_input_tokens], - "inputs_embeds": inputs_embeds, - "spec_step_idx": fwd_idx, - } - - self.model(**model_kwargs) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f01604073dfa..061ac8680157 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -155,7 +155,6 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -from vllm.v1.spec_decode.multi_layer_eagle import MultiLayerEagleProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext @@ -464,15 +463,7 @@ def __init__( elif self.speculative_config.method == "suffix": self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - if ( - self.speculative_config.enable_multi_layers_mtp - and self.speculative_config.method == "mtp" - ): - self.drafter = MultiLayerEagleProposer( - self.vllm_config, self.device, self - ) - else: - self.drafter = EagleProposer(self.vllm_config, self.device, self) + self.drafter = EagleProposer(self.vllm_config, self.device, self) if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = ( self.drafter.eagle3_use_aux_hidden_state @@ -893,10 +884,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) self.num_prompt_logprobs.pop(req_id, None) - if hasattr(self, "drafter") and isinstance( - self.drafter, MultiLayerEagleProposer - ): - self.drafter.clean_req_cache(req_id) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -4091,9 +4078,6 @@ def propose_draft_token_ids( else: mm_embed_inputs = None - if isinstance(self.drafter, MultiLayerEagleProposer): - self.drafter.set_running_req_ids(self.input_batch.req_ids) - draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, From ede83c9459db51498cb52ae2391d1b37a080e36e Mon Sep 17 00:00:00 2001 From: csy0225 Date: Thu, 29 Jan 2026 13:42:19 +0800 Subject: [PATCH 11/34] revert routed_scaling_factor passthrough in fused moe --- .../fused_moe/router/custom_routing_router.py | 4 ---- .../layers/fused_moe/router/router_factory.py | 3 +-- vllm/model_executor/models/step3p5.py | 21 ++++++------------- 3 files changed, 7 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py index a1f931156750..bbd73f57924d 100644 --- a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py +++ b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py @@ -21,7 +21,6 @@ def __init__( renormalize: bool = True, enable_eplb: bool = False, indices_type_getter: Callable[[], torch.dtype | None] | None = None, - routed_scaling_factor: float = 1.0, ): super().__init__( top_k=top_k, @@ -32,7 +31,6 @@ def __init__( ) self.custom_routing_function = custom_routing_function self.renormalize = renormalize - self.routed_scaling_factor = routed_scaling_factor @property def routing_method_type(self) -> RoutingMethodType: @@ -56,8 +54,6 @@ def _compute_routing( topk=self.top_k, renormalize=self.renormalize, ) - if self.routed_scaling_factor != 1.0: - topk_weights *= self.routed_scaling_factor return topk_weights.to(torch.float32), topk_ids.to( torch.int32 if indices_type is None else indices_type ) diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py index 330741dd5f4d..890f846d3539 100644 --- a/vllm/model_executor/layers/fused_moe/router/router_factory.py +++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py @@ -40,7 +40,7 @@ def create_fused_moe_router( topk_group: int | None = None, scoring_func: str = "softmax", num_fused_shared_experts: int = 0, - # grouped topk + fused topk bias parameters/ custom router function + # grouped topk + fused topk bias parameters routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, # custom routing paramaters @@ -130,7 +130,6 @@ def create_fused_moe_router( renormalize=renormalize, enable_eplb=enable_eplb, indices_type_getter=indices_type_getter, - routed_scaling_factor=routed_scaling_factor, ) if e_score_correction_bias is not None: diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index 9f5c8d1433cd..e19d5714f968 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -45,19 +45,6 @@ logger = init_logger(__name__) -def sigmoid_routing_function(hidden_states: torch.Tensor, - gating_output: torch.Tensor, topk: int, - renormalize: bool): - gating_output = gating_output.float() - gate_prob = torch.sigmoid(gating_output) - gate_prob = gate_prob / gate_prob.sum(dim=-1, keepdim=True) - topk_prob, indices = torch.topk(gate_prob, k=topk, dim=1) - expert_topk_weight = topk_prob - if renormalize: - expert_topk_weight = expert_topk_weight / torch.sum( - expert_topk_weight, dim=-1, keepdim=True) - return expert_topk_weight, indices.to(torch.int32) - class Step3p5MLP(nn.Module): def __init__( @@ -309,13 +296,15 @@ def __init__(self, assert config.moe_dynamic_exp_p == 1, "Only support dynamic exp p=1" self.use_moe_router_bias = config.use_moe_router_bias + self.routed_scaling_factor = getattr(config, "moe_router_scaling_factor", + 1.0) + if self.routed_scaling_factor is None: + self.routed_scaling_factor = 1.0 if self.use_moe_router_bias: self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts, dtype=torch.float32), requires_grad=False) custom_routing_function = self.router_bias_func - elif config.moe_router_activation == "sigmoid": - custom_routing_function = sigmoid_routing_function else: custom_routing_function = None self.need_fp32_gate = config.need_fp32_gate @@ -363,6 +352,8 @@ def router_bias_func(self, hidden_states: torch.Tensor, if renormalize: expert_topk_weight = expert_topk_weight / ( torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20) + if self.routed_scaling_factor != 1.0: + expert_topk_weight *= self.routed_scaling_factor return expert_topk_weight, indices.to(torch.int32) def forward( From 03c65f8d6a6bbfa5053c74082df69ac6d546f15c Mon Sep 17 00:00:00 2001 From: csy0225 Date: Thu, 29 Jan 2026 15:49:59 +0800 Subject: [PATCH 12/34] refactor: revert activation_limit for swiglustep --- vllm/envs.py | 4 ---- .../layers/fused_moe/batched_deep_gemm_moe.py | 2 +- .../layers/fused_moe/deep_gemm_moe.py | 11 ++++----- .../layers/fused_moe/fallback.py | 2 -- .../fused_moe/flashinfer_cutlass_moe.py | 1 - .../layers/fused_moe/fused_moe.py | 14 +++-------- .../fused_moe/fused_moe_modular_method.py | 1 - vllm/model_executor/layers/fused_moe/layer.py | 6 ----- .../layers/fused_moe/modular_kernel.py | 18 ++------------ .../fused_moe/router/custom_routing_router.py | 1 + .../fused_moe/unquantized_fused_moe_method.py | 1 - vllm/model_executor/layers/fused_moe/utils.py | 12 +++------- .../layers/quantization/bitsandbytes.py | 1 - .../compressed_tensors_moe.py | 4 ---- .../layers/quantization/experts_int8.py | 1 - .../model_executor/layers/quantization/fp8.py | 1 - .../layers/quantization/modelopt.py | 2 -- .../layers/quantization/quark/quark_moe.py | 2 -- vllm/model_executor/models/step3p5.py | 24 ++++++++++++------- 19 files changed, 30 insertions(+), 78 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 7d2bdadc47e3..741a2163c91f 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -254,7 +254,6 @@ VLLM_DEBUG_MFU_METRICS: bool = False VLLM_DISABLE_LOG_LOGO: bool = False VLLM_LORA_DISABLE_PDL: bool = False - VLLM_USE_FUSED_ALL_REDUCE: bool = True def get_default_cache_root(): @@ -1632,9 +1631,6 @@ def _get_or_set_default() -> str: # Disable PDL for LoRA, as enabling PDL with LoRA on SM100 causes # Triton compilation to fail. "VLLM_LORA_DISABLE_PDL": lambda: bool(int(os.getenv("VLLM_LORA_DISABLE_PDL", "0"))), - # If set, step3p5 will use symmcomm inplace all reduce. - "VLLM_USE_FUSED_ALL_REDUCE": - lambda: os.getenv("VLLM_USE_FUSED_ALL_REDUCE", "true").lower() in ("1", "true"), } diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 8c081ef440f1..ac37cff9329a 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -304,7 +304,7 @@ def _supports_quant_scheme( @staticmethod def _supports_activation(activation: str) -> bool: - return activation in ["silu", "swiglustep"] + return activation in ["silu"] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 775ba132992e..222ff124a05c 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -144,7 +144,7 @@ def _supports_quant_scheme( @staticmethod def _supports_activation(activation: str) -> bool: - return activation in ["silu", "swiglustep"] + return activation in ["silu", "swiglustep_clip_7"] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: @@ -185,7 +185,7 @@ def workspace_shapes( return (workspace1, workspace2, output) def _act_mul_quant( - self, input: torch.Tensor, output: torch.Tensor, activation: str, activation_limit: float | None = None, + self, input: torch.Tensor, output: torch.Tensor, activation: str ) -> tuple[torch.Tensor, torch.Tensor]: assert self.block_shape is not None block_k = self.block_shape[1] @@ -199,7 +199,7 @@ def _act_mul_quant( act_out = torch.empty( (M_sum, activation_out_dim), dtype=input.dtype, device=input.device ) - self.activation(activation, act_out, input, activation_limit) + self.activation(activation, act_out, input) a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm( act_out, block_k, @@ -220,7 +220,7 @@ def _act_mul_quant( act_out = torch.empty( (M_sum, activation_out_dim), dtype=input.dtype, device=input.device ) - self.activation(activation, act_out, input, activation_limit) + self.activation(activation, act_out, input) return per_token_group_quant_fp8( act_out, block_k, column_major_scales=True, out_q=output ) @@ -242,7 +242,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - activation_limit: float | None = None, ): assert a1q_scale is not None assert a2_scale is None @@ -291,7 +290,7 @@ def apply( workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim) ) a2q, a2q_scale = self._act_mul_quant( - input=mm1_out.view(-1, N), output=quant_out, activation=activation, activation_limit=activation_limit + input=mm1_out.view(-1, N), output=quant_out, activation=activation ) mm2_out = _resize_cache(workspace2, (M_sum, K)) diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py index 014bd30e8cbe..07e5b80059f0 100644 --- a/vllm/model_executor/layers/fused_moe/fallback.py +++ b/vllm/model_executor/layers/fused_moe/fallback.py @@ -168,7 +168,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - activation_limit: float | None = None, ): experts = self._select_experts_impl(hidden_states, w1, w2) experts.apply( @@ -187,5 +186,4 @@ def apply( workspace2, expert_tokens_meta, apply_router_weight_on_input, - activation_limit ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index cff0822207c6..7c27da46fee5 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -210,7 +210,6 @@ def apply( workspace2: torch.Tensor | None, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool | None, - activation_limit: float | None = None ): from flashinfer.fused_moe.core import ActivationType diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1872ab817302..94e03acefa9b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1343,7 +1343,6 @@ def inplace_fused_experts( block_shape: list[int] | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, - activation_limit: float | None = None, ) -> None: fused_experts_impl( hidden_states, @@ -1371,7 +1370,6 @@ def inplace_fused_experts( block_shape, w1_bias, w2_bias, - activation_limit, ) @@ -1400,7 +1398,6 @@ def inplace_fused_experts_fake( block_shape: list[int] | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, - activation_limit: float | None = None, ) -> None: pass @@ -1438,7 +1435,6 @@ def outplace_fused_experts( block_shape: list[int] | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, - activation_limit: float | None = None, ) -> torch.Tensor: return fused_experts_impl( hidden_states, @@ -1466,7 +1462,6 @@ def outplace_fused_experts( block_shape, w1_bias, w2_bias, - activation_limit, ) @@ -1494,7 +1489,6 @@ def outplace_fused_experts_fake( block_shape: list[int] | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, - activation_limit: float | None = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1619,7 +1613,6 @@ def fused_experts_impl( block_shape: list[int] | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, - activation_limit: float | None = None, ) -> torch.Tensor: # Check constraints. if use_int4_w4a16: @@ -1848,7 +1841,7 @@ def fused_experts_impl( ) apply_moe_activation( - activation, intermediate_cache2, intermediate_cache1.view(-1, N), activation_limit=activation_limit, + activation, intermediate_cache2, intermediate_cache1.view(-1, N) ) qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( @@ -1946,7 +1939,7 @@ def _supports_quant_scheme( @staticmethod def _supports_activation(activation: str) -> bool: - return activation in ["silu", "gelu", "swigluoai", "swiglustep"] + return activation in ["silu", "gelu", "swigluoai", "swiglustep_clip_7"] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: @@ -1995,7 +1988,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - activation_limit: float | None = None, ): # Check constraints. if self.quant_config.use_int4_w4a16: @@ -2083,7 +2075,7 @@ def apply( ) self.activation( - activation, intermediate_cache2, intermediate_cache1.view(-1, N), activation_limit=activation_limit + activation, intermediate_cache2, intermediate_cache1.view(-1, N) ) a2q_scale: torch.Tensor | None = None diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index eaab3bcff677..7a2244a9bc1d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -101,7 +101,6 @@ def apply( topk_ids=topk_ids, inplace=self.allow_inplace, activation=layer.activation, - activation_limit=getattr(layer, "activation_limit", None), global_num_experts=layer.global_num_experts, apply_router_weight_on_input=layer.apply_router_weight_on_input, expert_map=None if self.disable_expert_map else layer.expert_map, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 532a700986d6..5fe4bce7a4fc 100755 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -332,7 +332,6 @@ def __init__( expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, router_logits_dtype: torch.dtype | None = None, - activation_limit: float | None = None, ): super().__init__() @@ -520,11 +519,6 @@ def __init__( self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation - self.activation_limit = activation_limit - if self.activation == "swiglustep" and self.activation_limit is None: - raise ValueError( - "activation='swiglustep' requires activation_limit to be set." - ) self.router = create_fused_moe_router( top_k=top_k, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index beb78ec7f937..940a2c55f73a 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -414,7 +414,6 @@ def __init__( self.quant_config = quant_config self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers - self.activation_limit: float | None = None @property def expects_unquantized_inputs(self) -> bool: @@ -712,14 +711,9 @@ def adjust_N_for_activation(N: int, activation: str) -> int: return N if is_no_mul else N // 2 def activation( - self, activation: str, output: torch.Tensor, input: torch.Tensor, activation_limit: float | None = None + self, activation: str, output: torch.Tensor, input: torch.Tensor ) -> None: - apply_moe_activation( - activation, - output, - input, - activation_limit=activation_limit, - ) + apply_moe_activation(activation, output, input) def enable_chunking(self): return ( @@ -747,7 +741,6 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - activation_limit: float | None = None ) -> None: """ This function computes the intermediate result of a Mixture of Experts @@ -1146,7 +1139,6 @@ def _fused_experts( expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, expert_tokens_meta: ExpertTokensMetadata | None, - activation_limit: float | None = None ) -> torch.Tensor: _, M_full, N, K, top_k = self.fused_experts.moe_problem_size( a1q, w1, w2, topk_ids @@ -1222,7 +1214,6 @@ def input_chunk_range(chunk_idx: int) -> tuple[int, int]: workspace2=workspace2, expert_tokens_meta=c_expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, - activation_limit=activation_limit, ) return fused_out @@ -1306,7 +1297,6 @@ def forward( global_num_experts: int = -1, expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, - activation_limit: float | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ This function computes a Mixture of Experts (MoE) layer using two sets @@ -1336,9 +1326,6 @@ def forward( - torch.Tensor: The output tensor after applying the MoE layer. """ - # Propagate any activation parameters to the experts implementation. - self.fused_experts.activation_limit = activation_limit - if inplace and self.shared_experts is None and not disable_inplace(): output = hidden_states else: @@ -1371,7 +1358,6 @@ def forward( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, expert_tokens_meta=expert_tokens_meta, - activation_limit=activation_limit ) return self._finalize( diff --git a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py index bbd73f57924d..0367189ca1ab 100644 --- a/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py +++ b/vllm/model_executor/layers/fused_moe/router/custom_routing_router.py @@ -54,6 +54,7 @@ def _compute_routing( topk=self.top_k, renormalize=self.renormalize, ) + return topk_weights.to(torch.float32), topk_ids.to( torch.int32 if indices_type is None else indices_type ) diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 4bcd367c000d..2ddaf272b147 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -344,7 +344,6 @@ def forward_cuda( topk_ids=topk_ids, inplace=self.use_inplace, activation=layer.activation, - activation_limit=getattr(layer, "activation_limit", None), apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index f0d98d807a28..903eef1d7e86 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -331,12 +331,11 @@ def apply_moe_activation( activation: str, output: torch.Tensor, input: torch.Tensor, - activation_limit: float | None = None, ) -> torch.Tensor: """ Apply MoE activation function. - For *_and_mul activations (silu, gelu, swigluoai, swiglustep): + For *_and_mul activations (silu, gelu, swigluoai): - Expects output.size(-1) * 2 == input.size(-1) For *_no_mul activations (silu_no_mul, gelu_no_mul, relu2_no_mul): @@ -359,14 +358,9 @@ def apply_moe_activation( torch.ops._C.gelu_and_mul(output, input) elif activation == "swigluoai": torch.ops._C.swigluoai_and_mul(output, input) - elif activation == "swiglustep": - if activation_limit is None: - raise ValueError( - "activation='swiglustep' requires activation_limit to be set." - ) + elif activation == "swiglustep_clip_7": from vllm.model_executor.layers.activation import swiglustep_and_mul_out - - swiglustep_and_mul_out(output, input, activation_limit) + swiglustep_and_mul_out(output, input, 7.0) # Activations without gated multiplication elif activation == SILU_NO_MUL: output.copy_(F.silu(input)) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 22211a17e445..8b6b1e445f35 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -517,7 +517,6 @@ def apply( topk_ids=topk_ids, inplace=True, activation=layer.activation, - activation_limit=getattr(layer, "activation_limit", None), apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 2e3944f0c585..dbfa8fb9bd7a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -672,7 +672,6 @@ def apply( topk_ids, inplace=False, activation=layer.activation, - activation_limit=getattr(layer, "activation_limit", None), global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, @@ -1085,7 +1084,6 @@ def apply( topk_ids, inplace=self.use_inplace, activation=layer.activation, - activation_limit=getattr(layer, "activation_limit", None), global_num_experts=layer.global_num_experts, # TODO(rob): investigate the disable_expert_map introduced by: # https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501 @@ -1225,7 +1223,6 @@ def apply( topk_ids=topk_ids, inplace=True, activation=layer.activation, - activation_limit=getattr(layer, "activation_limit", None), apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, @@ -1983,7 +1980,6 @@ def apply( topk_ids=topk_ids, inplace=True, activation=layer.activation, - activation_limit=getattr(layer, "activation_limit", None), apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index e1d594a76730..5a0bb5d30f9e 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -151,7 +151,6 @@ def apply( topk_ids=topk_ids, inplace=True, activation=layer.activation, - activation_limit=getattr(layer, "activation_limit", None), apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 208cd1914a82..6436a9ae0abf 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1019,7 +1019,6 @@ def apply( topk_ids, inplace=self.use_inplace, activation=layer.activation, - activation_limit=getattr(layer, "activation_limit", None), global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 17abc92b3970..e76c109eceda 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -969,7 +969,6 @@ def apply( topk_ids, inplace=self.use_inplace, activation=layer.activation, - activation_limit=getattr(layer, "activation_limit", None), global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, @@ -1541,7 +1540,6 @@ def apply( topk_ids, inplace=False, activation=layer.activation, - activation_limit=getattr(layer, "activation_limit", None), global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index d519390d9089..d2f0213e8091 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -400,7 +400,6 @@ def apply( topk_ids=topk_ids, inplace=True, activation=layer.activation, - activation_limit=getattr(layer, "activation_limit", None), apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, @@ -772,7 +771,6 @@ def apply( topk_ids=topk_ids, inplace=True, activation=layer.activation, - activation_limit=getattr(layer, "activation_limit", None), global_num_experts=layer.global_num_experts, apply_router_weight_on_input=layer.apply_router_weight_on_input, expert_map=layer.expert_map, diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index e19d5714f968..b440c0dce1b8 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -310,14 +310,22 @@ def __init__(self, self.need_fp32_gate = config.need_fp32_gate layer_idx = int(prefix.split("layers.")[1].split(".")[0]) activation = "silu" - swigluoai_step_limit = None - if config.swiglu_limits and config.swiglu_limits[ - layer_idx] is not None and config.swiglu_limits[layer_idx] != 0: - swigluoai_step_limit = config.swiglu_limits[layer_idx] - activation = "swiglustep" + swiglu_limits = config.swiglu_limits or [] + swiglu_limit = ( + swiglu_limits[layer_idx] if layer_idx < len(swiglu_limits) else None + ) + if swiglu_limit not in (None, 0): + swiglu_limit = float(swiglu_limit) + swiglu_activation_by_limit = { + 7.0: "swiglustep_clip_7", + 16.0: "swiglustep_clip_16", + } + activation = swiglu_activation_by_limit.get(swiglu_limit, activation) logger.info( - f"step3p5 layer_idx: {layer_idx}, activation limit: {config.swiglu_limits[layer_idx]}, will use swiglustep" + f"step3p5 layer_idx: {layer_idx}, activation: {activation}, " + f"limit: {swiglu_limit}" ) + self.experts = SharedFusedMoE( shared_experts=shared_experts, num_experts=config.moe_num_experts, @@ -328,7 +336,6 @@ def __init__(self, renormalize=config.norm_expert_weight, quant_config=quant_config, activation=activation, - activation_limit=swigluoai_step_limit if swigluoai_step_limit else None, prefix=f"{prefix}.experts", custom_routing_function=custom_routing_function, routed_scaling_factor=config.moe_router_scaling_factor, @@ -437,7 +444,7 @@ def __init__(self, self.use_moe = False self.tp_group = get_tp_group() self.use_fused_all_reduce = get_tensor_model_parallel_world_size( - ) > 1 and get_dp_group().world_size == 1 and envs.VLLM_USE_FUSED_ALL_REDUCE + ) > 1 and get_dp_group().world_size == 1 if self.use_fused_all_reduce: logger.warning_once("Enable custom fused all reduce...") else: @@ -503,7 +510,6 @@ def forward(self, positions: torch.Tensor, if self.use_moe: shared_output, moe_output = self.moe(hidden_states) - # share expert & moe 可以合并all reduce ffn_output = self.add_and_maybe_inplace_all_reduce( moe_output, shared_output) else: From 952026ec48bf4ce63d289dd42bb30bd0872ab070 Mon Sep 17 00:00:00 2001 From: zetaohong Date: Thu, 29 Jan 2026 16:28:40 +0800 Subject: [PATCH 13/34] fix step3p5 reasoning parser --- vllm/reasoning/step3p5_reasoning_parser.py | 32 +++++++++++++++++----- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/vllm/reasoning/step3p5_reasoning_parser.py b/vllm/reasoning/step3p5_reasoning_parser.py index 93aa7f5ee08d..f558f59d7407 100644 --- a/vllm/reasoning/step3p5_reasoning_parser.py +++ b/vllm/reasoning/step3p5_reasoning_parser.py @@ -14,6 +14,7 @@ from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser + class Step3p5ReasoningParser(BaseThinkingReasoningParser): """ Reasoning parser for Step3p5 model. @@ -39,6 +40,25 @@ def __init__(self, tokenizer: TokenizerLike, *args, **kwargs): # whether it is immediately before . self._pending_reasoning_newline = False + # Used to delay the reasoning end detection. + # This is necessary to remove the newline appears immediately after , + # which may cause the end detection to be delayed by one round. + self.end_offset = 1 + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + if self.end_token_id in input_ids and self.end_offset > 0: + self.end_offset -= 1 + return False + return self.end_offset < 1 + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + if self.end_token_id in input_ids and self.end_offset > 0: + self.end_offset -= 1 + return False + return self.end_offset < 1 + def extract_reasoning( self, model_output: str, @@ -68,11 +88,6 @@ def extract_reasoning_streaming( remaining = delta_text.removeprefix("\n") return DeltaMessage(content=remaining) if remaining else None - # If we are about to see the end token, any pending newline is - # immediately before and should be dropped. - if self.end_token_id in delta_token_ids and self._pending_reasoning_newline: - self._pending_reasoning_newline = False - ret = super().extract_reasoning_streaming( previous_text, current_text, @@ -98,9 +113,9 @@ def extract_reasoning_streaming( content = delta_text[end_index + len(self.end_token) :] ret = DeltaMessage(reasoning=reasoning, content=content or None) elif self.end_token_id in previous_token_ids: - ret = DeltaMessage(content=delta_text or None) + ret = DeltaMessage(content=delta_text) else: - ret = DeltaMessage(reasoning=delta_text or None) + ret = DeltaMessage(reasoning=delta_text) reasoning_to_output = ret.reasoning content_to_output = ret.content @@ -122,6 +137,9 @@ def extract_reasoning_streaming( # Content: handle the newline immediately after . if content_to_output is not None: + # No need to get into parser again to remove newline after . + self.end_offset -= 1 + # If we have content, reasoning must have ended. self._pending_reasoning_newline = False From 285a3abe25a1d053e1f94f891ed0abd72d5de1d0 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Thu, 29 Jan 2026 16:40:45 +0800 Subject: [PATCH 14/34] fix: fix comments about swiglustep, default limit=7.0 --- vllm/model_executor/layers/activation.py | 6 ++---- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 2 +- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/layers/fused_moe/utils.py | 4 ++-- vllm/model_executor/models/step3p5.py | 7 ++----- 5 files changed, 8 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 2a623ede3d39..b242b892f097 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -25,7 +25,7 @@ def swiglustep_and_mul_out( out: torch.Tensor, x: torch.Tensor, - limit: float, + limit: float = 7.0, ) -> torch.Tensor: """Out-variant of swiglustep activation. @@ -337,9 +337,7 @@ class SwigluStepAndMul(CustomOp): return: (num_tokens, d) or (batch_size, seq_len, d) """ - # --8<-- [end:swiglustep_and_mul] - - def __init__(self, limit: float): + def __init__(self, limit: float = 7.0): super().__init__() if limit is None: raise ValueError("SwigluStepAndMul requires limit to be set.") diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 222ff124a05c..00d55bfb7e04 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -144,7 +144,7 @@ def _supports_quant_scheme( @staticmethod def _supports_activation(activation: str) -> bool: - return activation in ["silu", "swiglustep_clip_7"] + return activation in ["silu", "swiglustep"] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 94e03acefa9b..120b3c2d1e2c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1939,7 +1939,7 @@ def _supports_quant_scheme( @staticmethod def _supports_activation(activation: str) -> bool: - return activation in ["silu", "gelu", "swigluoai", "swiglustep_clip_7"] + return activation in ["silu", "gelu", "swigluoai", "swiglustep"] @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 903eef1d7e86..29964f566e0e 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -358,9 +358,9 @@ def apply_moe_activation( torch.ops._C.gelu_and_mul(output, input) elif activation == "swigluoai": torch.ops._C.swigluoai_and_mul(output, input) - elif activation == "swiglustep_clip_7": + elif activation == "swiglustep": from vllm.model_executor.layers.activation import swiglustep_and_mul_out - swiglustep_and_mul_out(output, input, 7.0) + swiglustep_and_mul_out(output, input) # Activations without gated multiplication elif activation == SILU_NO_MUL: output.copy_(F.silu(input)) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index b440c0dce1b8..e494bfcec23c 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -316,11 +316,8 @@ def __init__(self, ) if swiglu_limit not in (None, 0): swiglu_limit = float(swiglu_limit) - swiglu_activation_by_limit = { - 7.0: "swiglustep_clip_7", - 16.0: "swiglustep_clip_16", - } - activation = swiglu_activation_by_limit.get(swiglu_limit, activation) + assert swiglu_limit == 7.0, "swiglu_limit in fused moe block only suport 7.0 now." + activation = "swiglustep" logger.info( f"step3p5 layer_idx: {layer_idx}, activation: {activation}, " f"limit: {swiglu_limit}" From 864602c054604c8430de3c1c7b8576e4dcabc91b Mon Sep 17 00:00:00 2001 From: csy0225 Date: Thu, 29 Jan 2026 17:54:44 +0800 Subject: [PATCH 15/34] format: remove useless field in step3p5 config --- vllm/model_executor/models/step3p5.py | 34 ++++++++-------------- vllm/transformers_utils/configs/step3p5.py | 21 ------------- 2 files changed, 12 insertions(+), 43 deletions(-) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index e494bfcec23c..2d8ed2a65e16 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -181,6 +181,7 @@ def __init__( rope_parameters: dict[str, Any] = { "rope_type": "default", "partial_rotary_factor": partial_rotary_factor, + "rope_theta": self.rope_theta } if rope_scaling is not None: if not isinstance(rope_scaling, dict): @@ -188,8 +189,6 @@ def __init__( "rope_scaling must be a dict for Step3p5Attention." ) rope_parameters.update(rope_scaling) - rope_parameters["rope_theta"] = self.rope_theta - rope_parameters["partial_rotary_factor"] = partial_rotary_factor self.rotary_emb = get_rope( head_size=self.head_dim, @@ -296,18 +295,13 @@ def __init__(self, assert config.moe_dynamic_exp_p == 1, "Only support dynamic exp p=1" self.use_moe_router_bias = config.use_moe_router_bias - self.routed_scaling_factor = getattr(config, "moe_router_scaling_factor", - 1.0) - if self.routed_scaling_factor is None: - self.routed_scaling_factor = 1.0 - if self.use_moe_router_bias: - self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts, - dtype=torch.float32), - requires_grad=False) - custom_routing_function = self.router_bias_func - else: - custom_routing_function = None + assert self.use_moe_router_bias == True, "Only support use_moe_router_bias == true." + self.routed_scaling_factor = config.moe_router_scaling_factor + self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts, + dtype=torch.float32), + requires_grad=False) self.need_fp32_gate = config.need_fp32_gate + assert self.need_fp32_gate, "Router logits must use FP32 precision for numerical stability." layer_idx = int(prefix.split("layers.")[1].split(".")[0]) activation = "silu" swiglu_limits = config.swiglu_limits or [] @@ -316,7 +310,7 @@ def __init__(self, ) if swiglu_limit not in (None, 0): swiglu_limit = float(swiglu_limit) - assert swiglu_limit == 7.0, "swiglu_limit in fused moe block only suport 7.0 now." + assert swiglu_limit == 7.0, "Swiglu limit in fused moe block only suport 7.0 now." activation = "swiglustep" logger.info( f"step3p5 layer_idx: {layer_idx}, activation: {activation}, " @@ -334,7 +328,7 @@ def __init__(self, quant_config=quant_config, activation=activation, prefix=f"{prefix}.experts", - custom_routing_function=custom_routing_function, + custom_routing_function=self.router_bias_func, routed_scaling_factor=config.moe_router_scaling_factor, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, @@ -356,8 +350,7 @@ def router_bias_func(self, hidden_states: torch.Tensor, if renormalize: expert_topk_weight = expert_topk_weight / ( torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20) - if self.routed_scaling_factor != 1.0: - expert_topk_weight *= self.routed_scaling_factor + expert_topk_weight *= self.routed_scaling_factor return expert_topk_weight, indices.to(torch.int32) def forward( @@ -366,11 +359,8 @@ def forward( orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) - if self.need_fp32_gate: - router_logits = hidden_states.to( - torch.float32) @ self.gate.weight.to(torch.float32).t() - else: - router_logits, _ = self.gate(hidden_states) + router_logits = hidden_states.to( + torch.float32) @ self.gate.weight.to(torch.float32).t() shared_out, final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits) diff --git a/vllm/transformers_utils/configs/step3p5.py b/vllm/transformers_utils/configs/step3p5.py index 5d34608a917c..33daf931c05f 100644 --- a/vllm/transformers_utils/configs/step3p5.py +++ b/vllm/transformers_utils/configs/step3p5.py @@ -23,25 +23,17 @@ def __init__( moe_intermediate_size: int = 10240, moe_num_experts: int = 16, moe_top_k: int = 4, - max_pos_interp_ratio: float = 1, moe_layer_offset: int = 0, moe_dynamic_exp_p: float = 1.0, rope_theta: Optional[Union[float, list[float]]] = 500000, rope_scaling: Optional[dict[str, Any]] = None, head_dim: Optional[int] = None, share_expert_dim: Optional[int] = None, - allgather_dtype: Optional[str] = None, - share_q_dim: Optional[int] = None, norm_expert_weight: bool = True, bos_token_id: Optional[Union[list[int], int]] = None, eos_token_id: Optional[Union[list[int], int]] = None, moe_router_activation: str = "softmax", moe_router_scaling_factor: float = 1.0, - qk_nope_head_dim: Optional[int] = None, - qk_rope_head_dim: Optional[int] = None, - v_head_dim: Optional[int] = None, - q_lora_rank: Optional[int] = None, - kv_lora_rank: Optional[int] = None, att_impl_type: str = "MFA", use_head_wise_attn_gate: bool = False, use_moe_router_bias: bool = False, @@ -51,10 +43,8 @@ def __init__( yarn_only_types: Optional[list[str]] = None, attention_other_setting: Optional[dict[str, Any]] = None, num_nextn_predict_layers: int = 0, - swa_num_attention_heads: Optional[int] = None, swiglu_limits: Optional[list[float]] = None, swiglu_limits_shared: Optional[list[float]] = None, - zero_centered: bool = True, max_position_embeddings: Optional[int] = None, **kwargs, ): @@ -72,7 +62,6 @@ def __init__( self.moe_num_experts = moe_num_experts self.num_experts_per_tok = moe_top_k self.moe_top_k = moe_top_k - self.max_pos_interp_ratio = max_pos_interp_ratio self.moe_layer_offset = moe_layer_offset self.moe_dynamic_exp_p = moe_dynamic_exp_p @@ -83,21 +72,13 @@ def __init__( self.share_expert_dim = self.moe_intermediate_size * self.moe_top_k else: self.share_expert_dim = share_expert_dim - self.share_q_dim = share_q_dim self.norm_expert_weight = norm_expert_weight - self.allgather_dtype = allgather_dtype - self.max_position_embeddings = max_position_embeddings self.moe_router_activation = moe_router_activation self.moe_router_scaling_factor = moe_router_scaling_factor self.use_moe_router_bias = use_moe_router_bias self.need_fp32_gate = need_fp32_gate - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.v_head_dim = v_head_dim - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank self.att_impl_type = att_impl_type self.use_head_wise_attn_gate = use_head_wise_attn_gate @@ -106,10 +87,8 @@ def __init__( self.yarn_only_types = yarn_only_types self.attention_other_setting = attention_other_setting self.num_nextn_predict_layers = num_nextn_predict_layers - self.swa_num_attention_heads = swa_num_attention_heads self.swiglu_limits = swiglu_limits self.swiglu_limits_shared = swiglu_limits_shared - self.zero_centered = zero_centered resolved_bos_token_id = 1 if bos_token_id is None else bos_token_id resolved_eos_token_id = [2, 3] if eos_token_id is None else eos_token_id From 5389b7d6cbd86f095b9addeed86aa9dc1060846a Mon Sep 17 00:00:00 2001 From: csy0225 Date: Thu, 29 Jan 2026 18:29:34 +0800 Subject: [PATCH 16/34] format: simplified rope scaling params update --- vllm/model_executor/models/step3p5.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index 2d8ed2a65e16..d5df8fcdd30f 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -105,7 +105,7 @@ def __init__( rope_theta: Optional[Union[float, list[float]]] = 10000, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[tuple] = None, + rope_scaling: Optional[dict[str, Any]] = None, prefix: str = "", attn_type: str = AttentionType.DECODER, # Step3p5 specific args @@ -178,17 +178,15 @@ def __init__( prefix=f"{prefix}.o_proj", ) - rope_parameters: dict[str, Any] = { - "rope_type": "default", - "partial_rotary_factor": partial_rotary_factor, - "rope_theta": self.rope_theta - } - if rope_scaling is not None: - if not isinstance(rope_scaling, dict): - raise ValueError( - "rope_scaling must be a dict for Step3p5Attention." - ) - rope_parameters.update(rope_scaling) + if rope_scaling is not None and not isinstance(rope_scaling, dict): + raise ValueError("rope_scaling must be a dict for Step3p5Attention.") + + rope_parameters: dict[str, Any] = ( + dict(rope_scaling) if rope_scaling is not None else {} + ) + rope_parameters.setdefault("rope_type", "default") + rope_parameters["rope_theta"] = self.rope_theta + rope_parameters["partial_rotary_factor"] = partial_rotary_factor self.rotary_emb = get_rope( head_size=self.head_dim, @@ -378,7 +376,6 @@ def __init__(self, super().__init__() config = config.hf_config self.hidden_size = config.hidden_size - rope_scaling = getattr(config, "rope_scaling", None) layer_idx = int(prefix.split("layers.")[1].split(".")[0]) self.layer_idx = layer_idx if cache_config is not None: @@ -412,7 +409,7 @@ def __init__(self, config, 'head_dim', None), cache_config=cache_config, quant_config=quant_config, - rope_scaling=rope_scaling, + rope_scaling=getattr(config, "rope_scaling", None), sliding_window=getattr(config, 'sliding_window', None), use_head_wise_attn_gate=getattr(config, "use_head_wise_attn_gate", From 459448f10024249678517f8ad389b012bef9f976 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Thu, 29 Jan 2026 20:36:27 +0800 Subject: [PATCH 17/34] fix: mtp --- vllm/model_executor/models/step3p5_mtp.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/step3p5_mtp.py b/vllm/model_executor/models/step3p5_mtp.py index a7747d09e9d5..967e6b54202c 100644 --- a/vllm/model_executor/models/step3p5_mtp.py +++ b/vllm/model_executor/models/step3p5_mtp.py @@ -73,8 +73,6 @@ def forward( spec_step_index: int = 0, ) -> torch.Tensor: assert inputs_embeds is not None - # masking inputs at position 0, as not needed by MTP - inputs_embeds[positions == 0] = 0 inputs_embeds = self.enorm(inputs_embeds) previous_hidden_states = self.hnorm(previous_hidden_states) From a5ab3708511dc8a239bc4f59cd2fed8b539a516b Mon Sep 17 00:00:00 2001 From: csy0225 Date: Fri, 30 Jan 2026 10:41:12 +0800 Subject: [PATCH 18/34] format: fix config.json default value --- vllm/transformers_utils/configs/step3p5.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/transformers_utils/configs/step3p5.py b/vllm/transformers_utils/configs/step3p5.py index 33daf931c05f..814486ed8b36 100644 --- a/vllm/transformers_utils/configs/step3p5.py +++ b/vllm/transformers_utils/configs/step3p5.py @@ -34,10 +34,10 @@ def __init__( eos_token_id: Optional[Union[list[int], int]] = None, moe_router_activation: str = "softmax", moe_router_scaling_factor: float = 1.0, - att_impl_type: str = "MFA", + att_impl_type: str = "GQA", use_head_wise_attn_gate: bool = False, - use_moe_router_bias: bool = False, - need_fp32_gate: bool = False, + use_moe_router_bias: bool = True, + need_fp32_gate: bool = True, layer_types: Optional[list[str]] = None, use_rope_layers: Optional[list[bool]] = None, yarn_only_types: Optional[list[str]] = None, From 4086e80f712b391c4705d4bb17df880b6c9cf657 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Fri, 30 Jan 2026 10:51:18 +0800 Subject: [PATCH 19/34] format: pre-commit tool fix --- vllm/config/speculative.py | 10 +- vllm/model_executor/layers/fused_moe/utils.py | 1 + vllm/model_executor/models/step3p5.py | 470 +++++++----- vllm/model_executor/models/step3p5_mtp.py | 172 +++-- vllm/reasoning/step3p5_reasoning_parser.py | 5 +- vllm/tool_parsers/step3p5_tool_parser.py | 688 ++++++++++-------- vllm/transformers_utils/configs/step3p5.py | 28 +- 7 files changed, 782 insertions(+), 592 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index c7648487866c..966d168b47ab 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -41,7 +41,7 @@ "longcat_flash_mtp", "mtp", "pangu_ultra_moe_mtp", - "step3p5_mtp" + "step3p5_mtp", ] EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes] SpeculativeMethod = Literal[ @@ -264,14 +264,12 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.update( {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]} ) - + if hf_config.model_type == "step3p5": hf_config.model_type = "step3p5_mtp" n_predict = getattr(hf_config, "num_nextn_predict_layers", 1) - hf_config.update( - {"n_predict": n_predict, "architectures": ["Step3p5MTP"]} - ) - + hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]}) + if initial_architecture == "MistralLarge3ForCausalLM": hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]}) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 29964f566e0e..50c216c43b81 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -360,6 +360,7 @@ def apply_moe_activation( torch.ops._C.swigluoai_and_mul(output, input) elif activation == "swiglustep": from vllm.model_executor.layers.activation import swiglustep_and_mul_out + swiglustep_and_mul_out(output, input) # Activations without gated multiplication elif activation == SILU_NO_MUL: diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index d5df8fcdd30f..161d17190113 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -1,87 +1,103 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Jurassic model.""" + from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn -import vllm.envs as envs -from vllm.v1.attention.backend import AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig -from vllm.distributed import (get_dp_group, - get_ep_group, get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - get_tp_group) +from vllm.distributed import ( + get_dp_group, + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, +) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul, SwigluStepAndMul from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import GemmaRMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.fused_moe.shared_fused_moe import ( - SharedFusedMoE) from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backend import AttentionType from .interfaces import MixtureOfExperts, SupportsPP -from .utils import (PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) -class Step3p5MLP(nn.Module): +class Step3p5MLP(nn.Module): def __init__( self, config: ModelConfig, hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, reduce_results: bool = True, prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") + prefix=f"{prefix}.gate_up_proj", + ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() self.prefix = prefix self.hidden_size = hidden_size self.limit = None layer_idx = int(prefix.split("layers.")[1].split(".")[0]) - if config.swiglu_limits_shared and config.swiglu_limits_shared[ - layer_idx] is not None and config.swiglu_limits_shared[ - layer_idx] != 0: + if ( + config.swiglu_limits_shared + and config.swiglu_limits_shared[layer_idx] is not None + and config.swiglu_limits_shared[layer_idx] != 0 + ): self.limit = config.swiglu_limits_shared[layer_idx] self.act_fn = SwigluStepAndMul(limit=self.limit) @@ -91,30 +107,30 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: output, _ = self.down_proj(intermediate_act) return output -class Step3p5Attention(nn.Module): +class Step3p5Attention(nn.Module): def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, max_position: int = 4096 * 32, - head_dim: Optional[int] = None, + head_dim: int | None = None, rms_norm_eps: float = 1e-06, qkv_bias: bool = False, - rope_theta: Optional[Union[float, list[float]]] = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[dict[str, Any]] = None, + rope_theta: float | list[float] | None = 10000, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + rope_scaling: dict[str, Any] | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, # Step3p5 specific args - sliding_window: Optional[int] = None, + sliding_window: int | None = None, use_head_wise_attn_gate: bool = False, layer_types: list = None, use_rope_layers: list = None, yarn_only_types: list = None, - swa_num_attention_heads: Optional[int] = None, + swa_num_attention_heads: int | None = None, partial_rotary_factor: float = 1.0, ): super().__init__() @@ -123,16 +139,14 @@ def __init__( tp_size = get_tensor_model_parallel_world_size() self.layer_idx = extract_layer_index(prefix) if layer_types: - enable_sliding_window = layer_types[ - self.layer_idx] == "sliding_attention" + enable_sliding_window = layer_types[self.layer_idx] == "sliding_attention" else: enable_sliding_window = self.layer_idx % 2 == 0 - if yarn_only_types and layer_types[ - self.layer_idx] not in yarn_only_types: + if yarn_only_types and layer_types[self.layer_idx] not in yarn_only_types: rope_scaling = None if sliding_window is not None and enable_sliding_window: - sliding_window = (sliding_window) + sliding_window = sliding_window if swa_num_attention_heads is not None: num_heads = swa_num_attention_heads self.total_num_heads = swa_num_attention_heads @@ -223,15 +237,15 @@ def __init__( self.max_position_embeddings = max_position assert self.partial_rotary_factor == 1 or self.partial_rotary_factor == 0.5 - self.rotary_dim = self.head_dim if self.partial_rotary_factor == 1 else self.head_dim // 2 + self.rotary_dim = ( + self.head_dim if self.partial_rotary_factor == 1 else self.head_dim // 2 + ) def qk_norm_rope(self, q, k, positions): - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, - self.head_dim) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = self.q_norm(q_by_head.contiguous()) q = q_by_head.view(q.shape) - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, - self.head_dim) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) k_by_head = self.k_norm(k_by_head.contiguous()) k = k_by_head.view(k.shape) if self.use_rope: @@ -244,28 +258,30 @@ def forward( hidden_states: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.qk_norm_rope(q, k, positions) attn_output = self.attn(q, k, v) if self.use_head_wise_attn_gate: extra_dims, _ = self.g_proj(hidden_states) - output = attn_output.view( - *attn_output.shape[:-1], self.num_heads, - self.head_dim) * extra_dims.unsqueeze(-1).sigmoid() + output = ( + attn_output.view(*attn_output.shape[:-1], self.num_heads, self.head_dim) + * extra_dims.unsqueeze(-1).sigmoid() + ) attn_output = output.view(*attn_output.shape) output, _ = self.o_proj(attn_output) return output -class FusedMoEBlock(nn.Module): - def __init__(self, - config: ModelConfig, - parallel_config: ParallelConfig, - shared_experts: torch.nn.Module, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - prefix: str = ""): +class FusedMoEBlock(nn.Module): + def __init__( + self, + config: ModelConfig, + parallel_config: ParallelConfig, + shared_experts: torch.nn.Module, + quant_config: QuantizationConfig | None = None, + reduce_results: bool = True, + prefix: str = "", + ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.layer_idx = extract_layer_index(prefix) @@ -280,26 +296,30 @@ def __init__(self, self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size - self.physical_expert_start = (self.ep_rank * - self.n_local_physical_experts) - self.physical_expert_end = (self.physical_expert_start + - self.n_local_physical_experts) + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) if self.tp_size > config.moe_num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.moe_num_experts}.") + f"the number of experts {config.moe_num_experts}." + ) assert config.moe_dynamic_exp_p == 1, "Only support dynamic exp p=1" self.use_moe_router_bias = config.use_moe_router_bias - assert self.use_moe_router_bias == True, "Only support use_moe_router_bias == true." + assert self.use_moe_router_bias, "Only support use_moe_router_bias is true." self.routed_scaling_factor = config.moe_router_scaling_factor - self.router_bias = nn.Parameter(torch.zeros(config.moe_num_experts, - dtype=torch.float32), - requires_grad=False) + self.router_bias = nn.Parameter( + torch.zeros(config.moe_num_experts, dtype=torch.float32), + requires_grad=False, + ) self.need_fp32_gate = config.need_fp32_gate - assert self.need_fp32_gate, "Router logits must use FP32 precision for numerical stability." + assert self.need_fp32_gate, ( + "Router logits must use FP32 precision for numerical stability." + ) layer_idx = int(prefix.split("layers.")[1].split(".")[0]) activation = "silu" swiglu_limits = config.swiglu_limits or [] @@ -308,13 +328,17 @@ def __init__(self, ) if swiglu_limit not in (None, 0): swiglu_limit = float(swiglu_limit) - assert swiglu_limit == 7.0, "Swiglu limit in fused moe block only suport 7.0 now." + assert swiglu_limit == 7.0, ( + "Swiglu limit in fused moe block only suport 7.0 now." + ) activation = "swiglustep" logger.info( - f"step3p5 layer_idx: {layer_idx}, activation: {activation}, " - f"limit: {swiglu_limit}" + "step3p5 layer_idx: %s, activation: %s, limit: %s", + layer_idx, + activation, + swiglu_limit, ) - + self.experts = SharedFusedMoE( shared_experts=shared_experts, num_experts=config.moe_num_experts, @@ -331,15 +355,21 @@ def __init__(self, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, ) - self.gate = ReplicatedLinear(config.hidden_size, - config.moe_num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") - - def router_bias_func(self, hidden_states: torch.Tensor, - gating_output: torch.Tensor, topk: int, - renormalize: bool): + self.gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + def router_bias_func( + self, + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + ): gate_prob = torch.sigmoid(gating_output.float()) gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0) _, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1) @@ -347,32 +377,34 @@ def router_bias_func(self, hidden_states: torch.Tensor, expert_topk_weight = topk_prob if renormalize: expert_topk_weight = expert_topk_weight / ( - torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20) + torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20 + ) expert_topk_weight *= self.routed_scaling_factor return expert_topk_weight, indices.to(torch.int32) - def forward( - self, - hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) - router_logits = hidden_states.to( - torch.float32) @ self.gate.weight.to(torch.float32).t() + router_logits = ( + hidden_states.to(torch.float32) @ self.gate.weight.to(torch.float32).t() + ) shared_out, final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits) + hidden_states=hidden_states, router_logits=router_logits + ) return shared_out, final_hidden_states.view(orig_shape) class Step3p5DecoderLayer(nn.Module): - - def __init__(self, - config: ModelConfig, - parallel_config: ParallelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: ModelConfig, + parallel_config: ParallelConfig, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: super().__init__() config = config.hf_config self.hidden_size = config.hidden_size @@ -384,41 +416,46 @@ def __init__(self, num_attention_heads = None num_attention_groups = None head_dim = None - if getattr(config, "attention_other_setting", None) and getattr( - config, "layer_types", []) and config.layer_types[ - layer_idx] == config.attention_other_setting[ - 'attention_type']: + if ( + getattr(config, "attention_other_setting", None) + and getattr(config, "layer_types", []) + and config.layer_types[layer_idx] + == config.attention_other_setting["attention_type"] + ): num_attention_heads = config.attention_other_setting[ - 'num_attention_heads'] + "num_attention_heads" + ] num_attention_groups = config.attention_other_setting[ - 'num_attention_groups'] - head_dim = config.attention_other_setting['head_dim'] - partial_rotary_factors = getattr(config, "partial_rotary_factors", - []) + "num_attention_groups" + ] + head_dim = config.attention_other_setting["head_dim"] + partial_rotary_factors = getattr(config, "partial_rotary_factors", []) self.self_attn = Step3p5Attention( hidden_size=self.hidden_size, num_heads=num_attention_heads - if num_attention_heads else config.num_attention_heads, + if num_attention_heads + else config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=num_attention_groups - if num_attention_groups else config.num_attention_groups, + if num_attention_groups + else config.num_attention_groups, rope_theta=config.rope_theta, rms_norm_eps=config.rms_norm_eps, - qkv_bias=getattr(config, 'attention_bias', False), - head_dim=head_dim if head_dim else getattr( - config, 'head_dim', None), + qkv_bias=getattr(config, "attention_bias", False), + head_dim=head_dim if head_dim else getattr(config, "head_dim", None), cache_config=cache_config, quant_config=quant_config, rope_scaling=getattr(config, "rope_scaling", None), - sliding_window=getattr(config, 'sliding_window', None), - use_head_wise_attn_gate=getattr(config, - "use_head_wise_attn_gate", - False), + sliding_window=getattr(config, "sliding_window", None), + use_head_wise_attn_gate=getattr( + config, "use_head_wise_attn_gate", False + ), layer_types=getattr(config, "layer_types", []), use_rope_layers=getattr(config, "use_rope_layers", []), yarn_only_types=getattr(config, "yarn_only_types", []), partial_rotary_factor=partial_rotary_factors[layer_idx] - if partial_rotary_factors else 1.0, + if partial_rotary_factors + else 1.0, prefix=f"{prefix}.self_attn", ) else: @@ -427,8 +464,10 @@ def __init__(self, ) self.use_moe = False self.tp_group = get_tp_group() - self.use_fused_all_reduce = get_tensor_model_parallel_world_size( - ) > 1 and get_dp_group().world_size == 1 + self.use_fused_all_reduce = ( + get_tensor_model_parallel_world_size() > 1 + and get_dp_group().world_size == 1 + ) if self.use_fused_all_reduce: logger.warning_once("Enable custom fused all reduce...") else: @@ -436,15 +475,16 @@ def __init__(self, moe_layers_enum = getattr(config, "moe_layers_enum", None) if moe_layers_enum is not None: - moe_layers_idx = [ - int(i) for i in moe_layers_enum.strip().split(',') - ] + moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] else: moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] if layer_idx in moe_layers_idx: reduce_results = True - if self.use_fused_all_reduce or self.tp_group.world_size == 1 and get_ep_group( - ).world_size == 1: + if ( + self.use_fused_all_reduce + or self.tp_group.world_size == 1 + and get_ep_group().world_size == 1 + ): reduce_results = False self.share_expert = Step3p5MLP( config=config, @@ -453,34 +493,43 @@ def __init__(self, hidden_act="silu", reduce_results=reduce_results, quant_config=quant_config, - prefix=f"{prefix}.share_expert") - self.moe = FusedMoEBlock(shared_experts=self.share_expert, - config=config, - parallel_config=parallel_config, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.moe") + prefix=f"{prefix}.share_expert", + ) + self.moe = FusedMoEBlock( + shared_experts=self.share_expert, + config=config, + parallel_config=parallel_config, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.moe", + ) self.use_moe = True else: - self.mlp = Step3p5MLP(config=config, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act="silu", - quant_config=quant_config, - reduce_results=True, - prefix=f"{prefix}.mlp") + self.mlp = Step3p5MLP( + config=config, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act="silu", + quant_config=quant_config, + reduce_results=True, + prefix=f"{prefix}.mlp", + ) self.input_layernorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, config.rms_norm_eps + ) self.prefix = prefix - def add_and_maybe_inplace_all_reduce(self, in1: torch.Tensor, - in2: torch.Tensor) -> torch.Tensor: + def add_and_maybe_inplace_all_reduce( + self, in1: torch.Tensor, in2: torch.Tensor + ) -> torch.Tensor: if not self.use_fused_all_reduce: return in1 + in2 return self.tp_group.all_reduce(in1 + in2) - def forward(self, positions: torch.Tensor, - hidden_states: torch.Tensor) -> torch.Tensor: + def forward( + self, positions: torch.Tensor, hidden_states: torch.Tensor + ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -495,18 +544,17 @@ def forward(self, positions: torch.Tensor, if self.use_moe: shared_output, moe_output = self.moe(hidden_states) ffn_output = self.add_and_maybe_inplace_all_reduce( - moe_output, shared_output) + moe_output, shared_output + ) else: ffn_output = self.mlp(hidden_states) hidden_states = ffn_output + residual return hidden_states + @support_torch_compile class Step3p5Model(nn.Module): - - def __init__(self, - vllm_config: VllmConfig, - prefix: str = "") -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -516,8 +564,9 @@ def __init__(self, self.moe_num_experts = config.moe_num_experts - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -532,7 +581,8 @@ def __init__(self, parallel_config=vllm_config.parallel_config, cache_config=cache_config, quant_config=quant_config, - prefix=prefix), + prefix=prefix, + ), prefix=f"{prefix}.layers", ) if get_pp_group().is_last_rank: @@ -540,9 +590,9 @@ def __init__(self, else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -551,8 +601,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -567,15 +617,16 @@ def forward( hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - }) + return IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) return hidden_states class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): - def __init__( self, *, @@ -588,8 +639,9 @@ def __init__( self.config = config self.vllm_config = vllm_config - self.model = Step3p5Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Step3p5Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.moe_layers: list[FusedMoEBlock] = [] for layer in self.model.layers: @@ -608,15 +660,18 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, + ) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) # Set MoE hyperparameters self.expert_weights = [] @@ -631,13 +686,16 @@ def __init__( self.num_routed_experts = example_layer.n_routed_experts self.num_redundant_experts = example_layer.n_redundant_experts - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None): - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -674,8 +732,7 @@ def update_physical_experts_metadata( assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = (num_physical_experts - - self.num_logical_experts) + self.num_redundant_experts = num_physical_experts - self.num_logical_experts for layer in self.moe_layers: assert isinstance(layer, FusedMoEBlock) layer.n_local_physical_experts = num_local_physical_experts @@ -703,7 +760,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): expert_params_mapping = [ (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), - (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2") + (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"), ] disable_moe_stacked_params = [data[1] for data in expert_params_mapping] @@ -713,12 +770,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): if spec_layer is not None: continue # skip spec decode layers for main model - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - if any(disable_moe_stacked_param in name - for disable_moe_stacked_param in - disable_moe_stacked_params): + if any( + disable_moe_stacked_param in name + for disable_moe_stacked_param in disable_moe_stacked_params + ): continue name = name.replace(weight_name, param_name) if is_pp_missing_parameter(name, self): @@ -738,8 +796,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): if is_pp_missing_parameter(name, self): continue # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader @@ -747,16 +806,22 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): assert loaded_weight.shape[0] == moe_expert_num for expert_id in range(moe_expert_num): loaded_weight_expert = loaded_weight[expert_id] - weight_loader(param, - loaded_weight_expert, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight_expert, + name, + shard_id=shard_id, + expert_id=expert_id, + ) loaded_params.add(name) break else: - for (param_name, weight_name, start_idx, - end_idx) in qkv_params_mapping: + for ( + param_name, + weight_name, + start_idx, + end_idx, + ) in qkv_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -766,8 +831,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): dim = param.shape[param.output_dim] begin_idx = int(start_idx * dim) end_idx = int(end_idx * dim) - param_slice = param.narrow(param.output_dim, begin_idx, - end_idx - begin_idx) + param_slice = param.narrow( + param.output_dim, begin_idx, end_idx - begin_idx + ) param_slice.copy_(loaded_weight) loaded_params.add(name) break @@ -778,20 +844,22 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): logger.warning_once("ignore expert_bias") continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -def get_spec_layer_idx_from_weight_name(config: ModelConfig, - weight_name: str) -> Optional[int]: - if hasattr(config, - "num_nextn_predict_layers") and (config.num_nextn_predict_layers - > 0): +def get_spec_layer_idx_from_weight_name( + config: ModelConfig, weight_name: str +) -> int | None: + if hasattr(config, "num_nextn_predict_layers") and ( + config.num_nextn_predict_layers > 0 + ): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): - if weight_name.startswith(f"model.layers.{layer_idx+i}."): + if weight_name.startswith(f"model.layers.{layer_idx + i}."): return layer_idx + i return None diff --git a/vllm/model_executor/models/step3p5_mtp.py b/vllm/model_executor/models/step3p5_mtp.py index 967e6b54202c..365be3fc3f52 100644 --- a/vllm/model_executor/models/step3p5_mtp.py +++ b/vllm/model_executor/models/step3p5_mtp.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn @@ -9,13 +8,16 @@ from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from vllm.model_executor.layers.layernorm import GemmaRMSNorm + from .step3p5 import Step3p5DecoderLayer, get_spec_layer_idx_from_weight_name from .utils import maybe_prefix @@ -23,53 +25,51 @@ class SharedHead(nn.Module): - def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.norm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) - self.head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(hidden_states) class Step3p5AMultiTokenPredictorLayer(nn.Module): - def __init__( self, config: PretrainedConfig, prefix: str, model_config: ModelConfig, parallel_config: ParallelConfig = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) - self.eh_proj = nn.Linear(config.hidden_size * 2, - config.hidden_size, - bias=False) + self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) self.shared_head = SharedHead(config=config, quant_config=quant_config) - self.mtp_block = Step3p5DecoderLayer(model_config, - parallel_config=parallel_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.mtp_block") + self.mtp_block = Step3p5DecoderLayer( + model_config, + parallel_config=parallel_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.mtp_block", + ) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, spec_step_index: int = 0, ) -> torch.Tensor: assert inputs_embeds is not None @@ -77,15 +77,14 @@ def forward( previous_hidden_states = self.hnorm(previous_hidden_states) hidden_states = self.eh_proj( - torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + torch.cat([inputs_embeds, previous_hidden_states], dim=-1) + ) - hidden_states = self.mtp_block(positions=positions, - hidden_states=hidden_states) + hidden_states = self.mtp_block(positions=positions, hidden_states=hidden_states) return hidden_states class Step3p5AMultiTokenPredictor(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -96,19 +95,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.mtp_start_layer_idx = config.num_hidden_layers self.num_mtp_layers = config.num_nextn_predict_layers # to map the exact layer index from weights - self.layers = torch.nn.ModuleDict({ - str(idx): - Step3p5AMultiTokenPredictorLayer( - config, - f"{prefix}.layers.{idx}", - model_config=vllm_config.model_config, - parallel_config=vllm_config.parallel_config, - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, - ) - for idx in range(self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers) - }) + self.layers = torch.nn.ModuleDict( + { + str(idx): Step3p5AMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + parallel_config=vllm_config.parallel_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range( + self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers, + ) + } + ) self.logits_processor = LogitsProcessor(config.vocab_size) @@ -117,12 +119,12 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - current_step_idx = (spec_step_idx % self.num_mtp_layers) + current_step_idx = spec_step_idx % self.num_mtp_layers return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( input_ids, positions, @@ -136,11 +138,11 @@ def compute_logits( hidden_states: torch.Tensor, spec_step_idx: int = 0, ) -> torch.Tensor: - current_step_idx = (spec_step_idx % self.num_mtp_layers) - mtp_layer = self.layers[str(self.mtp_start_layer_idx + - current_step_idx)] - logits = self.logits_processor(mtp_layer.shared_head.head, - mtp_layer.shared_head(hidden_states)) + current_step_idx = spec_step_idx % self.num_mtp_layers + mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] + logits = self.logits_processor( + mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states) + ) return logits def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -148,14 +150,13 @@ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: class Step3p5MTP(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config self.vllm_config = vllm_config - self.model = Step3p5AMultiTokenPredictor(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) + self.model = Step3p5AMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) @@ -165,23 +166,23 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, hidden_states, - inputs_embeds, spec_step_idx) + hidden_states = self.model( + input_ids, positions, hidden_states, inputs_embeds, spec_step_idx + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, spec_step_idx: int = 0, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.model.compute_logits(hidden_states, spec_step_idx) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: vllm_config = self.vllm_config config = vllm_config.model_config.hf_config @@ -197,7 +198,7 @@ def load_weights(self, weights: Iterable[tuple[str, expert_params_mapping = [ (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), - (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2") + (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"), ] params_dict = dict(self.named_parameters()) @@ -209,7 +210,7 @@ def load_weights(self, weights: Iterable[tuple[str, if "embed_tokens" not in name and spec_layer is None: continue name = self._rewrite_spec_layer_name(spec_layer, name) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -219,7 +220,7 @@ def load_weights(self, weights: Iterable[tuple[str, # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue if "experts" in name or "moe" in name: continue @@ -239,40 +240,46 @@ def load_weights(self, weights: Iterable[tuple[str, continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader for expert_id in range(loaded_weight.shape[0]): loaded_weight_expert = loaded_weight[expert_id] - weight_loader(param, - loaded_weight_expert, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight_expert, + name, + shard_id=shard_id, + expert_id=expert_id, + ) loaded_params.add(name) break else: # Skip loading extra bias for GPTQ models. - if name.endswith( - ".bias" - ) and name not in params_dict or "tok_embeddings" in name: + if ( + name.endswith(".bias") + and name not in params_dict + or "tok_embeddings" in name + ): continue if f"{config.num_hidden_layers}.transformer." in name: name = name.replace(".transformer.", ".") if "shared_head" in name: - name = name.replace("shared_head.output", - "shared_head.head") + name = name.replace("shared_head.output", "shared_head.head") if "embed_tokens" in name: - assert hasattr( - self.config, "num_nextn_predict_layers" - ) and self.config.num_nextn_predict_layers > 0 + assert ( + hasattr(self.config, "num_nextn_predict_layers") + and self.config.num_nextn_predict_layers > 0 + ) name = "model.embed_tokens.weight" param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) params_need_to_load = set(params_dict.keys()) @@ -290,7 +297,9 @@ def load_weights(self, weights: Iterable[tuple[str, missing_params = list(params_need_to_load - loaded_params) param_name_example = missing_params[0] raise RuntimeError( - f"Some parameters like {param_name_example} are not in the checkpoint and will falsely use random initialization" + "Some parameters like " + f"{param_name_example} are not in the checkpoint and will falsely " + "use random initialization" ) return loaded_params @@ -300,7 +309,11 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: Add .mtp_block for modules in transformer layer block for spec layer """ spec_layer_weight_names = [ - "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + "embed_tokens", + "enorm", + "hnorm", + "eh_proj", + "shared_head", ] spec_layer_weight = False for weight_name in spec_layer_weight_names: @@ -309,6 +322,7 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: break if not spec_layer_weight: # treat rest weights as weights for transformer layer block - name = name.replace(f"model.layers.{spec_layer}.", - f"model.layers.{spec_layer}.mtp_block.") + name = name.replace( + f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block." + ) return name diff --git a/vllm/reasoning/step3p5_reasoning_parser.py b/vllm/reasoning/step3p5_reasoning_parser.py index f558f59d7407..b93f551426fb 100644 --- a/vllm/reasoning/step3p5_reasoning_parser.py +++ b/vllm/reasoning/step3p5_reasoning_parser.py @@ -3,16 +3,15 @@ from collections.abc import Sequence -from vllm.tokenizers import TokenizerLike - from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ) +from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.entrypoints.openai.responses.protocol import ( ResponsesRequest, ) -from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser +from vllm.tokenizers import TokenizerLike class Step3p5ReasoningParser(BaseThinkingReasoningParser): diff --git a/vllm/tool_parsers/step3p5_tool_parser.py b/vllm/tool_parsers/step3p5_tool_parser.py index 32704079e965..b7c8699a03db 100644 --- a/vllm/tool_parsers/step3p5_tool_parser.py +++ b/vllm/tool_parsers/step3p5_tool_parser.py @@ -30,6 +30,7 @@ logger = init_logger(__name__) + class StreamingXMLToolCallParser: """ Simplified streaming XML tool call parser @@ -108,15 +109,27 @@ def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage: new_deltas = self.deltas[initial_delta_count:] # If this chunk contains # but didn't generate '}', then complete it - if (self.current_call_id is not None - and self.function_end_token in xml_chunk): + if ( + self.current_call_id is not None + and self.function_end_token in xml_chunk + ): # - Added '}' (non-empty parameter ending) # - Added '{}' (empty parameter function) - has_function_close = any((td.tool_calls and any( - (tc.function and tc.id == self.current_call_id - and isinstance(tc.function.arguments, str) and - (tc.function.arguments in ("}", "{}"))) - for tc in td.tool_calls)) for td in new_deltas) + has_function_close = any( + ( + td.tool_calls + and any( + ( + tc.function + and tc.id == self.current_call_id + and isinstance(tc.function.arguments, str) + and (tc.function.arguments in ("}", "{}")) + ) + for tc in td.tool_calls + ) + ) + for td in new_deltas + ) if not has_function_close: # Close potentially unclosed element if self.current_param_name: @@ -125,12 +138,25 @@ def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage: self._end_element("function") # If this chunk contains # but didn't generate final empty delta, then complete it - if (self.current_call_id is not None - and self.tool_call_end_token in xml_chunk): - has_toolcall_close = any((td.tool_calls and any( - (tc.type == "function" and tc.function and tc.function. - arguments == "" and tc.id == self.current_call_id) - for tc in td.tool_calls)) for td in new_deltas) + if ( + self.current_call_id is not None + and self.tool_call_end_token in xml_chunk + ): + has_toolcall_close = any( + ( + td.tool_calls + and any( + ( + tc.type == "function" + and tc.function + and tc.function.arguments == "" + and tc.id == self.current_call_id + ) + for tc in td.tool_calls + ) + ) + for td in new_deltas + ) if not has_toolcall_close: # Close potentially unclosed element if self.current_param_name: @@ -142,7 +168,8 @@ def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage: logger.warning("Error with fallback parsing: %s", e) # Merge newly generated deltas into single response result_delta = self._merge_new_deltas_to_single_response( - initial_delta_count) + initial_delta_count + ) return result_delta else: # No complete elements, check if there's unoutput text content @@ -160,8 +187,9 @@ def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage: # to prevent accidentally closing new calls # in multi scenarios if self.current_call_id is not None and ( - self.function_end_token in xml_chunk - or self.tool_call_end_token in xml_chunk): + self.function_end_token in xml_chunk + or self.tool_call_end_token in xml_chunk + ): # Close potentially unclosed element if self.current_param_name: self._end_element("parameter") @@ -171,7 +199,8 @@ def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage: self._end_element("tool_call") # Return the merged delta result generated by this fallback result_delta = self._merge_new_deltas_to_single_response( - initial_delta_count) + initial_delta_count + ) return result_delta # No complete elements, return empty response @@ -209,8 +238,7 @@ def _process_complete_xml_elements(self) -> bool: while self.last_processed_pos < len(self.streaming_buffer): # Find next complete xml element - element, end_pos = self._find_next_complete_element( - self.last_processed_pos) + element, end_pos = self._find_next_complete_element(self.last_processed_pos) if element is None: # No complete element found, wait for more data break @@ -224,10 +252,13 @@ def _process_complete_xml_elements(self) -> bool: try: preprocessed_element = self._preprocess_xml_chunk(element) # Check if this is the first tool_call start - if ((preprocessed_element.strip().startswith("") or - preprocessed_element.strip().startswith("") + or preprocessed_element.strip().startswith("") - and self.tool_call_index > 0 and self.current_call_id - and self.current_function_name): + if ( + preprocessed_element.strip().startswith("") + and self.tool_call_index > 0 + and self.current_call_id + and self.current_function_name + ): # Reset parser state but preserve generated deltas if self.current_param_name: self._end_element("parameter") @@ -255,8 +289,7 @@ def _process_complete_xml_elements(self) -> bool: index=self.tool_call_index - 1, id=self.current_call_id, type="function", - function=DeltaFunctionCall(name=None, - arguments=""), + function=DeltaFunctionCall(name=None, arguments=""), ) ], ) @@ -277,9 +310,12 @@ def _process_complete_xml_elements(self) -> bool: def _fix_incomplete_tag_in_chunk(self, chunk: str) -> str: """ - Fallback: fix incomplete ) - Examples: , - Also handles missing = cases: -> , -> + Fallback: fix incomplete ) + Examples: , + Also handles missing = cases: -> , + -> Only fixes tags that pass validation (parameter exists in tool definition) """ # First, handle missing = cases for function tags @@ -296,15 +332,19 @@ def _fix_incomplete_tag_in_chunk(self, chunk: str) -> str: lt_pos = after_tag.find("<", len(pattern)) # Skip if already well-formed - if (gt_pos != -1 and (lt_pos == -1 or gt_pos < lt_pos) - and pattern in after_tag[:gt_pos]): + if ( + gt_pos != -1 + and (lt_pos == -1 or gt_pos < lt_pos) + and pattern in after_tag[:gt_pos] + ): continue # Extract tag name (stop at space, newline, or <) - content = chunk[start_idx + len(pattern):] + content = chunk[start_idx + len(pattern) :] end_pos = next( - (i for i, ch in enumerate(content) if ch in (' ', '\n', '<')), - len(content)) + (i for i, ch in enumerate(content) if ch in (" ", "\n", "<")), + len(content), + ) tag_name = content[:end_pos] if not tag_name: @@ -312,61 +352,62 @@ def _fix_incomplete_tag_in_chunk(self, chunk: str) -> str: # Remove duplicate prefix: ', 1) + chunk = chunk.replace( + f"<{tag_type}={content[:end_pos]}", f"<{tag_type}={tag_name}>", 1 + ) return chunk def _fix_missing_equals_in_function_tag(self, chunk: str) -> str: """ Fix missing = in function tags: or - Examples: + Examples: -> -> Only fixes if function name exists in tool definition """ # already correct - if '" match1 = re.search(pattern1, chunk) if match1: func_name = match1.group(1).strip() # must validate function name exists before fixing if func_name and self._validate_function_name(func_name): original = match1.group(0) - fixed = f'' + fixed = f"" chunk = chunk.replace(original, fixed, 1) return chunk # Pattern 2: (no space, no =) # only match ' + pattern2 = r"" match2 = re.search(pattern2, chunk) if match2: func_name = match2.group(1).strip() # must validate function name exists before fixing if func_name and self._validate_function_name(func_name): original = match2.group(0) - fixed = f'' + fixed = f"" chunk = chunk.replace(original, fixed, 1) return chunk @@ -378,10 +419,13 @@ def _validate_function_name(self, func_name: str) -> bool: return False for tool in self.tools: - if (hasattr(tool, "type") and tool.type == "function" - and hasattr(tool, "function") - and hasattr(tool.function, "name") - and tool.function.name == func_name): + if ( + hasattr(tool, "type") + and tool.type == "function" + and hasattr(tool, "function") + and hasattr(tool.function, "name") + and tool.function.name == func_name + ): return True return False @@ -392,10 +436,13 @@ def _validate_parameter_name(self, param_name: str) -> bool: return True for tool in self.tools: - if (hasattr(tool, "type") and tool.type == "function" - and hasattr(tool, "function") - and hasattr(tool.function, "name") - and tool.function.name == self.current_function_name): + if ( + hasattr(tool, "type") + and tool.type == "function" + and hasattr(tool, "function") + and hasattr(tool.function, "name") + and tool.function.name == self.current_function_name + ): if not hasattr(tool.function, "parameters"): return True params = tool.function.parameters @@ -418,9 +465,11 @@ def _should_skip_element(self, element: str) -> bool: """ # If it's a tool_call XML tag, don't skip - if (element.startswith(self.tool_call_start_token) - or element.startswith(self.function_start_token) - or element.startswith(self.parameter_start_token)): + if ( + element.startswith(self.tool_call_start_token) + or element.startswith(self.function_start_token) + or element.startswith(self.parameter_start_token) + ): return False # If currently not parsing tool calls and not blank, @@ -440,8 +489,7 @@ def _should_skip_element(self, element: str) -> bool: # Skip blank content return not element - def _find_next_complete_element(self, - start_pos: int) -> tuple[str | None, int]: + def _find_next_complete_element(self, start_pos: int) -> tuple[str | None, int]: """ Find next complete XML element from specified position @@ -460,10 +508,12 @@ def _find_next_complete_element(self, if buffer.startswith("<"): # Check if this is an incomplete parameter/function tag # e.g., " not in buffer.split("\n")[0]) - is_incomplete_func = (buffer.startswith("" not in buffer.split("\n")[0]) + is_incomplete_param = ( + buffer.startswith("" not in buffer.split("\n")[0] + ) + is_incomplete_func = ( + buffer.startswith("" not in buffer.split("\n")[0] + ) if is_incomplete_param or is_incomplete_func: # Find the corresponding closing tag @@ -473,9 +523,8 @@ def _find_next_complete_element(self, if closing_pos != -1: # Found closing tag, return complete element including closing tag - complete_element = buffer[:closing_pos + len(closing_tag)] - return complete_element, start_pos + closing_pos + len( - closing_tag) + complete_element = buffer[: closing_pos + len(closing_tag)] + return complete_element, start_pos + closing_pos + len(closing_tag) # Need to ensure no new < appears, # find the nearest one between < and > @@ -487,21 +536,23 @@ def _find_next_complete_element(self, return buffer[:tag_end], start_pos + tag_end # Next nearest is >, means found XML element else: - return buffer[:tag_end2 + 1], start_pos + tag_end2 + 1 + return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 elif tag_end != -1: return buffer[:tag_end], start_pos + tag_end elif tag_end2 != -1: - return buffer[:tag_end2 + 1], start_pos + tag_end2 + 1 + return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 else: # If currently not parsing tool calls (entering a tool_call), # check if starts with or - if buffer == ""[:len(buffer)]: + if buffer == ""[: len(buffer)]: # Might be start of , wait for more data return None, start_pos - elif (buffer.startswith(" str: # Check if this is a tool_call related element is_tool_call = False if chunk.startswith(self.tool_call_start_token) or chunk.startswith( - self.tool_call_end_token): + self.tool_call_end_token + ): is_tool_call = True # Check for function tags (including malformed ones without =) # , , , - if (chunk.startswith(self.function_start_token) - or chunk.startswith(self.function_end_token) - or chunk.startswith(" + # Fallback: fix incomplete # This handles cases like: format -> - processed = re.sub(r"]+)>", r'', - chunk) + processed = re.sub(r"]+)>", r'', chunk) # Handle format -> - processed = re.sub(r"]+)>", r'', - processed) + processed = re.sub(r"]+)>", r'', processed) original_chunk = chunk # If in parameter value accumulation mode @@ -655,30 +713,38 @@ def _preprocess_xml_chunk(self, chunk: str) -> str: # and pass through directly if self._pre_param_buffer == "": # Get current parameter type - param_type = (self._get_param_type( - self._pre_current_param_name) if - self._pre_current_param_name else "string") + param_type = ( + self._get_param_type(self._pre_current_param_name) + if self._pre_current_param_name + else "string" + ) # Only these types need deferred parsing to # handle Python literals containing single quotes is_object_type = param_type in ["object"] - is_complex_type = (param_type - in ["array", "arr", "sequence"] - or param_type.startswith("dict") - or param_type.startswith("list")) + is_complex_type = ( + param_type in ["array", "arr", "sequence"] + or param_type.startswith("dict") + or param_type.startswith("list") + ) # Only delay when contains container symbols # and has single quotes and is complex type - has_container_hint = (("[" in original_chunk) - or ("{" in original_chunk) - or ("(" in original_chunk)) + has_container_hint = ( + ("[" in original_chunk) + or ("{" in original_chunk) + or ("(" in original_chunk) + ) # Determine if deferred parsing is needed need_defer = False if is_complex_type: # Complex type, always need deferred parsing need_defer = True - elif (is_object_type and has_container_hint - and ("'" in original_chunk)): + elif ( + is_object_type + and has_container_hint + and ("'" in original_chunk) + ): # Object type with container symbols # and single quotes, need deferred parsing need_defer = True @@ -711,8 +777,7 @@ def _emit_delta(self, delta: DeltaMessage): """Emit Delta response (streaming output)""" self.deltas.append(delta) - def _auto_close_open_parameter_if_needed(self, - incoming_tag: str | None = None): + def _auto_close_open_parameter_if_needed(self, incoming_tag: str | None = None): """Before starting to process new elements, if there are unclosed tags from before, automatically complete their endings to the parser. @@ -729,8 +794,7 @@ def _auto_close_open_parameter_if_needed(self, # If about to start new function or tool_call, # and there are unclosed functions, close function first - if incoming_tag in ("function", - "tool_call") and self.current_function_name: + if incoming_tag in ("function", "tool_call") and self.current_function_name: self._end_element("function") # If about to start new tool_call, @@ -764,15 +828,18 @@ def _start_element(self, name: str, attrs: dict[str, str]): self.current_function_name = function_name self.current_function_open = True if function_name: - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.tool_call_index - 1, - id=self.current_call_id, - type="function", - function=DeltaFunctionCall(name=function_name, - arguments=""), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=function_name, arguments="" + ), + ) + ] + ) self._emit_delta(delta) elif name.startswith("parameter") or (name == "parameter"): # If previous parameter hasn't ended normally, @@ -792,30 +859,36 @@ def _start_element(self, name: str, attrs: dict[str, str]): # First parameter # start JSON, only output parameter name and colon json_start = f'{{"{param_name}": ' - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.tool_call_index - 1, - id=self.current_call_id, - type="function", - function=DeltaFunctionCall(name=None, - arguments=json_start), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=json_start + ), + ) + ] + ) self._emit_delta(delta) self.current_param_is_first = True else: # Subsequent parameters # add comma and parameter name, no quotes json_continue = f', "{param_name}": ' - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.tool_call_index - 1, - id=self.current_call_id, - type="function", - function=DeltaFunctionCall( - name=None, arguments=json_continue), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=json_continue + ), + ) + ] + ) self._emit_delta(delta) self.current_param_is_first = False @@ -843,17 +916,20 @@ def _char_data(self, data: str): data = data[1:] # Output start quote for string type (if not already output) - if (param_type - in ["string", "str", "text", "varchar", "char", "enum"] - and not self.start_quote_emitted): - quote_delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.tool_call_index - 1, - id=self.current_call_id, - type="function", - function=DeltaFunctionCall(name=None, arguments='"'), - ) - ]) + if ( + param_type in ["string", "str", "text", "varchar", "char", "enum"] + and not self.start_quote_emitted + ): + quote_delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='"'), + ) + ] + ) self._emit_delta(quote_delta) self.start_quote_emitted = True @@ -872,22 +948,23 @@ def _char_data(self, data: str): # convert parameter value by param_type converted_value = self._convert_param_value( - self.current_param_value, param_type) - output_data = self._convert_for_json_streaming( - converted_value, param_type) + self.current_param_value, param_type + ) + output_data = self._convert_for_json_streaming(converted_value, param_type) - delta_data = output_data[len(self.current_param_value_converted):] + delta_data = output_data[len(self.current_param_value_converted) :] self.current_param_value_converted = output_data - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.tool_call_index - 1, - id=self.current_call_id, - type="function", - function=DeltaFunctionCall(name=None, - arguments=delta_data), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments=delta_data), + ) + ] + ) self._emit_delta(delta) def _end_element(self, name: str): @@ -898,12 +975,14 @@ def _end_element(self, name: str): # If function or tool_call ends and there are still unclosed parameters, # complete parameter end first - if (name.startswith("function") or name == "function" - or name == "tool_call") and self.current_param_name: + if ( + name.startswith("function") or name == "function" or name == "tool_call" + ) and self.current_param_name: self._auto_close_open_parameter_if_needed() - if (name.startswith("parameter") - or name == "parameter") and self.current_param_name: + if ( + name.startswith("parameter") or name == "parameter" + ) and self.current_param_name: # End current parameter param_name = self.current_param_name param_value = self.current_param_value @@ -912,8 +991,11 @@ def _end_element(self, name: str): # perform overall parsing on raw content # accumulated in preprocessing stage and output once if self.defer_current_parameter: - raw_text = (self.deferred_param_raw_value - if self.deferred_param_raw_value else param_value) + raw_text = ( + self.deferred_param_raw_value + if self.deferred_param_raw_value + else param_value + ) parsed_value = None output_arguments = None try: @@ -924,22 +1006,24 @@ def _end_element(self, name: str): else: raw_for_parse = raw_text parsed_value = ast.literal_eval(raw_for_parse) - output_arguments = json.dumps(parsed_value, - ensure_ascii=False) + output_arguments = json.dumps(parsed_value, ensure_ascii=False) except Exception: # Fallback: output as string as-is output_arguments = json.dumps(raw_text, ensure_ascii=False) parsed_value = raw_text - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.tool_call_index - 1, - id=self.current_call_id, - type="function", - function=DeltaFunctionCall(name=None, - arguments=output_arguments), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=output_arguments + ), + ) + ] + ) self._emit_delta(delta) # Clean up and store @@ -956,38 +1040,37 @@ def _end_element(self, name: str): param_type = self._get_param_type(param_name) # convert complete parameter value by param_type - converted_value = self._convert_param_value( - param_value, param_type) + converted_value = self._convert_param_value(param_value, param_type) # Decide whether to add end quote based on parameter type - if param_type in [ - "string", "str", "text", "varchar", "char", "enum" - ]: + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: # For empty string parameters, need special handling if not param_value and not self.start_quote_emitted: # No start quote output, # directly output complete empty string - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.tool_call_index - 1, - id=self.current_call_id, - type="function", - function=DeltaFunctionCall(name=None, - arguments='""'), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='""'), + ) + ] + ) self._emit_delta(delta) else: # Non-empty parameter value, output end quote - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.tool_call_index - 1, - id=self.current_call_id, - type="function", - function=DeltaFunctionCall(name=None, - arguments='"'), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='"'), + ) + ] + ) self._emit_delta(delta) self.should_emit_end_newline = False @@ -1001,28 +1084,34 @@ def _end_element(self, name: str): elif name.startswith("function") or name == "function": # if there are parameters, close JSON object if self.parameters: - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.tool_call_index - 1, - id=self.current_call_id, - type="function", - function=DeltaFunctionCall(name=None, arguments="}"), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments="}"), + ) + ] + ) self._emit_delta(delta) # return empty object else: - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.tool_call_index - 1, - id=self.current_call_id, - type="function", - function=DeltaFunctionCall(name=None, arguments="{}"), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments="{}"), + ) + ] + ) self._emit_delta(delta) self.current_function_open = False - self.current_function_name = None # Clear function name to prevent duplicate closing + self.current_function_name = ( + None # Clear function name to prevent duplicate closing + ) elif name == "tool_call": # Before ending tool_call, @@ -1034,14 +1123,16 @@ def _end_element(self, name: str): # Close function, ensure output '}' or '{}' self._end_element("function") # Final Delta - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.tool_call_index - 1, - id=self.current_call_id, - type="function", - function=DeltaFunctionCall(name=None, arguments=""), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments=""), + ) + ] + ) self._emit_delta(delta) # Check if there's text content to output (between tool_calls) @@ -1062,8 +1153,7 @@ def set_tools(self, tools: list[ChatCompletionToolsParam] | None): """Set tool configuration information""" self.tools = tools - def _extract_function_name(self, name: str, - attrs: dict[str, str]) -> str | None: + def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None: """Extract function name from various formats""" if attrs and "name" in attrs: return attrs["name"] @@ -1075,8 +1165,7 @@ def _extract_function_name(self, name: str, return None - def _extract_parameter_name(self, name: str, - attrs: dict[str, str]) -> str | None: + def _extract_parameter_name(self, name: str, attrs: dict[str, str]) -> str | None: """Extract parameter name from various formats""" if attrs and "name" in attrs: return attrs["name"] @@ -1100,25 +1189,31 @@ def _get_param_type(self, param_name: str) -> str: return "string" for tool in self.tools: - if not hasattr(tool, "type") or not (hasattr( - tool, "function") and hasattr(tool.function, "name")): + if not hasattr(tool, "type") or not ( + hasattr(tool, "function") and hasattr(tool.function, "name") + ): continue - if (tool.type == "function" - and tool.function.name == self.current_function_name): + if ( + tool.type == "function" + and tool.function.name == self.current_function_name + ): if not hasattr(tool.function, "parameters"): return "string" params = tool.function.parameters if isinstance(params, dict) and "properties" in params: properties = params["properties"] if param_name in properties and isinstance( - properties[param_name], dict): + properties[param_name], dict + ): return self.repair_param_type( - str(properties[param_name].get("type", "string"))) + str(properties[param_name].get("type", "string")) + ) elif isinstance(params, dict) and param_name in params: param_config = params[param_name] if isinstance(param_config, dict): return self.repair_param_type( - str(param_config.get("type", "string"))) + str(param_config.get("type", "string")) + ) break return "string" @@ -1130,18 +1225,22 @@ def repair_param_type(self, param_type: str) -> str: Returns: Repaired parameter type """ - if (param_type in ["string", "str", "text", "varchar", "char", "enum"] - or param_type.startswith("int") - or param_type.startswith("uint") - or param_type.startswith("long") - or param_type.startswith("short") - or param_type.startswith("unsigned") - or param_type.startswith("num") - or param_type.startswith("float") - or param_type in ["boolean", "bool", "binary"] - or (param_type in ["object", "array", "arr", "sequence"] - or param_type.startswith("dict") - or param_type.startswith("list"))): + if ( + param_type in ["string", "str", "text", "varchar", "char", "enum"] + or param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + or param_type.startswith("num") + or param_type.startswith("float") + or param_type in ["boolean", "bool", "binary"] + or ( + param_type in ["object", "array", "arr", "sequence"] + or param_type.startswith("dict") + or param_type.startswith("list") + ) + ): return param_type else: return "string" @@ -1161,29 +1260,32 @@ def _convert_param_value(self, param_value: str, param_type: str) -> Any: param_type = param_type.strip().lower() if param_type in ["string", "str", "text", "varchar", "char", "enum"]: return param_value - elif (param_type.startswith("int") or param_type.startswith("uint") - or param_type.startswith("long") - or param_type.startswith("short") - or param_type.startswith("unsigned")): + elif ( + param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + ): try: return int(param_value) except (ValueError, TypeError): logger.warning( - "Parsed value '%s' is not an integer, " - "degenerating to string.", + "Parsed value '%s' is not an integer, degenerating to string.", param_value, ) return param_value elif param_type.startswith("num") or param_type.startswith("float"): try: float_param_value: float = float(param_value) - return (float_param_value if float_param_value - - int(float_param_value) != 0 else - int(float_param_value)) + return ( + float_param_value + if float_param_value - int(float_param_value) != 0 + else int(float_param_value) + ) except (ValueError, TypeError): logger.warning( - "Parsed value '%s' is not a float, " - "degenerating to string.", + "Parsed value '%s' is not a float, degenerating to string.", param_value, ) return param_value @@ -1193,8 +1295,7 @@ def _convert_param_value(self, param_value: str, param_type: str) -> Any: else: return param_value - def _convert_for_json_streaming(self, converted_value: Any, - param_type: str) -> str: + def _convert_for_json_streaming(self, converted_value: Any, param_type: str) -> str: """Convert converted_value based on whether it's empty and if type is string Args: @@ -1253,7 +1354,6 @@ def _reset_xml_parser_after_tool_call(self): @ToolParserManager.register_module("step3p5") class Step3p5ToolParser(ToolParser): - def __init__(self, tokenizer: TokenizerLike): super().__init__(tokenizer) self.parser = StreamingXMLToolCallParser() @@ -1262,8 +1362,9 @@ def __init__(self, tokenizer: TokenizerLike): self.prev_tool_call_arr: list[dict] = [] self.streamed_args_for_tool: list[str] = [] - logger.info("vLLM Successfully import tool parser %s !", - self.__class__.__name__) + logger.info( + "vLLM Successfully import tool parser %s !", self.__class__.__name__ + ) def extract_tool_calls( self, @@ -1295,32 +1396,35 @@ def extract_tool_calls( name=tool_call.function.name, arguments=tool_call.function.arguments, ), - )) + ) + ) # Update tool call tracking arrays for compatibility - tool_index = (tool_call.index - if tool_call.index is not None else - len(self.prev_tool_call_arr) - 1) + tool_index = ( + tool_call.index + if tool_call.index is not None + else len(self.prev_tool_call_arr) - 1 + ) # Ensure we have enough entries in our tracking arrays while len(self.prev_tool_call_arr) <= tool_index: - self.prev_tool_call_arr.append({ - "name": "", - "arguments": "" - }) + self.prev_tool_call_arr.append({"name": "", "arguments": ""}) while len(self.streamed_args_for_tool) <= tool_index: self.streamed_args_for_tool.append("") # Update tool call information self.prev_tool_call_arr[tool_index]["name"] = ( - tool_call.function.name) + tool_call.function.name + ) self.prev_tool_call_arr[tool_index]["arguments"] = ( - tool_call.function.arguments) + tool_call.function.arguments + ) # Update streamed arguments if tool_call.function.arguments: self.streamed_args_for_tool[tool_index] = ( - tool_call.function.arguments) + tool_call.function.arguments + ) return ExtractedToolCallInformation( tool_calls=tool_calls, @@ -1352,10 +1456,14 @@ def extract_tool_calls_streaming( # to correctly output tool_call field if not delta_text and delta_token_ids: open_calls = current_text.count( - self.parser.tool_call_start_token) - current_text.count( - self.parser.tool_call_end_token) - if (open_calls == 0 and self.parser.tool_call_index > 0 - or not self.parser.tool_call_index and current_text): + self.parser.tool_call_start_token + ) - current_text.count(self.parser.tool_call_end_token) + if ( + open_calls == 0 + and self.parser.tool_call_index > 0 + or not self.parser.tool_call_index + and current_text + ): return DeltaMessage(content="") return None @@ -1366,32 +1474,34 @@ def extract_tool_calls_streaming( if result and result.tool_calls: for tool_call in result.tool_calls: if tool_call.function: - tool_index = (tool_call.index - if tool_call.index is not None else - len(self.prev_tool_call_arr) - 1) + tool_index = ( + tool_call.index + if tool_call.index is not None + else len(self.prev_tool_call_arr) - 1 + ) # Ensure we have enough entries in our tracking arrays while len(self.prev_tool_call_arr) <= tool_index: - self.prev_tool_call_arr.append({ - "name": "", - "arguments": "" - }) + self.prev_tool_call_arr.append({"name": "", "arguments": ""}) while len(self.streamed_args_for_tool) <= tool_index: self.streamed_args_for_tool.append("") # Update tool name if provided if tool_call.function.name: self.prev_tool_call_arr[tool_index]["name"] = ( - tool_call.function.name) + tool_call.function.name + ) # Update arguments incrementally if tool_call.function.arguments is not None: # Concatenate the incremental arguments # to the existing streamed arguments self.prev_tool_call_arr[tool_index]["arguments"] += ( - tool_call.function.arguments) + tool_call.function.arguments + ) self.streamed_args_for_tool[tool_index] += ( - tool_call.function.arguments) + tool_call.function.arguments + ) return result def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool: diff --git a/vllm/transformers_utils/configs/step3p5.py b/vllm/transformers_utils/configs/step3p5.py index 814486ed8b36..d2ad927da8e8 100644 --- a/vllm/transformers_utils/configs/step3p5.py +++ b/vllm/transformers_utils/configs/step3p5.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional, Union +from typing import Any from transformers.configuration_utils import PretrainedConfig @@ -25,27 +25,27 @@ def __init__( moe_top_k: int = 4, moe_layer_offset: int = 0, moe_dynamic_exp_p: float = 1.0, - rope_theta: Optional[Union[float, list[float]]] = 500000, - rope_scaling: Optional[dict[str, Any]] = None, - head_dim: Optional[int] = None, - share_expert_dim: Optional[int] = None, + rope_theta: float | list[float] | None = 500000, + rope_scaling: dict[str, Any] | None = None, + head_dim: int | None = None, + share_expert_dim: int | None = None, norm_expert_weight: bool = True, - bos_token_id: Optional[Union[list[int], int]] = None, - eos_token_id: Optional[Union[list[int], int]] = None, + bos_token_id: list[int] | int | None = None, + eos_token_id: list[int] | int | None = None, moe_router_activation: str = "softmax", moe_router_scaling_factor: float = 1.0, att_impl_type: str = "GQA", use_head_wise_attn_gate: bool = False, use_moe_router_bias: bool = True, need_fp32_gate: bool = True, - layer_types: Optional[list[str]] = None, - use_rope_layers: Optional[list[bool]] = None, - yarn_only_types: Optional[list[str]] = None, - attention_other_setting: Optional[dict[str, Any]] = None, + layer_types: list[str] | None = None, + use_rope_layers: list[bool] | None = None, + yarn_only_types: list[str] | None = None, + attention_other_setting: dict[str, Any] | None = None, num_nextn_predict_layers: int = 0, - swiglu_limits: Optional[list[float]] = None, - swiglu_limits_shared: Optional[list[float]] = None, - max_position_embeddings: Optional[int] = None, + swiglu_limits: list[float] | None = None, + swiglu_limits_shared: list[float] | None = None, + max_position_embeddings: int | None = None, **kwargs, ): self.hidden_size = hidden_size From 67ced2d43aeccbeac1a5f1e38ae5b1d5a27419b0 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Fri, 30 Jan 2026 17:36:53 +0800 Subject: [PATCH 20/34] fix: remove moe_dynamic_exp_p from config --- vllm/model_executor/models/step3p5.py | 2 -- vllm/transformers_utils/configs/step3p5.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index 161d17190113..d20ccd6ac47a 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -307,8 +307,6 @@ def __init__( f"the number of experts {config.moe_num_experts}." ) - assert config.moe_dynamic_exp_p == 1, "Only support dynamic exp p=1" - self.use_moe_router_bias = config.use_moe_router_bias assert self.use_moe_router_bias, "Only support use_moe_router_bias is true." self.routed_scaling_factor = config.moe_router_scaling_factor diff --git a/vllm/transformers_utils/configs/step3p5.py b/vllm/transformers_utils/configs/step3p5.py index d2ad927da8e8..435afd938212 100644 --- a/vllm/transformers_utils/configs/step3p5.py +++ b/vllm/transformers_utils/configs/step3p5.py @@ -24,7 +24,6 @@ def __init__( moe_num_experts: int = 16, moe_top_k: int = 4, moe_layer_offset: int = 0, - moe_dynamic_exp_p: float = 1.0, rope_theta: float | list[float] | None = 500000, rope_scaling: dict[str, Any] | None = None, head_dim: int | None = None, @@ -63,7 +62,6 @@ def __init__( self.num_experts_per_tok = moe_top_k self.moe_top_k = moe_top_k self.moe_layer_offset = moe_layer_offset - self.moe_dynamic_exp_p = moe_dynamic_exp_p self.rope_theta = rope_theta self.rope_scaling = rope_scaling From 9fded696d84f3c5c78053e303f220f4c496f93b9 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Fri, 30 Jan 2026 22:18:51 +0800 Subject: [PATCH 21/34] fix: mtp3 weights load error --- vllm/model_executor/models/step3p5.py | 7 +++++++ vllm/model_executor/models/step3p5_mtp.py | 5 +---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index d20ccd6ac47a..eef1d3da6cbd 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -767,6 +767,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is not None: continue # skip spec decode layers for main model + # Skip any layers beyond the main model's depth (e.g., MTP layers) + if name.startswith("model.layers."): + parts = name.split(".") + if len(parts) > 2 and parts[2].isdigit(): + layer_idx = int(parts[2]) + if layer_idx >= config.num_hidden_layers: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: diff --git a/vllm/model_executor/models/step3p5_mtp.py b/vllm/model_executor/models/step3p5_mtp.py index 365be3fc3f52..7d715de6a2ce 100644 --- a/vllm/model_executor/models/step3p5_mtp.py +++ b/vllm/model_executor/models/step3p5_mtp.py @@ -183,9 +183,6 @@ def compute_logits( return self.model.compute_logits(hidden_states, spec_step_idx) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - vllm_config = self.vllm_config - config = vllm_config.model_config.hf_config - stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -266,7 +263,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ): continue - if f"{config.num_hidden_layers}.transformer." in name: + if spec_layer is not None and ".transformer." in name: name = name.replace(".transformer.", ".") if "shared_head" in name: name = name.replace("shared_head.output", "shared_head.head") From 06892530d5902c73cad8e4aa795f654199a96fb3 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Sat, 31 Jan 2026 12:35:25 +0800 Subject: [PATCH 22/34] format: some review comments fix --- vllm/model_executor/models/step3p5.py | 276 ++++++++++++++------------ 1 file changed, 146 insertions(+), 130 deletions(-) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index eef1d3da6cbd..bc6ec65f7b9a 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -45,6 +45,7 @@ from .interfaces import MixtureOfExperts, SupportsPP from .utils import ( + AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, @@ -92,7 +93,7 @@ def __init__( self.prefix = prefix self.hidden_size = hidden_size self.limit = None - layer_idx = int(prefix.split("layers.")[1].split(".")[0]) + layer_idx = extract_layer_index(prefix) if ( config.swiglu_limits_shared and config.swiglu_limits_shared[layer_idx] is not None @@ -103,7 +104,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(hidden_states) - intermediate_act = self.act_fn.forward_cuda(gate_up) + intermediate_act = self.act_fn(gate_up) output, _ = self.down_proj(intermediate_act) return output @@ -241,25 +242,23 @@ def __init__( self.head_dim if self.partial_rotary_factor == 1 else self.head_dim // 2 ) - def qk_norm_rope(self, q, k, positions): + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # Add qk-norm inline similar to Qwen3 MOE attention q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = self.q_norm(q_by_head.contiguous()) q = q_by_head.view(q.shape) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) k_by_head = self.k_norm(k_by_head.contiguous()) k = k_by_head.view(k.shape) if self.use_rope: q, k = self.rotary_emb(positions, q, k) - return q, k - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.qk_norm_rope(q, k, positions) attn_output = self.attn(q, k, v) if self.use_head_wise_attn_gate: extra_dims, _ = self.g_proj(hidden_states) @@ -623,6 +622,137 @@ def forward( return hidden_states + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + config = self.config + assert config.num_attention_groups > 1, "Only support GQA" + qkv_params_mapping = [] + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + expert_params_mapping = [ + (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), + (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), + (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"), + ] + + disable_moe_stacked_params = [data[1] for data in expert_params_mapping] + + for name, loaded_weight in weights: + if name.startswith("model."): + local_name = name[len("model.") :] + full_name = name + else: + local_name = name + full_name = f"model.{name}" if name else "model" + + spec_layer = get_spec_layer_idx_from_weight_name(config, full_name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + # Skip any layers beyond the main model's depth (e.g., MTP layers) + if full_name.startswith("model.layers."): + parts = full_name.split(".") + if len(parts) > 2 and parts[2].isdigit(): + layer_idx = int(parts[2]) + if layer_idx >= config.num_hidden_layers: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in local_name: + continue + if any( + disable_moe_stacked_param in local_name + for disable_moe_stacked_param in disable_moe_stacked_params + ): + continue + replaced_name = local_name.replace(weight_name, param_name) + if is_pp_missing_parameter(replaced_name, self): + continue + if replaced_name not in params_dict: + continue + param = params_dict[replaced_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(replaced_name) + break + else: + for param_name, weight_name, shard_id in expert_params_mapping: + if weight_name not in local_name: + continue + replaced_name = local_name.replace(weight_name, param_name) + if is_pp_missing_parameter(replaced_name, self): + continue + if ( + replaced_name.endswith(".bias") + or replaced_name.endswith("_bias") + ) and replaced_name not in params_dict: + continue + if replaced_name not in params_dict: + continue + param = params_dict[replaced_name] + weight_loader = param.weight_loader + moe_expert_num = self.moe_num_experts + assert loaded_weight.shape[0] == moe_expert_num + for expert_id in range(moe_expert_num): + loaded_weight_expert = loaded_weight[expert_id] + weight_loader( + param, + loaded_weight_expert, + replaced_name, + shard_id=shard_id, + expert_id=expert_id, + ) + loaded_params.add(replaced_name) + break + else: + for ( + param_name, + weight_name, + start_idx, + end_idx, + ) in qkv_params_mapping: + if weight_name not in local_name: + continue + replaced_name = local_name.replace(weight_name, param_name) + if is_pp_missing_parameter(replaced_name, self): + continue + if replaced_name not in params_dict: + continue + param = params_dict[replaced_name] + dim = param.shape[param.output_dim] + begin_idx = int(start_idx * dim) + end_idx = int(end_idx * dim) + param_slice = param.narrow( + param.output_dim, begin_idx, end_idx - begin_idx + ) + param_slice.copy_(loaded_weight) + loaded_params.add(replaced_name) + break + else: + if is_pp_missing_parameter(local_name, self): + continue + if "expert_bias" in local_name: + logger.warning_once("ignore expert_bias") + continue + if local_name not in params_dict: + continue + param = params_dict[local_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(local_name) + return loaded_params + class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): def __init__( @@ -738,123 +868,9 @@ def update_physical_experts_metadata( layer.n_redundant_experts = self.num_redundant_experts layer.experts.update_expert_map() - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - vllm_config = self.vllm_config - config = vllm_config.model_config.hf_config - assert config.num_attention_groups > 1, "Only support GQA" - qkv_params_mapping = [] - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - params_dict = dict(self.named_parameters()) - loaded_params = set() - - expert_params_mapping = [ - (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), - (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), - (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"), - ] - - disable_moe_stacked_params = [data[1] for data in expert_params_mapping] - - for name, loaded_weight in weights: - spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) - if spec_layer is not None: - continue # skip spec decode layers for main model - # Skip any layers beyond the main model's depth (e.g., MTP layers) - if name.startswith("model.layers."): - parts = name.split(".") - if len(parts) > 2 and parts[2].isdigit(): - layer_idx = int(parts[2]) - if layer_idx >= config.num_hidden_layers: - continue - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - if any( - disable_moe_stacked_param in name - for disable_moe_stacked_param in disable_moe_stacked_params - ): - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - loaded_params.add(name) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # Skip loading extra bias for GPTQ models. - if ( - name.endswith(".bias") or name.endswith("_bias") - ) and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - moe_expert_num = self.model.moe_num_experts - assert loaded_weight.shape[0] == moe_expert_num - for expert_id in range(moe_expert_num): - loaded_weight_expert = loaded_weight[expert_id] - weight_loader( - param, - loaded_weight_expert, - name, - shard_id=shard_id, - expert_id=expert_id, - ) - loaded_params.add(name) - break - else: - for ( - param_name, - weight_name, - start_idx, - end_idx, - ) in qkv_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - dim = param.shape[param.output_dim] - begin_idx = int(start_idx * dim) - end_idx = int(end_idx * dim) - param_slice = param.narrow( - param.output_dim, begin_idx, end_idx - begin_idx - ) - param_slice.copy_(loaded_weight) - loaded_params.add(name) - break - else: - if is_pp_missing_parameter(name, self): - continue - if "expert_bias" in name: - logger.warning_once("ignore expert_bias") - continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) def get_spec_layer_idx_from_weight_name( From b2b55de2df5bae32f9fe9c1047299d927a9e6bd2 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 31 Jan 2026 06:44:56 +0000 Subject: [PATCH 23/34] Refactor moe Signed-off-by: Jee Jee Li --- vllm/model_executor/models/step3p5.py | 147 +++++++++++++------------- 1 file changed, 71 insertions(+), 76 deletions(-) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index bc6ec65f7b9a..f7b23993e216 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -9,7 +9,7 @@ from torch import nn from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import ( get_dp_group, get_ep_group, @@ -47,6 +47,7 @@ from .utils import ( AutoWeightsLoader, PPMissingLayer, + WeightsMapper, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, @@ -274,20 +275,21 @@ def forward( class FusedMoEBlock(nn.Module): def __init__( self, - config: ModelConfig, - parallel_config: ParallelConfig, - shared_experts: torch.nn.Module, - quant_config: QuantizationConfig | None = None, - reduce_results: bool = True, + vllm_config: VllmConfig, prefix: str = "", ): super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() self.layer_idx = extract_layer_index(prefix) self.ep_size = get_ep_group().device_group.size() self.ep_rank = get_ep_group().device_group.rank() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + self.hidden_size = config.hidden_size self.enable_eplb = parallel_config.enable_eplb self.n_routed_experts = config.moe_num_experts self.n_logical_experts = self.n_routed_experts @@ -336,18 +338,29 @@ def __init__( swiglu_limit, ) + self.share_expert = Step3p5MLP( + config=config, + hidden_size=self.hidden_size, + intermediate_size=config.share_expert_dim, + hidden_act="silu", + reduce_results=False, + quant_config=quant_config, + prefix=f"{prefix}.share_expert", + ) + self.experts = SharedFusedMoE( - shared_experts=shared_experts, + shared_experts=self.share_expert, num_experts=config.moe_num_experts, top_k=config.moe_top_k, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - reduce_results=reduce_results, + reduce_results=False, renormalize=config.norm_expert_weight, quant_config=quant_config, activation=activation, prefix=f"{prefix}.experts", - custom_routing_function=self.router_bias_func, + scoring_func=getattr(config, "moe_router_activation", "sigmoid"), + e_score_correction_bias=self.router_bias, routed_scaling_factor=config.moe_router_scaling_factor, enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, @@ -359,54 +372,55 @@ def __init__( quant_config=None, prefix=f"{prefix}.gate", ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) - def router_bias_func( - self, - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - ): - gate_prob = torch.sigmoid(gating_output.float()) - gate_prob_with_bias = gate_prob + self.router_bias.unsqueeze(0) - _, indices = torch.topk(gate_prob_with_bias, k=topk, dim=1) - topk_prob = torch.gather(gate_prob, 1, indices) - expert_topk_weight = topk_prob - if renormalize: - expert_topk_weight = expert_topk_weight / ( - torch.sum(expert_topk_weight, dim=-1, keepdim=True) + 1e-20 + if self.experts.is_internal_router: + # In this case, the gate/router runs inside the FusedMoE class + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=hidden_states + ) + else: + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits ) - expert_topk_weight *= self.routed_scaling_factor - return expert_topk_weight, indices.to(torch.int32) - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - orig_shape = hidden_states.shape - hidden_dim = hidden_states.shape[-1] - hidden_states = hidden_states.view(-1, hidden_dim) - router_logits = ( - hidden_states.to(torch.float32) @ self.gate.weight.to(torch.float32).t() - ) - shared_out, final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) + shared_output, final_hidden_states = fused_moe_out + if self.share_expert is None: + assert shared_output is None + + if self.share_expert is None: + assert shared_output is None + + if self.share_expert is not None: + assert shared_output is not None + final_hidden_states += shared_output - return shared_out, final_hidden_states.view(orig_shape) + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) + + return final_hidden_states.view(num_tokens, hidden_dim) class Step3p5DecoderLayer(nn.Module): def __init__( self, - config: ModelConfig, - parallel_config: ParallelConfig, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, + vllm_config: VllmConfig, prefix: str = "", ) -> None: super().__init__() - config = config.hf_config + config = vllm_config.model_config.hf_config self.hidden_size = config.hidden_size - layer_idx = int(prefix.split("layers.")[1].split(".")[0]) + layer_idx = extract_layer_index(prefix) self.layer_idx = layer_idx + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config if cache_config is not None: cache_config.sliding_window = None if config.att_impl_type == "GQA": @@ -476,28 +490,8 @@ def __init__( else: moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] if layer_idx in moe_layers_idx: - reduce_results = True - if ( - self.use_fused_all_reduce - or self.tp_group.world_size == 1 - and get_ep_group().world_size == 1 - ): - reduce_results = False - self.share_expert = Step3p5MLP( - config=config, - hidden_size=self.hidden_size, - intermediate_size=config.share_expert_dim, - hidden_act="silu", - reduce_results=reduce_results, - quant_config=quant_config, - prefix=f"{prefix}.share_expert", - ) self.moe = FusedMoEBlock( - shared_experts=self.share_expert, - config=config, - parallel_config=parallel_config, - quant_config=quant_config, - reduce_results=reduce_results, + vllm_config, prefix=f"{prefix}.moe", ) self.use_moe = True @@ -539,10 +533,7 @@ def forward( hidden_states = self.post_attention_layernorm(hidden_states) if self.use_moe: - shared_output, moe_output = self.moe(hidden_states) - ffn_output = self.add_and_maybe_inplace_all_reduce( - moe_output, shared_output - ) + ffn_output = self.moe(hidden_states) else: ffn_output = self.mlp(hidden_states) hidden_states = ffn_output + residual @@ -553,9 +544,11 @@ def forward( class Step3p5Model(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + + self.vllm_config = vllm_config config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config + # cache_config = vllm_config.cache_config + # quant_config = vllm_config.quant_config self.vocab_size = config.vocab_size self.config = config @@ -574,10 +567,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Step3p5DecoderLayer( - config=vllm_config.model_config, - parallel_config=vllm_config.parallel_config, - cache_config=cache_config, - quant_config=quant_config, + vllm_config, prefix=prefix, ), prefix=f"{prefix}.layers", @@ -754,7 +744,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded_params + class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={".share_expert.": ".moe.share_expert."} + ) + def __init__( self, *, @@ -870,7 +865,7 @@ def update_physical_experts_metadata( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_spec_layer_idx_from_weight_name( @@ -881,6 +876,6 @@ def get_spec_layer_idx_from_weight_name( ): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): - if weight_name.startswith(f"model.layers.{layer_idx + i}."): + if weight_name.startswith(f"layers.{layer_idx + i}."): return layer_idx + i return None From 128c8c72a0ca10afa2255e3e937efe7f47798a26 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 31 Jan 2026 07:39:00 +0000 Subject: [PATCH 24/34] fix shared moe Signed-off-by: Jee Jee Li --- vllm/model_executor/models/step3p5.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index f7b23993e216..07971acd1b34 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -308,6 +308,13 @@ def __init__( f"the number of experts {config.moe_num_experts}." ) + self.gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) self.use_moe_router_bias = config.use_moe_router_bias assert self.use_moe_router_bias, "Only support use_moe_router_bias is true." self.routed_scaling_factor = config.moe_router_scaling_factor @@ -350,6 +357,7 @@ def __init__( self.experts = SharedFusedMoE( shared_experts=self.share_expert, + gate=self.gate, num_experts=config.moe_num_experts, top_k=config.moe_top_k, hidden_size=config.hidden_size, @@ -365,14 +373,7 @@ def __init__( enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, ) - self.gate = ReplicatedLinear( - config.hidden_size, - config.moe_num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate", - ) - + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -744,7 +745,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded_params - class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={".share_expert.": ".moe.share_expert."} From d66b2de4de66a1549aee29e9c555442243366072 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 31 Jan 2026 09:26:22 +0000 Subject: [PATCH 25/34] NIT Signed-off-by: Jee Jee Li --- docs/models/supported_models.md | 1 + tests/models/registry.py | 3 +++ vllm/model_executor/models/step3p5.py | 2 -- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index e79899e6fe7a..92d610de6d4d 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -456,6 +456,7 @@ th { | `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | | `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | | `Step1ForCausalLM` | Step-Audio | `stepfun-ai/Step-Audio-EditX`, etc. | ✅︎ | ✅︎ | +| `Step3p5ForCausalLM` | Step-Audio | `stepfun-ai/step-3.5-flash`, etc. | | ✅︎ | | `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | | `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | | `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index c2760d37f4cd..398bcbdd9005 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -484,6 +484,9 @@ def check_available_online( "Step1ForCausalLM": _HfExamplesInfo( "stepfun-ai/Step-Audio-EditX", trust_remote_code=True ), + "Step3p5ForCausalLM": _HfExamplesInfo( + "stepfun-ai/step-3.5-flash", is_available_online=False + ), "SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index 07971acd1b34..6c9aad10e966 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -548,8 +548,6 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: self.vllm_config = vllm_config config = vllm_config.model_config.hf_config - # cache_config = vllm_config.cache_config - # quant_config = vllm_config.quant_config self.vocab_size = config.vocab_size self.config = config From 82631afd56e361bb305f274368ca2fc623fb90e9 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 31 Jan 2026 09:29:08 +0000 Subject: [PATCH 26/34] NIT Signed-off-by: Jee Jee Li --- docs/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 92d610de6d4d..14e78d383523 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -456,7 +456,7 @@ th { | `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | | `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | | `Step1ForCausalLM` | Step-Audio | `stepfun-ai/Step-Audio-EditX`, etc. | ✅︎ | ✅︎ | -| `Step3p5ForCausalLM` | Step-Audio | `stepfun-ai/step-3.5-flash`, etc. | | ✅︎ | +| `Step3p5ForCausalLM` | Step-3.5-flash | `stepfun-ai/step-3.5-flash`, etc. | | ✅︎ | | `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | | `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | | `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | From f35951bf44306b1faf3d07a38dffa3fc731dd2f5 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Sat, 31 Jan 2026 21:08:28 +0800 Subject: [PATCH 27/34] feat: keep router logits in fp32 precision --- vllm/model_executor/models/step3p5.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index 6c9aad10e966..23fcc616db32 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -378,19 +378,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.experts.is_internal_router: - # In this case, the gate/router runs inside the FusedMoE class - fused_moe_out = self.experts( - hidden_states=hidden_states, router_logits=hidden_states - ) - else: - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - fused_moe_out = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) + # router_logits: (num_tokens, n_experts) + # Use FP32 for higher precision. + router_logits = ( + hidden_states.to(torch.float32) @ self.gate.weight.to(torch.float32).t() + ) + shared_output, final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) - shared_output, final_hidden_states = fused_moe_out if self.share_expert is None: assert shared_output is None From 2faeea8e0ae5414437aa227432da73d0637dca18 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 31 Jan 2026 14:37:54 +0000 Subject: [PATCH 28/34] gate Signed-off-by: Jee Jee Li --- vllm/model_executor/models/step3p5.py | 37 ++++++++++++++++++++------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index 23fcc616db32..cbda28254935 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -7,6 +7,7 @@ import torch from torch import nn +from torch.nn.parameter import Parameter from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig @@ -58,6 +59,19 @@ logger = init_logger(__name__) +class FP32ReplicatedLinear(ReplicatedLinear): + """ + Use FP32 for higher precision. + """ + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: + assert self.params_dtype == torch.float32 + return super().forward(x.to(torch.float32)) + + class Step3p5MLP(nn.Module): def __init__( self, @@ -308,11 +322,12 @@ def __init__( f"the number of experts {config.moe_num_experts}." ) - self.gate = ReplicatedLinear( + self.gate = FP32ReplicatedLinear( config.hidden_size, config.moe_num_experts, bias=False, quant_config=None, + params_dtype=torch.float32, # Use FP32 for higher precision. prefix=f"{prefix}.gate", ) self.use_moe_router_bias = config.use_moe_router_bias @@ -378,15 +393,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (num_tokens, n_experts) - # Use FP32 for higher precision. - router_logits = ( - hidden_states.to(torch.float32) @ self.gate.weight.to(torch.float32).t() - ) - shared_output, final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) + if self.experts.is_internal_router: + # In this case, the gate/router runs inside the FusedMoE class + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=hidden_states + ) + else: + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + shared_output, final_hidden_states = fused_moe_out if self.share_expert is None: assert shared_output is None From 5147782201502324c400a400a52c1a0d5f4f844e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 1 Feb 2026 12:17:46 +0000 Subject: [PATCH 29/34] Fix MTP Signed-off-by: Jee Jee Li --- vllm/model_executor/models/step3p5.py | 4 +++- vllm/model_executor/models/step3p5_mtp.py | 22 ++++++---------------- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index cbda28254935..7dd764068fe7 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -889,6 +889,8 @@ def get_spec_layer_idx_from_weight_name( ): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): - if weight_name.startswith(f"layers.{layer_idx + i}."): + if weight_name.startswith( + f"layers.{layer_idx + i}." # Step3p5Model + ) or weight_name.startswith(f"model.layers.{layer_idx + i}."): # Step3p5MTP return layer_idx + i return None diff --git a/vllm/model_executor/models/step3p5_mtp.py b/vllm/model_executor/models/step3p5_mtp.py index 7d715de6a2ce..83e43dce5114 100644 --- a/vllm/model_executor/models/step3p5_mtp.py +++ b/vllm/model_executor/models/step3p5_mtp.py @@ -6,7 +6,7 @@ import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -43,24 +43,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Step3p5AMultiTokenPredictorLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + vllm_config: VllmConfig, prefix: str, - model_config: ModelConfig, - parallel_config: ParallelConfig = None, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() - + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) self.shared_head = SharedHead(config=config, quant_config=quant_config) self.mtp_block = Step3p5DecoderLayer( - model_config, - parallel_config=parallel_config, - cache_config=cache_config, - quant_config=quant_config, + vllm_config, prefix=f"{prefix}.mtp_block", ) @@ -98,12 +92,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.layers = torch.nn.ModuleDict( { str(idx): Step3p5AMultiTokenPredictorLayer( - config, + vllm_config, f"{prefix}.layers.{idx}", - model_config=vllm_config.model_config, - parallel_config=vllm_config.parallel_config, - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, ) for idx in range( self.mtp_start_layer_idx, From 1c4d28c2eb03215e52c5fb70ba8820e827b2d255 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 1 Feb 2026 14:42:19 +0000 Subject: [PATCH 30/34] triton act Signed-off-by: Jee Jee Li --- tests/kernels/core/test_activation.py | 10 ++- vllm/model_executor/layers/activation.py | 70 ++++++++++++++----- vllm/model_executor/layers/fused_moe/utils.py | 5 +- vllm/model_executor/models/step3p5.py | 13 ++-- 4 files changed, 70 insertions(+), 28 deletions(-) diff --git a/tests/kernels/core/test_activation.py b/tests/kernels/core/test_activation.py index 8f28e967aec3..66727a3099ee 100644 --- a/tests/kernels/core/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -17,6 +17,8 @@ QuickGELU, SiluAndMul, SwigluOAIAndMul, + SwigluStepAndMul, + swiglustep_and_mul_triton, ) from vllm.utils.torch_utils import set_random_seed @@ -36,6 +38,7 @@ "gelu_tanh", "fatrelu", "swigluoai_and_mul", + "swiglustep_and_mul", ], ) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -75,9 +78,12 @@ def test_act_and_mul( elif activation == "swigluoai_and_mul": layer = SwigluOAIAndMul() fn = torch.ops._C.swigluoai_and_mul + elif activation == "swiglustep_and_mul": + layer = SwigluStepAndMul() + fn = swiglustep_and_mul_triton out = layer(x) ref_out = layer.forward_native(x) - if activation == "swigluoai_and_mul": + if activation in ["swigluoai_and_mul", "swiglustep_and_mul"]: rtol = { # For fp16, change the relative tolerance from 1e-3 to 2e-3 torch.float16: 2e-3, @@ -104,7 +110,7 @@ def _get_rtol(output) -> float: opcheck(fn, (out, x, threshold)) elif activation == "swigluoai_and_mul": opcheck(fn, (out, x, layer.alpha, layer.limit)) - else: + elif activation != "swiglustep_and_mul": opcheck(fn, (out, x)) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index b242b892f097..b53a37a31761 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -17,29 +17,61 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils.collection_utils import LazyDict logger = init_logger(__name__) -def swiglustep_and_mul_out( - out: torch.Tensor, - x: torch.Tensor, - limit: float = 7.0, -) -> torch.Tensor: - """Out-variant of swiglustep activation. - - Writes into `out`: - silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit) - """ - # Prefer the fused custom op when available (CUDA); fallback to PyTorch ops - # otherwise. - gate, up = x.chunk(2, dim=-1) - gate = F.silu(gate) - gate = gate.clamp(max=limit) - up = up.clamp(min=-limit, max=limit) - out.copy_(gate * up) - return out +@triton.jit +def _swiglustep_and_mul_kernel( + o_ptr, + o_stride, + x_ptr, + x_stride, + limit: tl.constexpr, + d: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +) -> None: + i = tl.program_id(axis=0).to(tl.int64) + j = tl.program_id(axis=1) + o_row_ptr = o_ptr + o_stride * i + x_row_ptr = x_ptr + x_stride * i + offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < d + + gate = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32) + up = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32) + + gate_silu = tl.sigmoid(gate) * gate + gate_clamped = tl.minimum(gate_silu, limit) + up_clamped = tl.minimum(tl.maximum(up, -limit), limit) + + result = gate_clamped * up_clamped + result = result.to(x_ptr.dtype.element_ty) + tl.store(o_row_ptr + offsets, result, mask=mask) + + +def swiglustep_and_mul_triton( + output: torch.Tensor, input: torch.Tensor, limit: float = 7.0 +): + b, n = input.shape + assert input.ndim == 2 + assert n % 2 == 0 + d = n // 2 + + def grid(meta): + return (b, triton.cdiv(d, meta["BLOCK_SIZE"])) + + _swiglustep_and_mul_kernel[grid]( + output, + output.stride(0), + input, + input.stride(0), + limit=limit, + d=d, + BLOCK_SIZE=1024, + ) # --8<-- [start:fatrelu_and_mul] @@ -355,7 +387,7 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - swiglustep_and_mul_out(out, x, self.limit) + swiglustep_and_mul_triton(out, x, self.limit) return out def extra_repr(self) -> str: diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 50c216c43b81..75873a92abdb 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -359,9 +359,10 @@ def apply_moe_activation( elif activation == "swigluoai": torch.ops._C.swigluoai_and_mul(output, input) elif activation == "swiglustep": - from vllm.model_executor.layers.activation import swiglustep_and_mul_out + from vllm.model_executor.layers.activation import swiglustep_and_mul_triton + + swiglustep_and_mul_triton(output, input) - swiglustep_and_mul_out(output, input) # Activations without gated multiplication elif activation == SILU_NO_MUL: output.copy_(F.silu(input)) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index 7dd764068fe7..82266869da70 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -341,11 +341,13 @@ def __init__( assert self.need_fp32_gate, ( "Router logits must use FP32 precision for numerical stability." ) - layer_idx = int(prefix.split("layers.")[1].split(".")[0]) + activation = "silu" swiglu_limits = config.swiglu_limits or [] swiglu_limit = ( - swiglu_limits[layer_idx] if layer_idx < len(swiglu_limits) else None + swiglu_limits[self.layer_idx] + if self.layer_idx < len(swiglu_limits) + else None ) if swiglu_limit not in (None, 0): swiglu_limit = float(swiglu_limit) @@ -353,9 +355,9 @@ def __init__( "Swiglu limit in fused moe block only suport 7.0 now." ) activation = "swiglustep" - logger.info( + logger.debug( "step3p5 layer_idx: %s, activation: %s, limit: %s", - layer_idx, + self.layer_idx, activation, swiglu_limit, ) @@ -369,7 +371,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.share_expert", ) - + if get_tensor_model_parallel_rank() == 0: + print(f"{self.layer_idx}: moe activation=={activation}") self.experts = SharedFusedMoE( shared_experts=self.share_expert, gate=self.gate, From b4264eb934447cbf59d1c14a4a9ec58b799dbbbb Mon Sep 17 00:00:00 2001 From: csy0225 Date: Mon, 2 Feb 2026 00:33:15 +0800 Subject: [PATCH 31/34] format: delete print log --- vllm/model_executor/models/step3p5.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index 82266869da70..6663c6f02ad7 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -371,8 +371,6 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.share_expert", ) - if get_tensor_model_parallel_rank() == 0: - print(f"{self.layer_idx}: moe activation=={activation}") self.experts = SharedFusedMoE( shared_experts=self.share_expert, gate=self.gate, From 809d69dc093d41bd042221baf37f62c81638e717 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Mon, 2 Feb 2026 07:04:32 +0800 Subject: [PATCH 32/34] fix: remove --- vllm/model_executor/models/step3p5.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/models/step3p5.py b/vllm/model_executor/models/step3p5.py index 6663c6f02ad7..f0d7b4a75a9d 100644 --- a/vllm/model_executor/models/step3p5.py +++ b/vllm/model_executor/models/step3p5.py @@ -410,9 +410,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.share_expert is None: assert shared_output is None - if self.share_expert is None: - assert shared_output is None - if self.share_expert is not None: assert shared_output is not None final_hidden_states += shared_output From f86b755f08bfdd0410b0a0e92c71deb60e1d0672 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Mon, 2 Feb 2026 08:19:03 +0800 Subject: [PATCH 33/34] CI: add step3p5 mtp hf examples info --- tests/models/registry.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/models/registry.py b/tests/models/registry.py index 398bcbdd9005..408cdea5134d 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -1116,6 +1116,12 @@ def check_available_online( "Qwen3NextMTP": _HfExamplesInfo( "Qwen/Qwen3-Next-80B-A3B-Instruct", min_transformers_version="4.56.3" ), + "Step3p5MTP": _HfExamplesInfo( + "stepfun-ai/Step-3.5-Flash", + trust_remote_code=True, + speculative_model="stepfun-ai/Step-3.5-Flash", + is_available_online=False + ), } _TRANSFORMERS_BACKEND_MODELS = { From c7b05eba74d4fcabe70398ba34ef44bdddbc3f4c Mon Sep 17 00:00:00 2001 From: csy0225 Date: Mon, 2 Feb 2026 08:19:48 +0800 Subject: [PATCH 34/34] CI: add step3p5 mtp hf examples info --- tests/models/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 408cdea5134d..1bee16c81d01 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -1120,7 +1120,7 @@ def check_available_online( "stepfun-ai/Step-3.5-Flash", trust_remote_code=True, speculative_model="stepfun-ai/Step-3.5-Flash", - is_available_online=False + is_available_online=False, ), }