diff --git a/tests/models/registry.py b/tests/models/registry.py
index 6b041c67071d..19304d803160 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -372,6 +372,7 @@ def check_available_online(
"KimiLinearForCausalLM": _HfExamplesInfo(
"moonshotai/Kimi-Linear-48B-A3B-Instruct", trust_remote_code=True
),
+ "LagunaForCausalLM": _HfExamplesInfo("poolside/Laguna-XS.2"),
"Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B"),
"Lfm2MoeForCausalLM": _HfExamplesInfo(
"LiquidAI/LFM2-8B-A1B",
diff --git a/vllm/model_executor/models/laguna.py b/vllm/model_executor/models/laguna.py
new file mode 100644
index 000000000000..08f35d691817
--- /dev/null
+++ b/vllm/model_executor/models/laguna.py
@@ -0,0 +1,886 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Inference-only Laguna model compatible with HuggingFace weights."""
+
+import typing
+from collections.abc import Callable, Iterable
+from itertools import islice
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
+from vllm.distributed import (
+ get_ep_group,
+ get_pp_group,
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+)
+from vllm.logger import init_logger
+from vllm.model_executor.layers.attention import Attention
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead,
+ VocabParallelEmbedding,
+)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader,
+ maybe_remap_kv_scale_name,
+)
+from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
+from vllm.model_executor.models.utils import (
+ AutoWeightsLoader,
+ PPMissingLayer,
+ extract_layer_index,
+ is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory,
+ make_layers,
+ maybe_prefix,
+)
+from vllm.sequence import IntermediateTensors
+
+logger = init_logger(__name__)
+
+
+class LagunaMLP(nn.Module):
+ """Dense MLP for Laguna (used in mlp_only_layers)."""
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ quant_config: QuantizationConfig | None = None,
+ reduce_results: bool = True,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ # gate_proj and up_proj are kept as separate ColumnParallelLinear
+ # rather than merged via MergedColumnParallelLinear. The merged form
+ # requires per-partition NVFP4 global scales (weight_global_scale,
+ # input_global_scale) to be packed into a length-2 PerTensorScaleParameter
+ # and then collapsed via .max() in process_weights_after_loading; this
+ # doesn't round-trip cleanly through Marlin's NVFP4 stacked-layer code
+ # path. Splitting yields one global scale per Linear, exactly matching
+ # the standard compressed-tensors per-Linear schema on disk.
+ self.gate_proj = ColumnParallelLinear(
+ hidden_size,
+ intermediate_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_proj",
+ )
+ self.up_proj = ColumnParallelLinear(
+ hidden_size,
+ intermediate_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.up_proj",
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ reduce_results=reduce_results,
+ prefix=f"{prefix}.down_proj",
+ )
+ if hidden_act != "silu":
+ raise ValueError(
+ f"Unsupported activation: {hidden_act}. Only silu is supported."
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ gate, _ = self.gate_proj(x)
+ up, _ = self.up_proj(x)
+ x, _ = self.down_proj(F.silu(gate) * up)
+ return x
+
+
+class LagunaMoE(nn.Module):
+ """Sparse MoE block for Laguna with optional shared expert and sigmoid routing.
+
+ Key differences from other MoE implementations:
+ - Uses SIGMOID routing activation (not softmax)
+ - Shared expert runs in parallel with routed experts (when enabled)
+ - Matches HF reference: modular_laguna.py LagunaSparseMoeBlock
+ """
+
+ def __init__(
+ self,
+ config,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ enable_eplb: bool = False,
+ ):
+ super().__init__()
+ self.config = config
+ self.num_experts = config.num_experts
+ self.top_k = config.num_experts_per_tok
+
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.ep_group = get_ep_group().device_group
+ self.ep_rank = self.ep_group.rank()
+ self.ep_size = self.ep_group.size()
+
+ self.n_routed_experts = config.num_experts
+ self.n_shared_experts = 1 if config.shared_expert_intermediate_size > 0 else 0
+ self.routed_scaling_factor = float(
+ getattr(config, "moe_routed_scaling_factor", 1.0)
+ )
+
+ if self.tp_size > config.num_experts:
+ raise ValueError(
+ f"Tensor parallel size {self.tp_size} is greater than "
+ f"the number of experts {config.num_experts}."
+ )
+
+ # Load balancing settings.
+ vllm_config = get_current_vllm_config()
+ eplb_config = vllm_config.parallel_config.eplb_config
+ self.enable_eplb = enable_eplb
+ eplb_config.num_redundant_experts = (
+ eplb_config.num_redundant_experts
+ if eplb_config.num_redundant_experts is not None
+ else 0
+ )
+ self.n_redundant_experts = eplb_config.num_redundant_experts
+ self.n_logical_experts = self.n_routed_experts
+ self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
+ self.n_local_physical_experts = self.n_physical_experts // self.ep_size
+ self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
+ self.physical_expert_end = (
+ self.physical_expert_start + self.n_local_physical_experts
+ )
+
+ # Router gate
+ self.gate = ReplicatedLinear(
+ config.hidden_size,
+ config.num_experts,
+ bias=False,
+ quant_config=None,
+ prefix=f"{prefix}.gate",
+ )
+
+ # Shared expert (optional) - passed to FusedMoE for overlap optimization
+ self.shared_expert: LagunaMLP | None
+ if config.shared_expert_intermediate_size > 0:
+ self.shared_expert = LagunaMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.shared_expert_intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ reduce_results=False, # Reduce after shared+routed combine
+ prefix=f"{prefix}.shared_expert",
+ )
+ else:
+ self.shared_expert = None
+
+ # Auxiliary-loss-free load-balancing bias (arXiv:2408.15664). The
+ # checkpoint stores one [num_experts] tensor per MoE layer at
+ # `mlp.experts.e_score_correction_bias`; registering it as a Parameter
+ # on the FusedMoE lets the weight loader pick it up and the router
+ # add it during top-k selection. The fused top-k bias router requires
+ # float32 regardless of model dtype.
+ e_score_correction_bias = torch.nn.Parameter(
+ torch.zeros(config.num_experts, dtype=torch.float32),
+ requires_grad=False,
+ )
+
+ # FusedMoE with SIGMOID routing. Passing `shared_experts=` lets the
+ # layer overlap the shared-expert compute with the all2all dispatch.
+ # `apply_routed_scale_to_output=True` makes FusedMoE handle the
+ # routed_scaling_factor, shared+routed combine, and TP all-reduce
+ # internally, so forward() just returns the final hidden states.
+ self.experts = FusedMoE(
+ shared_experts=self.shared_expert,
+ num_experts=config.num_experts,
+ top_k=config.num_experts_per_tok,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ renormalize=config.norm_topk_prob,
+ quant_config=quant_config,
+ prefix=f"{prefix}.experts",
+ scoring_func="sigmoid",
+ use_grouped_topk=False,
+ apply_router_weight_on_input=bool(config.moe_apply_router_weight_on_input),
+ e_score_correction_bias=e_score_correction_bias,
+ enable_eplb=self.enable_eplb,
+ num_redundant_experts=self.n_redundant_experts,
+ routed_scaling_factor=self.routed_scaling_factor,
+ apply_routed_scale_to_output=True,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ orig_shape = hidden_states.shape
+ hidden_dim = hidden_states.shape[-1]
+ hidden_states = hidden_states.view(-1, hidden_dim)
+
+ router_logits, _ = self.gate(hidden_states)
+ router_logits = router_logits.float()
+ softcap = getattr(self.config, "moe_router_logit_softcapping", 0.0) or 0.0
+ if softcap > 0.0:
+ router_logits = torch.tanh(router_logits / softcap) * softcap
+
+ final_hidden_states = self.experts(hidden_states, router_logits)
+ return final_hidden_states.view(orig_shape)
+
+
+class LagunaAttention(nn.Module):
+ """Laguna attention with optional softplus output gating.
+
+ Supports per-layer sliding window attention when ``config.layer_types``
+ is present. Layers whose type is ``"sliding_attention"`` use
+ ``config.sliding_window``; all other layers (typically labelled
+ ``"full_attention"``) use full attention. When ``layer_types`` is
+ absent every layer defaults to full attention for backwards
+ compatibility.
+ """
+
+ def __init__(
+ self,
+ config,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ max_position_embeddings: int = 131072,
+ head_dim: int | None = None,
+ cache_config: CacheConfig | None = None,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ attention_sink: bool = False,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ tp_size = get_tensor_model_parallel_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ if self.total_num_kv_heads >= tp_size:
+ assert self.total_num_kv_heads % tp_size == 0
+ else:
+ assert tp_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+ self.head_dim = head_dim or (hidden_size // self.total_num_heads)
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+ self.max_position_embeddings = max_position_embeddings
+
+ # Gating flag
+ self.gating = config.gating
+
+ # Per-layer sliding window (follows Gemma2/Cohere2 convention)
+ layer_types = getattr(config, "layer_types", None)
+ if layer_types is not None:
+ layer_idx = extract_layer_index(prefix)
+ is_sliding = layer_types[layer_idx] == "sliding_attention"
+ self.sliding_window = config.sliding_window if is_sliding else None
+ else:
+ self.sliding_window = None
+
+ # QKV projection (no bias for Laguna)
+ self.qkv_proj = QKVParallelLinear(
+ self.hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=config.qkv_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+
+ # Output projection
+ self.o_proj = RowParallelLinear(
+ self.total_num_heads * self.head_dim,
+ self.hidden_size,
+ bias=config.attention_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ )
+
+ # Gating projection (Laguna-specific, optional)
+ # config.gating may be:
+ # - True / "per-element": one gate per (head, head_dim) channel
+ # - "per-head": one gate per head, broadcast across head_dim
+ if self.gating:
+ # v5 LagunaConfig uses ``gating=True`` for per-head; older configs
+ # used ``"per-head"``. Accept both. ``"per-element"`` (or legacy
+ # ``True``) means per-element gating with output size num_heads ×
+ # head_dim.
+ gate_per_head = self.gating is True or self.gating == "per-head"
+ g_out = (
+ self.total_num_heads
+ if gate_per_head
+ else self.total_num_heads * self.head_dim
+ )
+ self.g_proj = ColumnParallelLinear(
+ hidden_size,
+ g_out,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.g_proj",
+ )
+ self.gate_per_head = gate_per_head
+ else:
+ self.g_proj = None
+ self.gate_per_head = False
+
+ # Attention sinks (learnable per-head bias for SWA layers)
+ sinks = None
+ if attention_sink:
+ self.sink = torch.nn.Parameter(
+ torch.empty(self.total_num_heads // tp_size, requires_grad=False)
+ )
+ sinks = self.sink
+
+ # Resolve rope params per-layer-type. ``config.rope_parameters`` is
+ # either a flat dict (legacy) or a nested ``{layer_type: rope_dict}``
+ # (v5 Laguna-XS schema). The v5 form is unhashable as-is and would
+ # crash `get_rope`'s cache lookup, so always pull out the layer's
+ # sub-dict before forwarding.
+ layer_type = (
+ layer_types[extract_layer_index(prefix)]
+ if layer_types is not None
+ else "full_attention"
+ )
+ is_sliding = layer_type == "sliding_attention"
+
+ top_rope = getattr(config, "rope_parameters", None) or {}
+ if any(isinstance(v, dict) for v in top_rope.values()):
+ # Nested per-layer-type form.
+ base_rope = top_rope.get(layer_type) or top_rope.get("full_attention") or {}
+ else:
+ base_rope = top_rope
+
+ # Older flat-rope ckpts can carry a separate `swa_rope_parameters`
+ # for SWA layers. Prefer it when present; otherwise the nested
+ # rope dict above already supplies the correct sub-config.
+ swa_rope = getattr(config, "swa_rope_parameters", None)
+ if (
+ is_sliding
+ and swa_rope is None
+ and not any(isinstance(v, dict) for v in top_rope.values())
+ ):
+ logger.warning_once(
+ "Laguna config has sliding_attention layers but neither "
+ "`swa_rope_parameters` nor a nested per-layer-type "
+ "`rope_parameters` — SWA layers will reuse the global rope. "
+ "If the checkpoint was trained with distinct SWA rope "
+ "(theta / partial_rotary_factor), regenerate its HF config "
+ "to include either form."
+ )
+ rope_params = swa_rope if (is_sliding and swa_rope is not None) else base_rope
+ # `partial_rotary_factor` may live on the top-level config (main attention)
+ # or on the per-layer rope dict itself (e.g. SWA can differ). Inject the
+ # top-level value into `rope_params` if the dict doesn't already set it.
+ top_partial = getattr(config, "partial_rotary_factor", None)
+ if top_partial is not None and "partial_rotary_factor" not in rope_params:
+ rope_params = {**rope_params, "partial_rotary_factor": top_partial}
+
+ # Rotary embeddings (YaRN)
+ self.rotary_emb = get_rope(
+ head_size=self.head_dim,
+ max_position=max_position_embeddings,
+ is_neox_style=True,
+ rope_parameters=rope_params,
+ )
+
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ per_layer_sliding_window=self.sliding_window,
+ prefix=f"{prefix}.attn",
+ sinks=sinks,
+ )
+
+ # QK normalization (like Qwen3)
+ self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+
+ q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
+ q_by_head = self.q_norm(q_by_head)
+ q = q_by_head.view(q.shape)
+
+ k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
+ k_by_head = self.k_norm(k_by_head)
+ k = k_by_head.view(k.shape)
+
+ q, k = self.rotary_emb(positions, q, k)
+ attn_output = self.attn(q, k, v)
+
+ # Apply gating if enabled (compute softplus in float32 for precision)
+ if self.gating and self.g_proj is not None:
+ gate, _ = self.g_proj(hidden_states)
+ gate = F.softplus(gate.float()).type_as(attn_output)
+ if self.gate_per_head:
+ # gate: [..., num_heads]; broadcast across head_dim
+ attn_shape = attn_output.shape
+ attn_output = (
+ attn_output.view(*attn_shape[:-1], self.num_heads, self.head_dim)
+ * gate.unsqueeze(-1)
+ ).view(attn_shape)
+ else:
+ attn_output = attn_output * gate
+
+ output, _ = self.o_proj(attn_output)
+ return output
+
+
+class LagunaDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ config,
+ cache_config: CacheConfig | None = None,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ enable_eplb: bool = False,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ layer_idx = extract_layer_index(prefix)
+
+ # Determine if this layer uses sliding window attention
+ layer_types = getattr(config, "layer_types", None)
+ is_sliding = (
+ layer_types is not None and layer_types[layer_idx] == "sliding_attention"
+ )
+
+ # Enable attention sinks on SWA layers when configured
+ attention_sink = is_sliding and getattr(
+ config, "swa_attention_sink_enabled", False
+ )
+
+ # Optional per-layer override of head count (Laguna-XS).
+ per_layer_heads = getattr(config, "num_attention_heads_per_layer", None)
+ layer_num_heads = (
+ per_layer_heads[layer_idx]
+ if per_layer_heads is not None
+ else config.num_attention_heads
+ )
+
+ self.self_attn = LagunaAttention(
+ config=config,
+ hidden_size=self.hidden_size,
+ num_heads=layer_num_heads,
+ num_kv_heads=config.num_key_value_heads,
+ max_position_embeddings=config.max_position_embeddings,
+ head_dim=getattr(config, "head_dim", None),
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.self_attn",
+ attention_sink=attention_sink,
+ )
+
+ # Check if this layer uses MoE or dense MLP (matches Qwen2/Qwen3 convention)
+ mlp_only_layers = (
+ [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
+ )
+ self.is_moe_layer = (
+ (layer_idx not in mlp_only_layers)
+ and (config.num_experts > 0)
+ and ((layer_idx + 1) % config.decoder_sparse_step == 0)
+ )
+
+ if self.is_moe_layer:
+ self.mlp = LagunaMoE(
+ config=config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ enable_eplb=enable_eplb,
+ )
+ else:
+ self.mlp = LagunaMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ )
+
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ # Self Attention
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ hidden_states = self.self_attn(
+ positions=positions,
+ hidden_states=hidden_states,
+ )
+
+ # Fully Connected
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+
+ return hidden_states, residual
+
+
+@support_torch_compile
+class LagunaModel(nn.Module):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+ enable_eplb = vllm_config.parallel_config.enable_eplb
+ eplb_config = vllm_config.parallel_config.eplb_config
+ self.num_redundant_experts = eplb_config.num_redundant_experts
+ self.config = config
+ self.quant_config = quant_config
+
+ # Disable the model-level sliding-window fallback in Attention.__init__.
+ # Laguna drives SWA per-layer via `layer_types`, passing
+ # `per_layer_sliding_window=self.sliding_window` (None for global
+ # layers). Without this, global layers whose `per_layer_sliding_window`
+ # is None would pick up `cache_config.sliding_window`
+ # (populated from `config.sliding_window`) as a fallback, silently
+ # applying a 512-token window to full-attention layers.
+ if cache_config is not None:
+ cache_config.sliding_window = None
+
+ self.vocab_size = config.vocab_size
+
+ if get_pp_group().is_first_rank or (
+ config.tie_word_embeddings and get_pp_group().is_last_rank
+ ):
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.embed_tokens",
+ )
+ else:
+ self.embed_tokens = PPMissingLayer()
+
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: LagunaDecoderLayer(
+ config=config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=prefix,
+ enable_eplb=enable_eplb,
+ ),
+ prefix=f"{prefix}.layers",
+ )
+
+ if get_pp_group().is_last_rank:
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size
+ )
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> torch.Tensor | IntermediateTensors:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.embed_tokens(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ for layer in islice(self.layers, self.start_layer, self.end_layer):
+ hidden_states, residual = layer(positions, hidden_states, residual)
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors(
+ {"hidden_states": hidden_states, "residual": residual}
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
+ """Get expert parameter mapping for weight loading.
+
+ Returns mapping tuples of (param_name, weight_name, expert_id, shard_id)
+ that handle both weights and quantization scales.
+ """
+ return FusedMoE.make_expert_params_mapping(
+ self,
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.num_experts,
+ num_redundant_experts=self.num_redundant_experts,
+ )
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ # gate_proj and up_proj are loaded as separate Linears (see
+ # LagunaMLP) so no merge entry is needed here.
+ ]
+
+ # Suffixes to skip for GPTQ/modelopt models if param doesn't exist
+ ignore_suffixes = (
+ ".bias",
+ "_bias",
+ ".k_scale",
+ "_k_scale",
+ ".v_scale",
+ "_v_scale",
+ ".weight_scale",
+ "_weight_scale",
+ ".input_scale",
+ "_input_scale",
+ )
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ expert_params_mapping = self.get_expert_mapping()
+
+ tp_rank = get_tensor_model_parallel_rank()
+
+ for name, loaded_weight in weights:
+ # Handle attention sinks (distributed across ranks). Derive the
+ # per-rank slice from the parameter's own shape so per-layer
+ # variations in head count are handled correctly.
+ if "sink" in name:
+ param = params_dict.get(name)
+ if param is not None:
+ layer_heads_per_rank = param.shape[0]
+ layer_head_start = tp_rank * layer_heads_per_rank
+ narrow_weight = loaded_weight.narrow(
+ 0, layer_head_start, layer_heads_per_rank
+ )
+ param.data.copy_(narrow_weight)
+ loaded_params.add(name)
+ continue
+
+ # Handle KV cache quantization scales
+ if self.quant_config is not None and (
+ scale_name := self.quant_config.get_cache_scale(name)
+ ):
+ param = params_dict[scale_name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ assert loaded_weight.numel() == 1, (
+ f"KV scale numel {loaded_weight.numel()} != 1"
+ )
+ loaded_weight = loaded_weight.squeeze()
+ weight_loader(param, loaded_weight)
+ loaded_params.add(scale_name)
+ continue
+
+ # Handle stacked params (QKV, gate_up for
+ # non-expert layers and shared_expert)
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ # Skip expert weights - handled below via expert_params_mapping
+ if "mlp.experts" in name and "shared_expert" not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ if name.endswith(ignore_suffixes) and name not in params_dict:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+ # Remap FP8 kv_scale names for backwards compatibility
+ if name.endswith("scale"):
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+ if name not in params_dict:
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ if weight_loader == default_weight_loader:
+ weight_loader(param, loaded_weight)
+ else:
+ weight_loader(param, loaded_weight, shard_id)
+ loaded_params.add(name)
+ break
+ else:
+ # Try expert params mapping (handles weights + quantization scales)
+ is_expert_weight = False
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in name:
+ continue
+
+ # Mark as expert weight so we skip regular loading below
+ is_expert_weight = True
+
+ # Create mapped name without modifying original
+ name_mapped = name.replace(weight_name, param_name)
+
+ if is_pp_missing_parameter(name_mapped, self):
+ continue
+ if (
+ name_mapped.endswith(ignore_suffixes)
+ and name_mapped not in params_dict
+ ):
+ continue
+ if name_mapped not in params_dict:
+ continue
+
+ param = params_dict[name_mapped]
+ # Use return_success to handle expert parallelism correctly
+ weight_loader = typing.cast(
+ Callable[..., bool], param.weight_loader
+ )
+ success = weight_loader(
+ param,
+ loaded_weight,
+ name_mapped,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ return_success=True,
+ )
+ if success:
+ loaded_params.add(name_mapped)
+ break
+ else:
+ # Expert weight not mapped to this rank - skip
+ if is_expert_weight:
+ continue
+
+ # Remap kv_scale names before the ignore_suffixes filter:
+ # the suffix list includes .k_scale/.v_scale, so filtering
+ # first drops the checkpoint key before remap can rewrite
+ # it to the .attn.* name that exists in params_dict.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
+ if name.endswith(ignore_suffixes) and name not in params_dict:
+ continue
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ if name not in params_dict:
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(
+ param, "weight_loader", default_weight_loader
+ )
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+
+ return loaded_params
+
+
+class LagunaForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
+ fall_back_to_pt_during_load = False
+
+ packed_modules_mapping = {
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+ self.quant_config = quant_config
+
+ self.model = LagunaModel(
+ vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
+ )
+
+ if get_pp_group().is_last_rank:
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "lm_head"),
+ )
+ if self.config.tie_word_embeddings:
+ self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
+ else:
+ self.lm_head = PPMissingLayer()
+
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors
+ )
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.embed_input_ids(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> torch.Tensor | IntermediateTensors:
+ hidden_states = self.model(
+ input_ids, positions, intermediate_tensors, inputs_embeds
+ )
+ return hidden_states
+
+ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
+ logits = self.logits_processor(self.lm_head, hidden_states)
+ return logits
+
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
+ return self.model.get_expert_mapping()
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(
+ self,
+ skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
+ )
+ return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 80cc8b895345..eba288dcc77a 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -150,6 +150,7 @@
"KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"),
"Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
"Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
+ "LagunaForCausalLM": ("laguna", "LagunaForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"),
# For decapoda-research/llama-*
diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py
index 755fa56d294c..2347eae54c25 100644
--- a/vllm/reasoning/__init__.py
+++ b/vllm/reasoning/__init__.py
@@ -32,6 +32,10 @@
"deepseek_v3_reasoning_parser",
"DeepSeekV3ReasoningParser",
),
+ "poolside_v1": (
+ "poolside_v1_reasoning_parser",
+ "PoolsideV1ReasoningParser",
+ ),
"ernie45": (
"ernie45_reasoning_parser",
"Ernie45ReasoningParser",
diff --git a/vllm/reasoning/poolside_v1_reasoning_parser.py b/vllm/reasoning/poolside_v1_reasoning_parser.py
new file mode 100644
index 000000000000..30031d8513a9
--- /dev/null
+++ b/vllm/reasoning/poolside_v1_reasoning_parser.py
@@ -0,0 +1,72 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Laguna reasoning parser.
+
+``DeepSeekV3ReasoningParser.is_reasoning_end`` walks the entire
+token sequence backwards and returns ``True`` on the first ```` it
+sees. When called on ``prompt_token_ids`` that mistakes any stray
+```` in conversation history, few-shot examples or tool descriptions
+for a template-injected "thinking already ended" marker. In the streaming
+path (see ``vllm/entrypoints/openai/chat_completion/serving.py``,
+``prompt_is_reasoning_end_arr``) that false positive short-circuits the
+reasoning parser for the whole response, so any ``...`` the
+model emits itself ends up in the content field instead of the reasoning
+field.
+
+As we have more flexible templates, we instead scope
+the backward search to the current assistant turn: the
+walk terminates as soon as we hit the ```` start-of-message
+token. A ```` in a prior user turn or few-shot example is no longer
+visible.
+"""
+
+from collections.abc import Sequence
+
+from transformers import PreTrainedTokenizerBase
+
+from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
+from vllm.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser
+from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
+
+
+class PoolsideV1ReasoningParser(DeepSeekV3ReasoningParser):
+ """Drop-in replacement for ``deepseek_v3`` that tolerates ````
+ tokens appearing anywhere in the prompt other than the generation prefix.
+ """
+
+ _start_of_assistant_message = ""
+
+ def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
+ super().__init__(tokenizer, *args, **kwargs)
+
+ if self._start_of_assistant_message not in self.vocab:
+ raise ValueError(
+ f"Tokenizer must contain {self._start_of_assistant_message!r} token"
+ )
+ self._start_of_assistant_message_token_id = self.vocab[
+ self._start_of_assistant_message
+ ]
+
+ def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
+ # IdentityReasoningParser always returns True: no reasoning to parse.
+ if isinstance(self._parser, IdentityReasoningParser):
+ return True
+
+ assert isinstance(self._parser, DeepSeekR1ReasoningParser)
+ for tok_id in reversed(input_ids):
+ # : reasoning is not yet ended.
+ if tok_id == self._parser.start_token_id:
+ return False
+ # : reasoning has ended.
+ if tok_id == self._parser.end_token_id:
+ return True
+ # : reached the start of the current assistant turn
+ # without seeing either marker. Anything further back belongs to
+ # the prior conversation and should be ignored.
+ if tok_id == self._start_of_assistant_message_token_id:
+ return False
+ return False
+
+
+__all__ = ["PoolsideV1ReasoningParser"]
diff --git a/vllm/tool_parsers/__init__.py b/vllm/tool_parsers/__init__.py
index 8a39ca825d5f..61a11cbcc376 100644
--- a/vllm/tool_parsers/__init__.py
+++ b/vllm/tool_parsers/__init__.py
@@ -66,6 +66,10 @@
"hermes_tool_parser",
"Hermes2ProToolParser",
),
+ "poolside_v1": (
+ "poolside_v1_tool_parser",
+ "PoolsideV1ToolParser",
+ ),
"hunyuan_a13b": (
"hunyuan_a13b_tool_parser",
"HunyuanA13BToolParser",
diff --git a/vllm/tool_parsers/poolside_v1_tool_parser.py b/vllm/tool_parsers/poolside_v1_tool_parser.py
new file mode 100644
index 000000000000..f14b47362917
--- /dev/null
+++ b/vllm/tool_parsers/poolside_v1_tool_parser.py
@@ -0,0 +1,583 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+GLM-4 Tool Call Parser with incremental string streaming support.
+
+This parser fixes the streaming issue reported in Issue #32829 where long string
+parameters (e.g., file content with 4000+ characters of code) are buffered until
+complete, causing multi-second delays before the user sees any content.
+
+The fix streams string values incrementally as they arrive, providing a true
+streaming experience for long content.
+"""
+
+import ast
+import json
+from collections.abc import Sequence
+from typing import Any
+
+import partial_json_parser.core.complete
+import regex as re
+from partial_json_parser.core.options import Allow
+
+from vllm.entrypoints.chat_utils import make_tool_call_id
+from vllm.entrypoints.openai.chat_completion.protocol import (
+ ChatCompletionRequest,
+)
+from vllm.entrypoints.openai.engine.protocol import (
+ DeltaFunctionCall,
+ DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall,
+ ToolCall,
+)
+from vllm.entrypoints.openai.responses.protocol import (
+ ResponsesRequest,
+)
+from vllm.logger import init_logger
+from vllm.tokenizers import TokenizerLike
+from vllm.tool_parsers.abstract_tool_parser import (
+ Tool,
+ ToolParser,
+)
+
+logger = init_logger(__name__)
+
+
+class PoolsideV1ToolParser(ToolParser):
+ """Tool parser for GLM-4 models with incremental string streaming.
+
+ This parser emits tool-call deltas incrementally as arguments arrive.
+ For string-type parameters, content is streamed character-by-character
+ rather than waiting for the complete tag.
+ """
+
+ def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
+ super().__init__(tokenizer, tools)
+ # Stateful streaming fields
+ self.current_tool_name_sent: bool = False
+ self.prev_tool_call_arr: list[dict[str, Any]] = []
+ self.current_tool_id: int = -1
+ self.streamed_args_for_tool: list[str] = []
+
+ self.tool_call_start_token: str = ""
+ self.tool_call_end_token: str = ""
+ self.arg_key_start: str = ""
+ self.arg_key_end: str = ""
+ self.arg_val_start: str = ""
+ self.arg_val_end: str = ""
+
+ self.tool_calls_start_token = self.tool_call_start_token
+
+ self.func_call_regex = re.compile(r".*?", re.DOTALL)
+ self.func_detail_regex = re.compile(
+ r"([^\n]*)\n(.*)", re.DOTALL
+ )
+ self.func_arg_regex = re.compile(
+ r"(.*?)\s*(.*?)", re.DOTALL
+ )
+
+ if not self.model_tokenizer:
+ raise ValueError(
+ "The model tokenizer must be passed to the ToolParser "
+ "constructor during construction."
+ )
+
+ self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
+ self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
+ self._buffer: str = ""
+
+ # Streaming state for incremental tool-call streaming
+ self._in_tool_call: bool = False
+ self._current_tool_name: str | None = None
+ self._pending_key: str | None = None
+ self._streaming_string_value: bool = False
+ self._tool_call_ids: list[str] = []
+ self._args_started: list[bool] = []
+ self._args_closed: list[bool] = []
+ self._seen_keys: list[set[str]] = []
+
+ @staticmethod
+ def _deserialize(value: str) -> Any:
+ try:
+ return json.loads(value)
+ except json.JSONDecodeError:
+ pass
+
+ try:
+ return ast.literal_eval(value)
+ except (ValueError, SyntaxError):
+ pass
+
+ return value
+
+ @staticmethod
+ def _json_escape_string_content(s: str) -> str:
+ """JSON-escape string content for incremental streaming.
+
+ This escapes the content that goes INSIDE a JSON string (between quotes),
+ not including the surrounding quotes themselves.
+ """
+ if not s:
+ return ""
+ return json.dumps(s, ensure_ascii=False)[1:-1]
+
+ @staticmethod
+ def _is_string_type(
+ tool_name: str,
+ arg_name: str,
+ tools: list[Tool] | None,
+ ) -> bool:
+ if tools is None:
+ return False
+ for tool in tools:
+ if tool.function.name != tool_name:
+ continue
+ if tool.function.parameters is None:
+ return False
+ arg_type = (
+ tool.function.parameters.get("properties", {})
+ .get(arg_name, {})
+ .get("type", None)
+ )
+ return arg_type == "string"
+ logger.debug("No tool named '%s'.", tool_name)
+ return False
+
+ @staticmethod
+ def _tools_enabled(request: ChatCompletionRequest) -> bool:
+ """Return whether tool parsing should be applied for this request."""
+ try:
+ tools = getattr(request, "tools", None)
+ tool_choice = getattr(request, "tool_choice", None)
+ return bool(tools) and tool_choice != "none"
+ except Exception:
+ logger.exception("Failed to determine if tools are enabled.")
+ return False
+
+ def adjust_request(
+ self, request: ChatCompletionRequest | ResponsesRequest
+ ) -> ChatCompletionRequest | ResponsesRequest:
+ """Adjust request parameters for tool call token handling."""
+ request = super().adjust_request(request)
+ if request.tools and request.tool_choice != "none":
+ # Ensure tool call tokens (, ) are not skipped
+ # during decoding. Even though they are not marked as special tokens,
+ # setting skip_special_tokens=False ensures proper handling in
+ # transformers 5.x where decoding behavior may have changed.
+ request.skip_special_tokens = False
+ return request
+
+ def extract_tool_calls(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest,
+ ) -> ExtractedToolCallInformation:
+ matched_tool_calls = self.func_call_regex.findall(model_output)
+ logger.debug("model_output: %s", model_output)
+ try:
+ tool_calls: list[ToolCall] = []
+ for match in matched_tool_calls:
+ tc_detail = self.func_detail_regex.search(match)
+ if not tc_detail:
+ logger.warning(
+ "Failed to parse tool call details from: %s",
+ match,
+ )
+ continue
+ tc_name = tc_detail.group(1).strip()
+ tc_args = tc_detail.group(2)
+ pairs = self.func_arg_regex.findall(tc_args) if tc_args else []
+ arg_dct: dict[str, Any] = {}
+ for key, value in pairs:
+ arg_key = key.strip()
+ arg_val = value.strip()
+ if not self._is_string_type(tc_name, arg_key, request.tools):
+ arg_val = self._deserialize(arg_val)
+ logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val)
+ arg_dct[arg_key] = arg_val
+ tool_calls.append(
+ ToolCall(
+ type="function",
+ function=FunctionCall(
+ name=tc_name,
+ arguments=json.dumps(arg_dct, ensure_ascii=False),
+ ),
+ )
+ )
+ except Exception:
+ logger.exception("Failed to extract tool call spec")
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+ else:
+ if len(tool_calls) > 0:
+ content: str | None = model_output[
+ : model_output.find(self.tool_calls_start_token)
+ ]
+ # Normalize empty/whitespace-only content to None
+ if not content or not content.strip():
+ content = None
+ return ExtractedToolCallInformation(
+ tools_called=True, tool_calls=tool_calls, content=content
+ )
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> DeltaMessage | None:
+ if not self._tools_enabled(request):
+ return DeltaMessage(content=delta_text) if delta_text else None
+
+ self._buffer += delta_text
+
+ pending_deltas: dict[int, DeltaToolCall] = {}
+ content: str | None = None
+
+ while True:
+ if not self._in_tool_call:
+ start_idx = self._buffer.find(self.tool_call_start_token)
+ if start_idx == -1:
+ # Check for partial start token at end of buffer
+ for i in range(1, len(self.tool_call_start_token)):
+ if self._buffer.endswith(self.tool_call_start_token[:i]):
+ out = self._buffer[:-i]
+ self._buffer = self._buffer[-i:]
+ if out:
+ content = (content or "") + out
+ break
+ else:
+ out = self._buffer
+ self._buffer = ""
+ if out:
+ content = (content or "") + out
+ break
+
+ if start_idx > 0:
+ content = (content or "") + self._buffer[:start_idx]
+ self._buffer = self._buffer[start_idx:]
+
+ self._buffer = self._buffer[len(self.tool_call_start_token) :]
+ self._begin_tool_call()
+ continue
+
+ # Parse tool name first
+ if not self.current_tool_name_sent:
+ nl = self._buffer.find("\n")
+ ak = self._buffer.find(self.arg_key_start)
+ end = self._buffer.find(self.tool_call_end_token)
+ candidates = [i for i in [nl, ak, end] if i != -1]
+ if not candidates:
+ break
+ cut = min(candidates)
+ tool_name = self._buffer[:cut].strip()
+ if tool_name == "" and cut == end:
+ # Handle empty tool call like ``.
+ # Consume the tokens and reset state to avoid infinite loop.
+ self._buffer = self._buffer[end + len(self.tool_call_end_token) :]
+ self._finish_tool_call()
+ self._revert_last_tool_call_state()
+ continue
+
+ if cut == nl:
+ self._buffer = self._buffer[nl + 1 :]
+ else:
+ self._buffer = self._buffer[cut:]
+
+ self._current_tool_name = tool_name
+ self.current_tool_name_sent = True
+ self._update_tool_name(pending_deltas, tool_name)
+ continue
+
+ assert self._current_tool_name is not None
+
+ # Handle incremental string value streaming
+ if self._streaming_string_value:
+ val_end = self._buffer.find(self.arg_val_end)
+ if val_end != -1:
+ raw_content = self._buffer[:val_end]
+ self._buffer = self._buffer[val_end + len(self.arg_val_end) :]
+ self._streaming_string_value = False
+ self._pending_key = None
+
+ escaped = self._json_escape_string_content(raw_content)
+ frag = escaped + '"'
+ self.streamed_args_for_tool[self.current_tool_id] += frag
+ self._update_tool_args(pending_deltas, frag)
+ continue
+
+ # Check for partial at end
+ safe_len = len(self._buffer)
+ for i in range(1, len(self.arg_val_end)):
+ if self._buffer.endswith(self.arg_val_end[:i]):
+ safe_len = len(self._buffer) - i
+ break
+
+ if safe_len > 0:
+ to_emit = self._buffer[:safe_len]
+ self._buffer = self._buffer[safe_len:]
+ escaped = self._json_escape_string_content(to_emit)
+ if escaped:
+ self.streamed_args_for_tool[self.current_tool_id] += escaped
+ self._update_tool_args(pending_deltas, escaped)
+ break
+
+ # If we have a pending key, parse its value
+ if self._pending_key is not None:
+ val_pos = self._buffer.find(self.arg_val_start)
+ if val_pos == -1:
+ break
+ if val_pos > 0:
+ self._buffer = self._buffer[val_pos:]
+
+ key = (self._pending_key or "").strip()
+
+ is_string = self._is_string_type(
+ self._current_tool_name, key, request.tools
+ )
+
+ if is_string:
+ # String type: stream incrementally
+ self._buffer = self._buffer[len(self.arg_val_start) :]
+
+ if key in self._seen_keys[self.current_tool_id]:
+ self._pending_key = None
+ continue
+
+ self._seen_keys[self.current_tool_id].add(key)
+ key_json = json.dumps(key, ensure_ascii=False)
+
+ if not self._args_started[self.current_tool_id]:
+ frag = "{" + key_json + ': "'
+ self._args_started[self.current_tool_id] = True
+ else:
+ frag = ", " + key_json + ': "'
+
+ self.streamed_args_for_tool[self.current_tool_id] += frag
+ self._streaming_string_value = True
+ self._update_tool_args(pending_deltas, frag)
+ continue
+
+ # Non-string type: wait for complete value
+ val_end = self._buffer.find(self.arg_val_end)
+ if val_end == -1:
+ break
+
+ raw_val = self._buffer[len(self.arg_val_start) : val_end].strip()
+ self._buffer = self._buffer[val_end + len(self.arg_val_end) :]
+ self._pending_key = None
+
+ frag_or_none = self._append_arg_fragment(key=key, raw_val=raw_val)
+ if frag_or_none:
+ self._update_tool_args(pending_deltas, frag_or_none)
+ continue
+
+ # Parse next arg or close
+ end_pos = self._buffer.find(self.tool_call_end_token)
+ key_pos = self._buffer.find(self.arg_key_start)
+ if end_pos != -1 and (key_pos == -1 or end_pos < key_pos):
+ self._buffer = self._buffer[end_pos + len(self.tool_call_end_token) :]
+ frag_or_none = self._close_args_if_needed()
+ # Finalize prev_tool_call_arr with complete parsed arguments
+ if self._current_tool_name:
+ try:
+ full_args_str = self.streamed_args_for_tool[
+ self.current_tool_id
+ ]
+ args_dict = json.loads(full_args_str)
+ self.prev_tool_call_arr[self.current_tool_id] = {
+ "name": self._current_tool_name,
+ "arguments": args_dict,
+ }
+ except (json.JSONDecodeError, IndexError) as e:
+ logger.warning(
+ "Failed to finalize tool call state for tool %d: %s",
+ self.current_tool_id,
+ e,
+ )
+ self._finish_tool_call()
+ if frag_or_none:
+ self._update_tool_args(pending_deltas, frag_or_none)
+ continue
+
+ if key_pos == -1:
+ break
+ if key_pos > 0:
+ self._buffer = self._buffer[key_pos:]
+ key_end = self._buffer.find(self.arg_key_end)
+ if key_end == -1:
+ break
+ key = self._buffer[len(self.arg_key_start) : key_end]
+ self._buffer = self._buffer[key_end + len(self.arg_key_end) :]
+ self._pending_key = key
+ continue
+
+ tool_calls = list(pending_deltas.values())
+ if content is None and len(tool_calls) == 0:
+ if request.logprobs:
+ return DeltaMessage(content="")
+ return None
+ return DeltaMessage(content=content, tool_calls=tool_calls)
+
+ def _ensure_tool_state(self) -> None:
+ while len(self._tool_call_ids) <= self.current_tool_id:
+ self._tool_call_ids.append(
+ make_tool_call_id(id_type="random", func_name=None, idx=None)
+ )
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ while len(self._args_started) <= self.current_tool_id:
+ self._args_started.append(False)
+ while len(self._args_closed) <= self.current_tool_id:
+ self._args_closed.append(False)
+ while len(self._seen_keys) <= self.current_tool_id:
+ self._seen_keys.append(set())
+
+ def _begin_tool_call(self) -> None:
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+ else:
+ self.current_tool_id += 1
+ self._ensure_tool_state()
+ self.current_tool_name_sent = False
+ self._current_tool_name = None
+ self._pending_key = None
+ self._streaming_string_value = False
+ self._in_tool_call = True
+
+ def _finish_tool_call(self) -> None:
+ self._in_tool_call = False
+ self._current_tool_name = None
+ self._pending_key = None
+ self._streaming_string_value = False
+
+ def _revert_last_tool_call_state(self) -> None:
+ """Revert the state allocation for the last tool call."""
+ if self.current_tool_id < 0:
+ return
+ self._tool_call_ids.pop()
+ self.streamed_args_for_tool.pop()
+ self.prev_tool_call_arr.pop()
+ self._args_started.pop()
+ self._args_closed.pop()
+ self._seen_keys.pop()
+ self.current_tool_id -= 1
+
+ def _get_or_create_delta(self, pending: dict[int, DeltaToolCall]) -> DeltaToolCall:
+ idx = self.current_tool_id
+ if idx not in pending:
+ pending[idx] = DeltaToolCall(
+ index=idx,
+ function=DeltaFunctionCall(),
+ )
+ delta = pending[idx]
+ assert delta.function is not None
+ return delta
+
+ def _update_tool_name(
+ self, pending: dict[int, DeltaToolCall], tool_name: str
+ ) -> None:
+ self.prev_tool_call_arr[self.current_tool_id] = {
+ "name": self._current_tool_name,
+ "arguments": {},
+ }
+ delta = self._get_or_create_delta(pending)
+ delta.id = self._tool_call_ids[self.current_tool_id]
+ delta.type = "function"
+ assert delta.function is not None
+ delta.function.name = tool_name
+ if delta.function.arguments is None:
+ delta.function.arguments = ""
+
+ @staticmethod
+ def _complete_json_prefix(
+ json_prefix: str,
+ allowed_partial_types: Allow,
+ ) -> dict | None:
+ """Complete a partial JSON prefix into a valid JSON object.
+
+ Returns (formatted_prefix, parsed_dict) or None on failure.
+
+ Note: ``partial_json_parser`` strips trailing whitespace before
+ parsing (``complete.py:20``), which means the returned slice is
+ shorter than ``json_prefix`` when it has trailing whitespace.
+ Since the parser controls the construction of the json_prefix value,
+ this code relies on it being a valid prefix and we only use the fix for
+ the completion of the JSON object.
+ """
+ try:
+ _, partial_str_completion = partial_json_parser.core.complete.fix(
+ json_prefix,
+ allowed_partial_types,
+ )
+ return json.loads(json_prefix + partial_str_completion)
+ except Exception:
+ return None
+
+ def _update_tool_args(
+ self, pending: dict[int, DeltaToolCall], fragment: str
+ ) -> None:
+ result = self._complete_json_prefix(
+ self.streamed_args_for_tool[self.current_tool_id],
+ Allow.ALL,
+ )
+ if result is not None:
+ self.prev_tool_call_arr[self.current_tool_id]["arguments"] = result
+ delta = self._get_or_create_delta(pending)
+ assert delta.function is not None
+ if delta.function.arguments is None:
+ delta.function.arguments = ""
+ delta.function.arguments += fragment
+
+ def _append_arg_fragment(
+ self,
+ *,
+ key: str,
+ raw_val: str,
+ ) -> str | None:
+ key = key.strip()
+ if not key:
+ return None
+ if key in self._seen_keys[self.current_tool_id]:
+ return None
+
+ # This function is only called for non-string types (already checked
+ # by _is_string_type in the caller), so we always deserialize.
+ val_obj: Any = self._deserialize(raw_val)
+
+ key_json = json.dumps(key, ensure_ascii=False)
+ val_json = json.dumps(val_obj, ensure_ascii=False)
+
+ if not self._args_started[self.current_tool_id]:
+ fragment = "{" + key_json + ": " + val_json
+ self._args_started[self.current_tool_id] = True
+ else:
+ fragment = ", " + key_json + ": " + val_json
+
+ self._seen_keys[self.current_tool_id].add(key)
+ self.streamed_args_for_tool[self.current_tool_id] += fragment
+ return fragment
+
+ def _close_args_if_needed(self) -> str | None:
+ if self._args_closed[self.current_tool_id]:
+ return None
+ self._args_closed[self.current_tool_id] = True
+ if not self._args_started[self.current_tool_id]:
+ fragment = "{}"
+ self.streamed_args_for_tool[self.current_tool_id] = fragment
+ else:
+ fragment = "}"
+ self.streamed_args_for_tool[self.current_tool_id] += fragment
+ return fragment
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index bb6ad1056b7b..47b74093b06c 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -127,6 +127,7 @@ def __getitem__(self, key):
qwen3_next="Qwen3NextConfig",
qwen3_5="Qwen3_5Config",
qwen3_5_moe="Qwen3_5MoeConfig",
+ laguna="LagunaConfig",
lfm2_moe="Lfm2MoeConfig",
tarsier2="Tarsier2Config",
)
@@ -409,22 +410,33 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
ompe = getattr(config, "original_max_position_embeddings", None)
if Version(version("transformers")) < Version("5.0.0"):
- # Transformers v4 installed, legacy config fields may be present
- if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
- config.rope_parameters = rope_scaling
- if (
- rope_theta is not None
- or partial_rotary_factor is not None
- or ompe is not None
- ) and not getattr(config, "rope_parameters", None):
- config.rope_parameters = {"rope_type": "default"}
- # Patch legacy fields into rope_parameters
- if rope_theta is not None:
- config.rope_parameters["rope_theta"] = rope_theta
- if partial_rotary_factor is not None:
- config.rope_parameters["partial_rotary_factor"] = partial_rotary_factor
- if ompe is not None:
- config.rope_parameters["original_max_position_embeddings"] = ompe
+ # Transformers v4 installed, legacy config fields may be present.
+ existing_rp = getattr(config, "rope_parameters", None)
+ if isinstance(existing_rp, dict) and is_rope_parameters_nested(existing_rp):
+ # Interleaved-attention models (e.g. Laguna-XS.2) ship a nested
+ # {layer_type: {...}} rope_parameters that the model code indexes
+ # by layer_type. The per-layer-type sub-dicts already carry the
+ # correct rope_theta / partial_rotary_factor / ompe (the converter
+ # places top-level legacy fields inside full_attention), so don't
+ # merge top-level fields here — that would shadow the per-type
+ # values and break sliding-attention layers.
+ pass
+ else:
+ if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
+ config.rope_parameters = rope_scaling
+ if (
+ rope_theta is not None
+ or partial_rotary_factor is not None
+ or ompe is not None
+ ) and not getattr(config, "rope_parameters", None):
+ config.rope_parameters = {"rope_type": "default"}
+ # Patch legacy fields into rope_parameters
+ if rope_theta is not None:
+ config.rope_parameters["rope_theta"] = rope_theta
+ if partial_rotary_factor is not None:
+ config.rope_parameters["partial_rotary_factor"] = partial_rotary_factor
+ if ompe is not None:
+ config.rope_parameters["original_max_position_embeddings"] = ompe
elif rope_theta is not None or getattr(config, "rope_parameters", None):
# Transformers v5 installed
# Patch these fields in case they used non-standard names
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index 667ed5a2596c..8c4d01a428bd 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -45,6 +45,7 @@
# `FalconConfig` class from the official HuggingFace transformers library.
"RWConfig": "vllm.transformers_utils.configs.falcon",
"JAISConfig": "vllm.transformers_utils.configs.jais",
+ "LagunaConfig": "vllm.transformers_utils.configs.laguna",
"Lfm2MoeConfig": "vllm.transformers_utils.configs.lfm2_moe",
"MedusaConfig": "vllm.transformers_utils.configs.medusa",
"MiDashengLMConfig": "vllm.transformers_utils.configs.midashenglm",
@@ -105,6 +106,7 @@
"IsaacConfig",
"RWConfig",
"JAISConfig",
+ "LagunaConfig",
"Lfm2MoeConfig",
"MedusaConfig",
"MiDashengLMConfig",
diff --git a/vllm/transformers_utils/configs/laguna.py b/vllm/transformers_utils/configs/laguna.py
new file mode 100644
index 000000000000..2702d3af5aa1
--- /dev/null
+++ b/vllm/transformers_utils/configs/laguna.py
@@ -0,0 +1,120 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from transformers.configuration_utils import PretrainedConfig
+
+
+class LagunaConfig(PretrainedConfig):
+ model_type = "laguna"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.g_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size: int = 100352,
+ hidden_size: int = 2048,
+ intermediate_size: int = 8192,
+ num_hidden_layers: int = 40,
+ num_attention_heads: int = 48,
+ num_key_value_heads: int = 8,
+ head_dim: int = 128,
+ qkv_bias: bool = False,
+ attention_bias: bool = False,
+ gating: bool | str = True,
+ hidden_act: str = "silu",
+ max_position_embeddings: int = 131072,
+ initializer_range: float = 0.02,
+ rms_norm_eps: float = 1e-6,
+ use_cache: bool = True,
+ tie_word_embeddings: bool = False,
+ rope_theta: float = 500000.0,
+ rope_scaling: dict | None = None,
+ rope_parameters: dict | None = None,
+ partial_rotary_factor: float = 1.0,
+ attention_dropout: float = 0.0,
+ sliding_window: int | None = None,
+ layer_types: list[str] | None = None,
+ swa_attention_sink_enabled: bool = False,
+ swa_rope_parameters: dict | None = None,
+ num_attention_heads_per_layer: list[int] | None = None,
+ num_experts: int = 256,
+ num_experts_per_tok: int = 8,
+ moe_intermediate_size: int = 512,
+ shared_expert_intermediate_size: int = 512,
+ norm_topk_prob: bool = True,
+ decoder_sparse_step: int = 1,
+ mlp_only_layers: list[int] | None = None,
+ router_aux_loss_coef: float = 0.001,
+ output_router_logits: bool = False,
+ moe_routed_scaling_factor: float = 1.0,
+ moe_apply_router_weight_on_input: bool = False,
+ **kwargs,
+ ):
+ if mlp_only_layers is None:
+ mlp_only_layers = [0]
+
+ # Accept either v4-style (rope_theta + rope_scaling) or v5-style
+ # (rope_parameters). Translate v5 → v4 so downstream code has one path.
+ if rope_parameters is not None:
+ rp = dict(rope_parameters)
+ rope_theta = float(rp.pop("rope_theta", rope_theta))
+ rt = rp.pop("rope_type", None)
+ if rt is not None and rt != "default":
+ rope_scaling = {"rope_type": rt, **rp}
+ elif rp and rope_scaling is None:
+ rope_scaling = {"rope_type": "default", **rp}
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.head_dim = head_dim
+ self.qkv_bias = qkv_bias
+ self.attention_bias = attention_bias
+ self.gating = gating
+ self.hidden_act = hidden_act
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.partial_rotary_factor = partial_rotary_factor
+ self.attention_dropout = attention_dropout
+ self.sliding_window = sliding_window
+ self.layer_types = layer_types
+ self.swa_attention_sink_enabled = swa_attention_sink_enabled
+ self.swa_rope_parameters = swa_rope_parameters
+ self.num_attention_heads_per_layer = num_attention_heads_per_layer
+ self.num_experts = num_experts
+ self.num_experts_per_tok = num_experts_per_tok
+ self.moe_intermediate_size = moe_intermediate_size
+ self.shared_expert_intermediate_size = shared_expert_intermediate_size
+ self.norm_topk_prob = norm_topk_prob
+ self.decoder_sparse_step = decoder_sparse_step
+ self.mlp_only_layers = mlp_only_layers
+ self.router_aux_loss_coef = router_aux_loss_coef
+ self.output_router_logits = output_router_logits
+ self.moe_routed_scaling_factor = moe_routed_scaling_factor
+ self.moe_apply_router_weight_on_input = moe_apply_router_weight_on_input
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+__all__ = ["LagunaConfig"]