From 41468cf17aa90f82017842aa56c6e13255b437c4 Mon Sep 17 00:00:00 2001 From: Joe Rowell Date: Tue, 28 Apr 2026 12:12:46 +0100 Subject: [PATCH 1/4] Laguna XS.2 implementation Signed-off-by: Joe Rowell Signed-off-by: Joe Rowell --- vllm/model_executor/models/laguna.py | 886 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/reasoning/__init__.py | 4 + .../reasoning/poolside_v1_reasoning_parser.py | 72 ++ vllm/tool_parsers/__init__.py | 4 + vllm/tool_parsers/poolside_v1_tool_parser.py | 576 ++++++++++++ vllm/transformers_utils/config.py | 44 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/laguna.py | 120 +++ 9 files changed, 1693 insertions(+), 16 deletions(-) create mode 100644 vllm/model_executor/models/laguna.py create mode 100644 vllm/reasoning/poolside_v1_reasoning_parser.py create mode 100644 vllm/tool_parsers/poolside_v1_tool_parser.py create mode 100644 vllm/transformers_utils/configs/laguna.py diff --git a/vllm/model_executor/models/laguna.py b/vllm/model_executor/models/laguna.py new file mode 100644 index 000000000000..08f35d691817 --- /dev/null +++ b/vllm/model_executor/models/laguna.py @@ -0,0 +1,886 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Laguna model compatible with HuggingFace weights.""" + +import typing +from collections.abc import Callable, Iterable +from itertools import islice + +import torch +import torch.nn.functional as F +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) +from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + + +class LagunaMLP(nn.Module): + """Dense MLP for Laguna (used in mlp_only_layers).""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + # gate_proj and up_proj are kept as separate ColumnParallelLinear + # rather than merged via MergedColumnParallelLinear. The merged form + # requires per-partition NVFP4 global scales (weight_global_scale, + # input_global_scale) to be packed into a length-2 PerTensorScaleParameter + # and then collapsed via .max() in process_weights_after_loading; this + # doesn't round-trip cleanly through Marlin's NVFP4 stacked-layer code + # path. Splitting yields one global scale per Linear, exactly matching + # the standard compressed-tensors per-Linear schema on disk. + self.gate_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_proj", + ) + self.up_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported." + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate, _ = self.gate_proj(x) + up, _ = self.up_proj(x) + x, _ = self.down_proj(F.silu(gate) * up) + return x + + +class LagunaMoE(nn.Module): + """Sparse MoE block for Laguna with optional shared expert and sigmoid routing. + + Key differences from other MoE implementations: + - Uses SIGMOID routing activation (not softmax) + - Shared expert runs in parallel with routed experts (when enabled) + - Matches HF reference: modular_laguna.py LagunaSparseMoeBlock + """ + + def __init__( + self, + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + enable_eplb: bool = False, + ): + super().__init__() + self.config = config + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + + self.tp_size = get_tensor_model_parallel_world_size() + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + + self.n_routed_experts = config.num_experts + self.n_shared_experts = 1 if config.shared_expert_intermediate_size > 0 else 0 + self.routed_scaling_factor = float( + getattr(config, "moe_routed_scaling_factor", 1.0) + ) + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}." + ) + + # Load balancing settings. + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = enable_eplb + eplb_config.num_redundant_experts = ( + eplb_config.num_redundant_experts + if eplb_config.num_redundant_experts is not None + else 0 + ) + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + # Router gate + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + # Shared expert (optional) - passed to FusedMoE for overlap optimization + self.shared_expert: LagunaMLP | None + if config.shared_expert_intermediate_size > 0: + self.shared_expert = LagunaMLP( + hidden_size=config.hidden_size, + intermediate_size=config.shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, # Reduce after shared+routed combine + prefix=f"{prefix}.shared_expert", + ) + else: + self.shared_expert = None + + # Auxiliary-loss-free load-balancing bias (arXiv:2408.15664). The + # checkpoint stores one [num_experts] tensor per MoE layer at + # `mlp.experts.e_score_correction_bias`; registering it as a Parameter + # on the FusedMoE lets the weight loader pick it up and the router + # add it during top-k selection. The fused top-k bias router requires + # float32 regardless of model dtype. + e_score_correction_bias = torch.nn.Parameter( + torch.zeros(config.num_experts, dtype=torch.float32), + requires_grad=False, + ) + + # FusedMoE with SIGMOID routing. Passing `shared_experts=` lets the + # layer overlap the shared-expert compute with the all2all dispatch. + # `apply_routed_scale_to_output=True` makes FusedMoE handle the + # routed_scaling_factor, shared+routed combine, and TP all-reduce + # internally, so forward() just returns the final hidden states. + self.experts = FusedMoE( + shared_experts=self.shared_expert, + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + scoring_func="sigmoid", + use_grouped_topk=False, + apply_router_weight_on_input=bool(config.moe_apply_router_weight_on_input), + e_score_correction_bias=e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + routed_scaling_factor=self.routed_scaling_factor, + apply_routed_scale_to_output=True, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + router_logits, _ = self.gate(hidden_states) + router_logits = router_logits.float() + softcap = getattr(self.config, "moe_router_logit_softcapping", 0.0) or 0.0 + if softcap > 0.0: + router_logits = torch.tanh(router_logits / softcap) * softcap + + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class LagunaAttention(nn.Module): + """Laguna attention with optional softplus output gating. + + Supports per-layer sliding window attention when ``config.layer_types`` + is present. Layers whose type is ``"sliding_attention"`` use + ``config.sliding_window``; all other layers (typically labelled + ``"full_attention"``) use full attention. When ``layer_types`` is + absent every layer defaults to full attention for backwards + compatibility. + """ + + def __init__( + self, + config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position_embeddings: int = 131072, + head_dim: int | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + attention_sink: bool = False, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.max_position_embeddings = max_position_embeddings + + # Gating flag + self.gating = config.gating + + # Per-layer sliding window (follows Gemma2/Cohere2 convention) + layer_types = getattr(config, "layer_types", None) + if layer_types is not None: + layer_idx = extract_layer_index(prefix) + is_sliding = layer_types[layer_idx] == "sliding_attention" + self.sliding_window = config.sliding_window if is_sliding else None + else: + self.sliding_window = None + + # QKV projection (no bias for Laguna) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + # Output projection + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # Gating projection (Laguna-specific, optional) + # config.gating may be: + # - True / "per-element": one gate per (head, head_dim) channel + # - "per-head": one gate per head, broadcast across head_dim + if self.gating: + # v5 LagunaConfig uses ``gating=True`` for per-head; older configs + # used ``"per-head"``. Accept both. ``"per-element"`` (or legacy + # ``True``) means per-element gating with output size num_heads × + # head_dim. + gate_per_head = self.gating is True or self.gating == "per-head" + g_out = ( + self.total_num_heads + if gate_per_head + else self.total_num_heads * self.head_dim + ) + self.g_proj = ColumnParallelLinear( + hidden_size, + g_out, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.g_proj", + ) + self.gate_per_head = gate_per_head + else: + self.g_proj = None + self.gate_per_head = False + + # Attention sinks (learnable per-head bias for SWA layers) + sinks = None + if attention_sink: + self.sink = torch.nn.Parameter( + torch.empty(self.total_num_heads // tp_size, requires_grad=False) + ) + sinks = self.sink + + # Resolve rope params per-layer-type. ``config.rope_parameters`` is + # either a flat dict (legacy) or a nested ``{layer_type: rope_dict}`` + # (v5 Laguna-XS schema). The v5 form is unhashable as-is and would + # crash `get_rope`'s cache lookup, so always pull out the layer's + # sub-dict before forwarding. + layer_type = ( + layer_types[extract_layer_index(prefix)] + if layer_types is not None + else "full_attention" + ) + is_sliding = layer_type == "sliding_attention" + + top_rope = getattr(config, "rope_parameters", None) or {} + if any(isinstance(v, dict) for v in top_rope.values()): + # Nested per-layer-type form. + base_rope = top_rope.get(layer_type) or top_rope.get("full_attention") or {} + else: + base_rope = top_rope + + # Older flat-rope ckpts can carry a separate `swa_rope_parameters` + # for SWA layers. Prefer it when present; otherwise the nested + # rope dict above already supplies the correct sub-config. + swa_rope = getattr(config, "swa_rope_parameters", None) + if ( + is_sliding + and swa_rope is None + and not any(isinstance(v, dict) for v in top_rope.values()) + ): + logger.warning_once( + "Laguna config has sliding_attention layers but neither " + "`swa_rope_parameters` nor a nested per-layer-type " + "`rope_parameters` — SWA layers will reuse the global rope. " + "If the checkpoint was trained with distinct SWA rope " + "(theta / partial_rotary_factor), regenerate its HF config " + "to include either form." + ) + rope_params = swa_rope if (is_sliding and swa_rope is not None) else base_rope + # `partial_rotary_factor` may live on the top-level config (main attention) + # or on the per-layer rope dict itself (e.g. SWA can differ). Inject the + # top-level value into `rope_params` if the dict doesn't already set it. + top_partial = getattr(config, "partial_rotary_factor", None) + if top_partial is not None and "partial_rotary_factor" not in rope_params: + rope_params = {**rope_params, "partial_rotary_factor": top_partial} + + # Rotary embeddings (YaRN) + self.rotary_emb = get_rope( + head_size=self.head_dim, + max_position=max_position_embeddings, + is_neox_style=True, + rope_parameters=rope_params, + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=self.sliding_window, + prefix=f"{prefix}.attn", + sinks=sinks, + ) + + # QK normalization (like Qwen3) + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + + # Apply gating if enabled (compute softplus in float32 for precision) + if self.gating and self.g_proj is not None: + gate, _ = self.g_proj(hidden_states) + gate = F.softplus(gate.float()).type_as(attn_output) + if self.gate_per_head: + # gate: [..., num_heads]; broadcast across head_dim + attn_shape = attn_output.shape + attn_output = ( + attn_output.view(*attn_shape[:-1], self.num_heads, self.head_dim) + * gate.unsqueeze(-1) + ).view(attn_shape) + else: + attn_output = attn_output * gate + + output, _ = self.o_proj(attn_output) + return output + + +class LagunaDecoderLayer(nn.Module): + def __init__( + self, + config, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + enable_eplb: bool = False, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + layer_idx = extract_layer_index(prefix) + + # Determine if this layer uses sliding window attention + layer_types = getattr(config, "layer_types", None) + is_sliding = ( + layer_types is not None and layer_types[layer_idx] == "sliding_attention" + ) + + # Enable attention sinks on SWA layers when configured + attention_sink = is_sliding and getattr( + config, "swa_attention_sink_enabled", False + ) + + # Optional per-layer override of head count (Laguna-XS). + per_layer_heads = getattr(config, "num_attention_heads_per_layer", None) + layer_num_heads = ( + per_layer_heads[layer_idx] + if per_layer_heads is not None + else config.num_attention_heads + ) + + self.self_attn = LagunaAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=layer_num_heads, + num_kv_heads=config.num_key_value_heads, + max_position_embeddings=config.max_position_embeddings, + head_dim=getattr(config, "head_dim", None), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + attention_sink=attention_sink, + ) + + # Check if this layer uses MoE or dense MLP (matches Qwen2/Qwen3 convention) + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + ) + self.is_moe_layer = ( + (layer_idx not in mlp_only_layers) + and (config.num_experts > 0) + and ((layer_idx + 1) % config.decoder_sparse_step == 0) + ) + + if self.is_moe_layer: + self.mlp = LagunaMoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, + ) + else: + self.mlp = LagunaMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class LagunaModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + enable_eplb = vllm_config.parallel_config.enable_eplb + eplb_config = vllm_config.parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts + self.config = config + self.quant_config = quant_config + + # Disable the model-level sliding-window fallback in Attention.__init__. + # Laguna drives SWA per-layer via `layer_types`, passing + # `per_layer_sliding_window=self.sliding_window` (None for global + # layers). Without this, global layers whose `per_layer_sliding_window` + # is None would pick up `cache_config.sliding_window` + # (populated from `config.sliding_window`) as a fallback, silently + # applying a 512-token window to full-attention layers. + if cache_config is not None: + cache_config.sliding_window = None + + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: LagunaDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + enable_eplb=enable_eplb, + ), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + """Get expert parameter mapping for weight loading. + + Returns mapping tuples of (param_name, weight_name, expert_id, shard_id) + that handle both weights and quantization scales. + """ + return FusedMoE.make_expert_params_mapping( + self, + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + num_redundant_experts=self.num_redundant_experts, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + # gate_proj and up_proj are loaded as separate Linears (see + # LagunaMLP) so no merge entry is needed here. + ] + + # Suffixes to skip for GPTQ/modelopt models if param doesn't exist + ignore_suffixes = ( + ".bias", + "_bias", + ".k_scale", + "_k_scale", + ".v_scale", + "_v_scale", + ".weight_scale", + "_weight_scale", + ".input_scale", + "_input_scale", + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + + tp_rank = get_tensor_model_parallel_rank() + + for name, loaded_weight in weights: + # Handle attention sinks (distributed across ranks). Derive the + # per-rank slice from the parameter's own shape so per-layer + # variations in head count are handled correctly. + if "sink" in name: + param = params_dict.get(name) + if param is not None: + layer_heads_per_rank = param.shape[0] + layer_head_start = tp_rank * layer_heads_per_rank + narrow_weight = loaded_weight.narrow( + 0, layer_head_start, layer_heads_per_rank + ) + param.data.copy_(narrow_weight) + loaded_params.add(name) + continue + + # Handle KV cache quantization scales + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + assert loaded_weight.numel() == 1, ( + f"KV scale numel {loaded_weight.numel()} != 1" + ) + loaded_weight = loaded_weight.squeeze() + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + # Handle stacked params (QKV, gate_up for + # non-expert layers and shared_expert) + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + # Skip expert weights - handled below via expert_params_mapping + if "mlp.experts" in name and "shared_expert" not in name: + continue + name = name.replace(weight_name, param_name) + + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + # Remap FP8 kv_scale names for backwards compatibility + if name.endswith("scale"): + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break + else: + # Try expert params mapping (handles weights + quantization scales) + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Mark as expert weight so we skip regular loading below + is_expert_weight = True + + # Create mapped name without modifying original + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + if ( + name_mapped.endswith(ignore_suffixes) + and name_mapped not in params_dict + ): + continue + if name_mapped not in params_dict: + continue + + param = params_dict[name_mapped] + # Use return_success to handle expert parallelism correctly + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + loaded_params.add(name_mapped) + break + else: + # Expert weight not mapped to this rank - skip + if is_expert_weight: + continue + + # Remap kv_scale names before the ignore_suffixes filter: + # the suffix list includes .k_scale/.v_scale, so filtering + # first drops the checkpoint key before remap can rewrite + # it to the .attn.* name that exists in params_dict. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class LagunaForCausalLM(nn.Module, SupportsPP, SupportsLoRA): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.model = LagunaModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if self.config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 80cc8b895345..eba288dcc77a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -150,6 +150,7 @@ "KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"), "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"), "Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"), + "LagunaForCausalLM": ("laguna", "LagunaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # For decapoda-research/llama-* diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 755fa56d294c..2347eae54c25 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -32,6 +32,10 @@ "deepseek_v3_reasoning_parser", "DeepSeekV3ReasoningParser", ), + "poolside_v1": ( + "poolside_v1_reasoning_parser", + "PoolsideV1ReasoningParser", + ), "ernie45": ( "ernie45_reasoning_parser", "Ernie45ReasoningParser", diff --git a/vllm/reasoning/poolside_v1_reasoning_parser.py b/vllm/reasoning/poolside_v1_reasoning_parser.py new file mode 100644 index 000000000000..30031d8513a9 --- /dev/null +++ b/vllm/reasoning/poolside_v1_reasoning_parser.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Laguna reasoning parser. + +``DeepSeekV3ReasoningParser.is_reasoning_end`` walks the entire +token sequence backwards and returns ``True`` on the first ```` it +sees. When called on ``prompt_token_ids`` that mistakes any stray +```` in conversation history, few-shot examples or tool descriptions +for a template-injected "thinking already ended" marker. In the streaming +path (see ``vllm/entrypoints/openai/chat_completion/serving.py``, +``prompt_is_reasoning_end_arr``) that false positive short-circuits the +reasoning parser for the whole response, so any ``...`` the +model emits itself ends up in the content field instead of the reasoning +field. + +As we have more flexible templates, we instead scope +the backward search to the current assistant turn: the +walk terminates as soon as we hit the ```` start-of-message +token. A ```` in a prior user turn or few-shot example is no longer +visible. +""" + +from collections.abc import Sequence + +from transformers import PreTrainedTokenizerBase + +from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from vllm.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser +from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser + + +class PoolsideV1ReasoningParser(DeepSeekV3ReasoningParser): + """Drop-in replacement for ``deepseek_v3`` that tolerates ```` + tokens appearing anywhere in the prompt other than the generation prefix. + """ + + _start_of_assistant_message = "" + + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + if self._start_of_assistant_message not in self.vocab: + raise ValueError( + f"Tokenizer must contain {self._start_of_assistant_message!r} token" + ) + self._start_of_assistant_message_token_id = self.vocab[ + self._start_of_assistant_message + ] + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + # IdentityReasoningParser always returns True: no reasoning to parse. + if isinstance(self._parser, IdentityReasoningParser): + return True + + assert isinstance(self._parser, DeepSeekR1ReasoningParser) + for tok_id in reversed(input_ids): + # : reasoning is not yet ended. + if tok_id == self._parser.start_token_id: + return False + # : reasoning has ended. + if tok_id == self._parser.end_token_id: + return True + # : reached the start of the current assistant turn + # without seeing either marker. Anything further back belongs to + # the prior conversation and should be ignored. + if tok_id == self._start_of_assistant_message_token_id: + return False + return False + + +__all__ = ["PoolsideV1ReasoningParser"] diff --git a/vllm/tool_parsers/__init__.py b/vllm/tool_parsers/__init__.py index 8a39ca825d5f..61a11cbcc376 100644 --- a/vllm/tool_parsers/__init__.py +++ b/vllm/tool_parsers/__init__.py @@ -66,6 +66,10 @@ "hermes_tool_parser", "Hermes2ProToolParser", ), + "poolside_v1": ( + "poolside_v1_tool_parser", + "PoolsideV1ToolParser", + ), "hunyuan_a13b": ( "hunyuan_a13b_tool_parser", "HunyuanA13BToolParser", diff --git a/vllm/tool_parsers/poolside_v1_tool_parser.py b/vllm/tool_parsers/poolside_v1_tool_parser.py new file mode 100644 index 000000000000..9bf74939343b --- /dev/null +++ b/vllm/tool_parsers/poolside_v1_tool_parser.py @@ -0,0 +1,576 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GLM-4 Tool Call Parser with incremental string streaming support. + +This parser fixes the streaming issue reported in Issue #32829 where long string +parameters (e.g., file content with 4000+ characters of code) are buffered until +complete, causing multi-second delays before the user sees any content. + +The fix streams string values incrementally as they arrive, providing a true +streaming experience for long content. +""" + +import ast +import json +from collections.abc import Sequence +from typing import Any + +import partial_json_parser.core.complete +import regex as re +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.chat_utils import make_tool_call_id +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, +) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + Tool, + ToolParser, +) + +logger = init_logger(__name__) + + +class PoolsideV1ToolParser(ToolParser): + """Tool parser for GLM-4 models with incremental string streaming. + + This parser emits tool-call deltas incrementally as arguments arrive. + For string-type parameters, content is streamed character-by-character + rather than waiting for the complete tag. + """ + + def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): + super().__init__(tokenizer, tools) + # Stateful streaming fields + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict[str, Any]] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = [] + + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + self.arg_key_start: str = "" + self.arg_key_end: str = "" + self.arg_val_start: str = "" + self.arg_val_end: str = "" + + self.tool_calls_start_token = self.tool_call_start_token + + self.func_call_regex = re.compile(r".*?", re.DOTALL) + self.func_detail_regex = re.compile( + r"([^\n]*)\n(.*)", re.DOTALL + ) + self.func_arg_regex = re.compile( + r"(.*?)\s*(.*?)", re.DOTALL + ) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction." + ) + + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + self._buffer: str = "" + + # Streaming state for incremental tool-call streaming + self._in_tool_call: bool = False + self._current_tool_name: str | None = None + self._pending_key: str | None = None + self._streaming_string_value: bool = False + self._tool_call_ids: list[str] = [] + self._args_started: list[bool] = [] + self._args_closed: list[bool] = [] + self._seen_keys: list[set[str]] = [] + + @staticmethod + def _deserialize(value: str) -> Any: + try: + return json.loads(value) + except json.JSONDecodeError: + pass + + try: + return ast.literal_eval(value) + except (ValueError, SyntaxError): + pass + + return value + + @staticmethod + def _json_escape_string_content(s: str) -> str: + """JSON-escape string content for incremental streaming. + + This escapes the content that goes INSIDE a JSON string (between quotes), + not including the surrounding quotes themselves. + """ + if not s: + return "" + return json.dumps(s, ensure_ascii=False)[1:-1] + + @staticmethod + def _is_string_type( + tool_name: str, + arg_name: str, + tools: list[Tool] | None, + ) -> bool: + if tools is None: + return False + for tool in tools: + if tool.function.name != tool_name: + continue + if tool.function.parameters is None: + return False + arg_type = ( + tool.function.parameters.get("properties", {}) + .get(arg_name, {}) + .get("type", None) + ) + return arg_type == "string" + logger.debug("No tool named '%s'.", tool_name) + return False + + @staticmethod + def _tools_enabled(request: ChatCompletionRequest) -> bool: + """Return whether tool parsing should be applied for this request.""" + try: + tools = getattr(request, "tools", None) + tool_choice = getattr(request, "tool_choice", None) + return bool(tools) and tool_choice != "none" + except Exception: + logger.exception("Failed to determine if tools are enabled.") + return False + + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + """Adjust request parameters for tool call token handling.""" + request = super().adjust_request(request) + if request.tools and request.tool_choice != "none": + # Ensure tool call tokens (, ) are not skipped + # during decoding. Even though they are not marked as special tokens, + # setting skip_special_tokens=False ensures proper handling in + # transformers 5.x where decoding behavior may have changed. + request.skip_special_tokens = False + return request + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + matched_tool_calls = self.func_call_regex.findall(model_output) + logger.debug("model_output: %s", model_output) + try: + tool_calls: list[ToolCall] = [] + for match in matched_tool_calls: + tc_detail = self.func_detail_regex.search(match) + if not tc_detail: + logger.warning( + "Failed to parse tool call details from: %s", + match, + ) + continue + tc_name = tc_detail.group(1).strip() + tc_args = tc_detail.group(2) + pairs = self.func_arg_regex.findall(tc_args) if tc_args else [] + arg_dct: dict[str, Any] = {} + for key, value in pairs: + arg_key = key.strip() + arg_val = value.strip() + if not self._is_string_type(tc_name, arg_key, request.tools): + arg_val = self._deserialize(arg_val) + logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val) + arg_dct[arg_key] = arg_val + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall( + name=tc_name, + arguments=json.dumps(arg_dct, ensure_ascii=False), + ), + ) + ) + except Exception: + logger.exception("Failed to extract tool call spec") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + else: + if len(tool_calls) > 0: + content: str | None = model_output[ + : model_output.find(self.tool_calls_start_token) + ] + # Normalize empty/whitespace-only content to None + if not content or not content.strip(): + content = None + return ExtractedToolCallInformation( + tools_called=True, tool_calls=tool_calls, content=content + ) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + if not self._tools_enabled(request): + return DeltaMessage(content=delta_text) if delta_text else None + + self._buffer += delta_text + + pending_deltas: dict[int, DeltaToolCall] = {} + content: str | None = None + + while True: + if not self._in_tool_call: + start_idx = self._buffer.find(self.tool_call_start_token) + if start_idx == -1: + # Check for partial start token at end of buffer + for i in range(1, len(self.tool_call_start_token)): + if self._buffer.endswith(self.tool_call_start_token[:i]): + out = self._buffer[:-i] + self._buffer = self._buffer[-i:] + if out: + content = (content or "") + out + break + else: + out = self._buffer + self._buffer = "" + if out: + content = (content or "") + out + break + + if start_idx > 0: + content = (content or "") + self._buffer[:start_idx] + self._buffer = self._buffer[start_idx:] + + self._buffer = self._buffer[len(self.tool_call_start_token) :] + self._begin_tool_call() + continue + + # Parse tool name first + if not self.current_tool_name_sent: + nl = self._buffer.find("\n") + ak = self._buffer.find(self.arg_key_start) + end = self._buffer.find(self.tool_call_end_token) + candidates = [i for i in [nl, ak, end] if i != -1] + if not candidates: + break + cut = min(candidates) + tool_name = self._buffer[:cut].strip() + if tool_name == "" and cut == end: + # Handle empty tool call like ``. + # Consume the tokens and reset state to avoid infinite loop. + self._buffer = self._buffer[end + len(self.tool_call_end_token) :] + self._finish_tool_call() + self._revert_last_tool_call_state() + continue + + if cut == nl: + self._buffer = self._buffer[nl + 1 :] + else: + self._buffer = self._buffer[cut:] + + self._current_tool_name = tool_name + self.current_tool_name_sent = True + self._update_tool_name(pending_deltas, tool_name) + continue + + assert self._current_tool_name is not None + + # Handle incremental string value streaming + if self._streaming_string_value: + val_end = self._buffer.find(self.arg_val_end) + if val_end != -1: + raw_content = self._buffer[:val_end] + self._buffer = self._buffer[val_end + len(self.arg_val_end) :] + self._streaming_string_value = False + self._pending_key = None + + escaped = self._json_escape_string_content(raw_content) + frag = escaped + '"' + self.streamed_args_for_tool[self.current_tool_id] += frag + self._update_tool_args(pending_deltas, frag) + continue + + # Check for partial at end + safe_len = len(self._buffer) + for i in range(1, len(self.arg_val_end)): + if self._buffer.endswith(self.arg_val_end[:i]): + safe_len = len(self._buffer) - i + break + + if safe_len > 0: + to_emit = self._buffer[:safe_len] + self._buffer = self._buffer[safe_len:] + escaped = self._json_escape_string_content(to_emit) + if escaped: + self.streamed_args_for_tool[self.current_tool_id] += escaped + self._update_tool_args(pending_deltas, escaped) + break + + # If we have a pending key, parse its value + if self._pending_key is not None: + val_pos = self._buffer.find(self.arg_val_start) + if val_pos == -1: + break + if val_pos > 0: + self._buffer = self._buffer[val_pos:] + + key = (self._pending_key or "").strip() + + is_string = self._is_string_type( + self._current_tool_name, key, request.tools + ) + + if is_string: + # String type: stream incrementally + self._buffer = self._buffer[len(self.arg_val_start) :] + + if key in self._seen_keys[self.current_tool_id]: + self._pending_key = None + continue + + self._seen_keys[self.current_tool_id].add(key) + key_json = json.dumps(key, ensure_ascii=False) + + if not self._args_started[self.current_tool_id]: + frag = "{" + key_json + ': "' + self._args_started[self.current_tool_id] = True + else: + frag = ", " + key_json + ': "' + + self.streamed_args_for_tool[self.current_tool_id] += frag + self._streaming_string_value = True + self._update_tool_args(pending_deltas, frag) + continue + + # Non-string type: wait for complete value + val_end = self._buffer.find(self.arg_val_end) + if val_end == -1: + break + + raw_val = self._buffer[len(self.arg_val_start) : val_end].strip() + self._buffer = self._buffer[val_end + len(self.arg_val_end) :] + self._pending_key = None + + frag_or_none = self._append_arg_fragment(key=key, raw_val=raw_val) + if frag_or_none: + self._update_tool_args(pending_deltas, frag_or_none) + continue + + # Parse next arg or close + end_pos = self._buffer.find(self.tool_call_end_token) + key_pos = self._buffer.find(self.arg_key_start) + if end_pos != -1 and (key_pos == -1 or end_pos < key_pos): + self._buffer = self._buffer[end_pos + len(self.tool_call_end_token) :] + frag_or_none = self._close_args_if_needed() + # Finalize prev_tool_call_arr with complete parsed arguments + if self._current_tool_name: + try: + full_args_str = self.streamed_args_for_tool[ + self.current_tool_id + ] + args_dict = json.loads(full_args_str) + self.prev_tool_call_arr[self.current_tool_id] = { + "name": self._current_tool_name, + "arguments": args_dict, + } + except (json.JSONDecodeError, IndexError) as e: + logger.warning( + "Failed to finalize tool call state for tool %d: %s", + self.current_tool_id, + e, + ) + self._finish_tool_call() + if frag_or_none: + self._update_tool_args(pending_deltas, frag_or_none) + continue + + if key_pos == -1: + break + if key_pos > 0: + self._buffer = self._buffer[key_pos:] + key_end = self._buffer.find(self.arg_key_end) + if key_end == -1: + break + key = self._buffer[len(self.arg_key_start) : key_end] + self._buffer = self._buffer[key_end + len(self.arg_key_end) :] + self._pending_key = key + continue + + tool_calls = list(pending_deltas.values()) + if content is None and len(tool_calls) == 0: + if request.logprobs: + return DeltaMessage(content="") + return None + return DeltaMessage(content=content, tool_calls=tool_calls) + + def _ensure_tool_state(self) -> None: + while len(self._tool_call_ids) <= self.current_tool_id: + self._tool_call_ids.append( + make_tool_call_id(id_type="random", func_name=None, idx=None) + ) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self._args_started) <= self.current_tool_id: + self._args_started.append(False) + while len(self._args_closed) <= self.current_tool_id: + self._args_closed.append(False) + while len(self._seen_keys) <= self.current_tool_id: + self._seen_keys.append(set()) + + def _begin_tool_call(self) -> None: + if self.current_tool_id == -1: + self.current_tool_id = 0 + else: + self.current_tool_id += 1 + self._ensure_tool_state() + self.current_tool_name_sent = False + self._current_tool_name = None + self._pending_key = None + self._streaming_string_value = False + self._in_tool_call = True + + def _finish_tool_call(self) -> None: + self._in_tool_call = False + self._current_tool_name = None + self._pending_key = None + self._streaming_string_value = False + + def _revert_last_tool_call_state(self) -> None: + """Revert the state allocation for the last tool call.""" + if self.current_tool_id < 0: + return + self._tool_call_ids.pop() + self.streamed_args_for_tool.pop() + self.prev_tool_call_arr.pop() + self._args_started.pop() + self._args_closed.pop() + self._seen_keys.pop() + self.current_tool_id -= 1 + + def _get_or_create_delta(self, pending: dict[int, DeltaToolCall]) -> DeltaToolCall: + idx = self.current_tool_id + if idx not in pending: + pending[idx] = DeltaToolCall( + index=idx, + function=DeltaFunctionCall(), + ) + delta = pending[idx] + assert delta.function is not None + return delta + + def _update_tool_name( + self, pending: dict[int, DeltaToolCall], tool_name: str + ) -> None: + self.prev_tool_call_arr[self.current_tool_id] = { + "name": self._current_tool_name, + "arguments": {}, + } + delta = self._get_or_create_delta(pending) + delta.id = self._tool_call_ids[self.current_tool_id] + delta.type = "function" + delta.function.name = tool_name + if delta.function.arguments is None: + delta.function.arguments = "" + + @staticmethod + def _complete_json_prefix( + json_prefix: str, + allowed_partial_types: Allow, + ) -> dict | None: + """Complete a partial JSON prefix into a valid JSON object. + + Returns (formatted_prefix, parsed_dict) or None on failure. + + Note: ``partial_json_parser`` strips trailing whitespace before + parsing (``complete.py:20``), which means the returned slice is + shorter than ``json_prefix`` when it has trailing whitespace. + Since the parser controls the construction of the json_prefix value, + this code relies on it being a valid prefix and we only use the fix for + the completion of the JSON object. + """ + try: + _, partial_str_completion = partial_json_parser.core.complete.fix( + json_prefix, + allowed_partial_types, + ) + return json.loads(json_prefix + partial_str_completion) + except Exception: + return None + + def _update_tool_args( + self, pending: dict[int, DeltaToolCall], fragment: str + ) -> None: + result = self._complete_json_prefix( + self.streamed_args_for_tool[self.current_tool_id], + Allow.ALL, + ) + if result is not None: + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = result + delta = self._get_or_create_delta(pending) + if delta.function.arguments is None: + delta.function.arguments = "" + delta.function.arguments += fragment + + def _append_arg_fragment( + self, + *, + key: str, + raw_val: str, + ) -> str | None: + key = key.strip() + if not key: + return None + if key in self._seen_keys[self.current_tool_id]: + return None + + # This function is only called for non-string types (already checked + # by _is_string_type in the caller), so we always deserialize. + val_obj: Any = self._deserialize(raw_val) + + key_json = json.dumps(key, ensure_ascii=False) + val_json = json.dumps(val_obj, ensure_ascii=False) + + if not self._args_started[self.current_tool_id]: + fragment = "{" + key_json + ": " + val_json + self._args_started[self.current_tool_id] = True + else: + fragment = ", " + key_json + ": " + val_json + + self._seen_keys[self.current_tool_id].add(key) + self.streamed_args_for_tool[self.current_tool_id] += fragment + return fragment + + def _close_args_if_needed(self) -> str | None: + if self._args_closed[self.current_tool_id]: + return None + self._args_closed[self.current_tool_id] = True + if not self._args_started[self.current_tool_id]: + fragment = "{}" + self.streamed_args_for_tool[self.current_tool_id] = fragment + else: + fragment = "}" + self.streamed_args_for_tool[self.current_tool_id] += fragment + return fragment diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index bb6ad1056b7b..47b74093b06c 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -127,6 +127,7 @@ def __getitem__(self, key): qwen3_next="Qwen3NextConfig", qwen3_5="Qwen3_5Config", qwen3_5_moe="Qwen3_5MoeConfig", + laguna="LagunaConfig", lfm2_moe="Lfm2MoeConfig", tarsier2="Tarsier2Config", ) @@ -409,22 +410,33 @@ def patch_rope_parameters(config: PretrainedConfig) -> None: ompe = getattr(config, "original_max_position_embeddings", None) if Version(version("transformers")) < Version("5.0.0"): - # Transformers v4 installed, legacy config fields may be present - if (rope_scaling := getattr(config, "rope_scaling", None)) is not None: - config.rope_parameters = rope_scaling - if ( - rope_theta is not None - or partial_rotary_factor is not None - or ompe is not None - ) and not getattr(config, "rope_parameters", None): - config.rope_parameters = {"rope_type": "default"} - # Patch legacy fields into rope_parameters - if rope_theta is not None: - config.rope_parameters["rope_theta"] = rope_theta - if partial_rotary_factor is not None: - config.rope_parameters["partial_rotary_factor"] = partial_rotary_factor - if ompe is not None: - config.rope_parameters["original_max_position_embeddings"] = ompe + # Transformers v4 installed, legacy config fields may be present. + existing_rp = getattr(config, "rope_parameters", None) + if isinstance(existing_rp, dict) and is_rope_parameters_nested(existing_rp): + # Interleaved-attention models (e.g. Laguna-XS.2) ship a nested + # {layer_type: {...}} rope_parameters that the model code indexes + # by layer_type. The per-layer-type sub-dicts already carry the + # correct rope_theta / partial_rotary_factor / ompe (the converter + # places top-level legacy fields inside full_attention), so don't + # merge top-level fields here — that would shadow the per-type + # values and break sliding-attention layers. + pass + else: + if (rope_scaling := getattr(config, "rope_scaling", None)) is not None: + config.rope_parameters = rope_scaling + if ( + rope_theta is not None + or partial_rotary_factor is not None + or ompe is not None + ) and not getattr(config, "rope_parameters", None): + config.rope_parameters = {"rope_type": "default"} + # Patch legacy fields into rope_parameters + if rope_theta is not None: + config.rope_parameters["rope_theta"] = rope_theta + if partial_rotary_factor is not None: + config.rope_parameters["partial_rotary_factor"] = partial_rotary_factor + if ompe is not None: + config.rope_parameters["original_max_position_embeddings"] = ompe elif rope_theta is not None or getattr(config, "rope_parameters", None): # Transformers v5 installed # Patch these fields in case they used non-standard names diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 667ed5a2596c..8c4d01a428bd 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -45,6 +45,7 @@ # `FalconConfig` class from the official HuggingFace transformers library. "RWConfig": "vllm.transformers_utils.configs.falcon", "JAISConfig": "vllm.transformers_utils.configs.jais", + "LagunaConfig": "vllm.transformers_utils.configs.laguna", "Lfm2MoeConfig": "vllm.transformers_utils.configs.lfm2_moe", "MedusaConfig": "vllm.transformers_utils.configs.medusa", "MiDashengLMConfig": "vllm.transformers_utils.configs.midashenglm", @@ -105,6 +106,7 @@ "IsaacConfig", "RWConfig", "JAISConfig", + "LagunaConfig", "Lfm2MoeConfig", "MedusaConfig", "MiDashengLMConfig", diff --git a/vllm/transformers_utils/configs/laguna.py b/vllm/transformers_utils/configs/laguna.py new file mode 100644 index 000000000000..2702d3af5aa1 --- /dev/null +++ b/vllm/transformers_utils/configs/laguna.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers.configuration_utils import PretrainedConfig + + +class LagunaConfig(PretrainedConfig): + model_type = "laguna" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.g_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size: int = 100352, + hidden_size: int = 2048, + intermediate_size: int = 8192, + num_hidden_layers: int = 40, + num_attention_heads: int = 48, + num_key_value_heads: int = 8, + head_dim: int = 128, + qkv_bias: bool = False, + attention_bias: bool = False, + gating: bool | str = True, + hidden_act: str = "silu", + max_position_embeddings: int = 131072, + initializer_range: float = 0.02, + rms_norm_eps: float = 1e-6, + use_cache: bool = True, + tie_word_embeddings: bool = False, + rope_theta: float = 500000.0, + rope_scaling: dict | None = None, + rope_parameters: dict | None = None, + partial_rotary_factor: float = 1.0, + attention_dropout: float = 0.0, + sliding_window: int | None = None, + layer_types: list[str] | None = None, + swa_attention_sink_enabled: bool = False, + swa_rope_parameters: dict | None = None, + num_attention_heads_per_layer: list[int] | None = None, + num_experts: int = 256, + num_experts_per_tok: int = 8, + moe_intermediate_size: int = 512, + shared_expert_intermediate_size: int = 512, + norm_topk_prob: bool = True, + decoder_sparse_step: int = 1, + mlp_only_layers: list[int] | None = None, + router_aux_loss_coef: float = 0.001, + output_router_logits: bool = False, + moe_routed_scaling_factor: float = 1.0, + moe_apply_router_weight_on_input: bool = False, + **kwargs, + ): + if mlp_only_layers is None: + mlp_only_layers = [0] + + # Accept either v4-style (rope_theta + rope_scaling) or v5-style + # (rope_parameters). Translate v5 → v4 so downstream code has one path. + if rope_parameters is not None: + rp = dict(rope_parameters) + rope_theta = float(rp.pop("rope_theta", rope_theta)) + rt = rp.pop("rope_type", None) + if rt is not None and rt != "default": + rope_scaling = {"rope_type": rt, **rp} + elif rp and rope_scaling is None: + rope_scaling = {"rope_type": "default", **rp} + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.qkv_bias = qkv_bias + self.attention_bias = attention_bias + self.gating = gating + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self.attention_dropout = attention_dropout + self.sliding_window = sliding_window + self.layer_types = layer_types + self.swa_attention_sink_enabled = swa_attention_sink_enabled + self.swa_rope_parameters = swa_rope_parameters + self.num_attention_heads_per_layer = num_attention_heads_per_layer + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.moe_intermediate_size = moe_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.norm_topk_prob = norm_topk_prob + self.decoder_sparse_step = decoder_sparse_step + self.mlp_only_layers = mlp_only_layers + self.router_aux_loss_coef = router_aux_loss_coef + self.output_router_logits = output_router_logits + self.moe_routed_scaling_factor = moe_routed_scaling_factor + self.moe_apply_router_weight_on_input = moe_apply_router_weight_on_input + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +__all__ = ["LagunaConfig"] From 92e83148e70d0841c1610a8ddaaa8d87f5c15616 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 28 Apr 2026 10:26:52 -0400 Subject: [PATCH 2/4] fix pre-commit Signed-off-by: Robert Shaw --- vllm/tool_parsers/poolside_v1_tool_parser.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/tool_parsers/poolside_v1_tool_parser.py b/vllm/tool_parsers/poolside_v1_tool_parser.py index 9bf74939343b..bbba4f25d205 100644 --- a/vllm/tool_parsers/poolside_v1_tool_parser.py +++ b/vllm/tool_parsers/poolside_v1_tool_parser.py @@ -491,6 +491,7 @@ def _update_tool_name( delta = self._get_or_create_delta(pending) delta.id = self._tool_call_ids[self.current_tool_id] delta.type = "function" + assert delta.function is not None delta.function.name = tool_name if delta.function.arguments is None: delta.function.arguments = "" @@ -530,6 +531,7 @@ def _update_tool_args( if result is not None: self.prev_tool_call_arr[self.current_tool_id]["arguments"] = result delta = self._get_or_create_delta(pending) + assert delta.function is not None if delta.function.arguments is None: delta.function.arguments = "" delta.function.arguments += fragment From 304e3fe8b9f05f036fe51cfa7547db1c1e4866a0 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Tue, 28 Apr 2026 10:32:34 -0400 Subject: [PATCH 3/4] fix pre-commit again Signed-off-by: Robert Shaw --- vllm/tool_parsers/poolside_v1_tool_parser.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/tool_parsers/poolside_v1_tool_parser.py b/vllm/tool_parsers/poolside_v1_tool_parser.py index bbba4f25d205..f14b47362917 100644 --- a/vllm/tool_parsers/poolside_v1_tool_parser.py +++ b/vllm/tool_parsers/poolside_v1_tool_parser.py @@ -32,6 +32,9 @@ FunctionCall, ToolCall, ) +from vllm.entrypoints.openai.responses.protocol import ( + ResponsesRequest, +) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ( @@ -153,7 +156,9 @@ def _tools_enabled(request: ChatCompletionRequest) -> bool: logger.exception("Failed to determine if tools are enabled.") return False - def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + def adjust_request( + self, request: ChatCompletionRequest | ResponsesRequest + ) -> ChatCompletionRequest | ResponsesRequest: """Adjust request parameters for tool call token handling.""" request = super().adjust_request(request) if request.tools and request.tool_choice != "none": From ce47c46bd2bc01a53737b0abdd23d0ee20443f90 Mon Sep 17 00:00:00 2001 From: Joe Rowell Date: Tue, 28 Apr 2026 16:52:53 +0100 Subject: [PATCH 4/4] Add model to registry Signed-off-by: Joe Rowell --- tests/models/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/registry.py b/tests/models/registry.py index 6b041c67071d..19304d803160 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -372,6 +372,7 @@ def check_available_online( "KimiLinearForCausalLM": _HfExamplesInfo( "moonshotai/Kimi-Linear-48B-A3B-Instruct", trust_remote_code=True ), + "LagunaForCausalLM": _HfExamplesInfo("poolside/Laguna-XS.2"), "Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B"), "Lfm2MoeForCausalLM": _HfExamplesInfo( "LiquidAI/LFM2-8B-A1B",