From ef897737805474098e9b796589b6712becfe0446 Mon Sep 17 00:00:00 2001 From: Piotr Mazurek Date: Wed, 7 Jan 2026 11:58:38 +0000 Subject: [PATCH 01/14] Add version 0 of LFM2 --- python/sglang/srt/configs/__init__.py | 2 + python/sglang/srt/configs/lfm2.py | 95 ++ .../sglang/srt/model_executor/model_runner.py | 3 +- python/sglang/srt/models/lfm2.py | 875 ++++++++++++++++++ 4 files changed, 974 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/configs/lfm2.py create mode 100644 python/sglang/srt/models/lfm2.py diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index b35cc1dc5f23..eadb3f03521a 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -6,6 +6,7 @@ from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.falcon_h1 import FalconH1Config from sglang.srt.configs.janus_pro import MultiModalityConfig +from sglang.srt.configs.lfm2 import Lfm2Config from sglang.srt.configs.jet_nemotron import JetNemotronConfig from sglang.srt.configs.jet_vlm import JetVLMConfig from sglang.srt.configs.kimi_linear import KimiLinearConfig @@ -40,6 +41,7 @@ "DotsVLMConfig", "DotsOCRConfig", "FalconH1Config", + "Lfm2Config", "NemotronHConfig", "NemotronH_Nano_VL_V2_Config", "JetNemotronConfig", diff --git a/python/sglang/srt/configs/lfm2.py b/python/sglang/srt/configs/lfm2.py new file mode 100644 index 000000000000..b94cd5f11afa --- /dev/null +++ b/python/sglang/srt/configs/lfm2.py @@ -0,0 +1,95 @@ +# coding=utf-8 +# Copyright 2024 Liquid AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LFM2 (Liquid Foundation Model 2) configuration""" + +from typing import List, Optional + +from transformers import Lfm2Config as HFLfm2Config +from transformers import CONFIG_MAPPING +from transformers.utils import logging + +from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape + +logger = logging.get_logger(__name__) + + +class Lfm2Config(HFLfm2Config): + """ + SGLang configuration for LFM2 models. + + Extends HuggingFace's Lfm2Config with hybrid model properties needed by SGLang. + LFM2 uses a hybrid architecture mixing full attention and ShortConv layers. + """ + + @property + def full_attention_layer_ids(self) -> List[int]: + """Return indices of attention layers for KV cache.""" + return [i for i, lt in enumerate(self.layer_types) if lt == "full_attention"] + + @property + def linear_layer_ids(self) -> List[int]: + """Return indices of conv layers for conv state cache.""" + return [i for i, lt in enumerate(self.layer_types) if lt in ("conv", "short_conv")] + + @property + def mamba_chunk_size(self) -> int: + """Return chunk size for Mamba2 backend. LFM2 doesn't use chunking, return 1.""" + return 1 + + @property + def mamba2_cache_params(self) -> Optional[Mamba2CacheParams]: + """ + Get cache params for HybridReqToTokenPool initialization. + + LFM2 uses ShortConv layers with a small fixed-size cache (kernel_size - 1). + Unlike full Mamba2 models, LFM2 only uses the conv state, not SSM temporal state. + """ + from sglang.srt.layers.dp_attention import get_attention_tp_size + + conv_layer_ids = self.linear_layer_ids + if not conv_layer_ids: + return None + + hidden_size = self.hidden_size + # conv_L_cache in config is kernel_size (e.g., 3) + conv_kernel = int(self.conv_L_cache) + L_cache = conv_kernel - 1 # actual cache size (e.g., 2 for kernel=3) + + # get_attention_tp_size() requires initialization, default to 1 if not available + try: + tp_size = get_attention_tp_size() + except (AssertionError, RuntimeError): + tp_size = 1 + + # For ShortConv layers, we use a simplified Mamba2StateShape + # LFM2 doesn't use SSM state (state_size=0), only conv state + shape = Mamba2StateShape.create( + tp_world_size=tp_size, + intermediate_size=hidden_size, + n_groups=1, # ShortConv doesn't use grouping + num_heads=1, # ShortConv is not multi-head + head_dim=hidden_size, # Conv operates on full hidden dim + state_size=0, # No SSM temporal state for ShortConv + conv_kernel=conv_kernel, + ) + + return Mamba2CacheParams(shape=shape, layers=conv_layer_ids) + + +# Override HuggingFace's Lfm2Config with our extended version +# Cannot use .register() because lfm2 is already registered by transformers +# Directly modify the internal _extra_content dict instead +CONFIG_MAPPING._extra_content["lfm2"] = Lfm2Config +logger.info("Registered SGLang Lfm2Config to override HuggingFace's version") diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 677d71d8c69d..dbb732544b17 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -38,6 +38,7 @@ NemotronH_Nano_VL_V2_Config, NemotronHConfig, Qwen3NextConfig, + Lfm2Config, ) from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig, LoadFormat @@ -1478,7 +1479,7 @@ def hybrid_gdn_config(self): @property def mamba2_config(self): config = self.model_config.hf_config - if isinstance(config, FalconH1Config | NemotronHConfig): + if isinstance(config, FalconH1Config | NemotronHConfig | Lfm2Config): return config if isinstance(config, NemotronH_Nano_VL_V2_Config): return config.llm_config diff --git a/python/sglang/srt/models/lfm2.py b/python/sglang/srt/models/lfm2.py new file mode 100644 index 000000000000..814a4bcb9373 --- /dev/null +++ b/python/sglang/srt/models/lfm2.py @@ -0,0 +1,875 @@ +# sglang/srt/models/lfm2.py +# LFM2 implementation for SGLang +# Based on HuggingFace's implementation +# +# This version uses SGLang's hybrid caching infrastructure (HybridReqToTokenPool + MambaPool) +# while keeping the original working model structure and weight names. +# +# IMPORTANT: This file patches Lfm2Config at MODULE IMPORT TIME to add the properties +# required by model_runner.py for hybrid cache detection. You must also add Lfm2Config +# to the isinstance check in model_runner.py's mamba2_config property. + +import logging +from typing import Iterable, List, Optional, Set, Tuple + +import torch +from torch import nn +import torch.nn.functional as F + +from transformers import Lfm2Config + +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix, make_layers + +# Import for Mamba state management - use SGLang's actual classes +from sglang.srt.configs.mamba_utils import ( + Mamba2StateShape, + Mamba2CacheParams, +) + +logger = logging.getLogger(__name__) + +# Debug flag - set to True to enable debug logging +DEBUG_LFM2 = False + + +def debug_tensor(name: str, t: torch.Tensor): + if DEBUG_LFM2: + logger.info(f"DEBUG {name}: shape={t.shape}, dtype={t.dtype}, " + f"min={t.min().item():.6f}, max={t.max().item():.6f}, " + f"mean={t.mean().item():.6f}, std={t.std().item():.6f}") + + +# ============================================================================ +# Config Patching - MUST happen at module import time +# ============================================================================ + +def _patch_lfm2_config_class(): + """ + Patch Lfm2Config CLASS (not instance) with properties required by model_runner.py. + + This must happen at module import time, BEFORE model_runner.py checks the config type. + + model_runner.py's mamba2_config property does: + if isinstance(config, FalconH1Config | NemotronHConfig | Lfm2Config): + return config + + And then uses config.mamba2_cache_params to set up HybridReqToTokenPool. + """ + if getattr(Lfm2Config, '_sglang_patched', False): + return + + def _get_full_attention_layer_ids(self) -> List[int]: + """Return indices of attention layers for KV cache.""" + return [i for i, lt in enumerate(self.layer_types) if lt == "full_attention"] + + def _get_linear_layer_ids(self) -> List[int]: + """Return indices of conv layers for conv state cache.""" + return [i for i, lt in enumerate(self.layer_types) if lt in ("conv", "short_conv")] + + def _get_mamba_chunk_size(self) -> int: + """Return chunk size for Mamba2 backend. LFM2 doesn't use chunking, return 1.""" + return 1 + + def _get_mamba2_cache_params(self) -> Optional[Mamba2CacheParams]: + """ + Get cache params for HybridReqToTokenPool initialization. + + Uses SGLang's Mamba2StateShape to describe the conv state shape. + LFM2 only uses the conv state (no SSM temporal state), so we set + state_size=0 which makes temporal state minimal. + """ + from sglang.srt.layers.dp_attention import get_attention_tp_size + + conv_layer_ids = [i for i, lt in enumerate(self.layer_types) if lt in ("conv", "short_conv")] + if not conv_layer_ids: + return None + + hidden_size = self.hidden_size + # conv_L_cache in config is kernel_size (e.g., 3) + conv_kernel = int(self.conv_L_cache) + L_cache = conv_kernel - 1 # actual cache size + tp_size = get_attention_tp_size() + + # Create Mamba2StateShape compatible with SGLang's infrastructure + # For LFM2 conv layers: + # - Conv state shape: (hidden_size, L_cache) per layer + # - No SSM temporal state (state_size=0) + # + # Mamba2StateShape.create() computes: + # - conv_dim = intermediate_size // tp_size (for TP sharding) + # - We want conv_dim = hidden_size, so intermediate_size = hidden_size * tp_size + shape = Mamba2StateShape.create( + tp_world_size=tp_size, + intermediate_size=hidden_size * tp_size, # Results in conv_dim = hidden_size + n_groups=1, + num_heads=1, + head_dim=1, + state_size=0, # No SSM state, only conv state + conv_kernel=conv_kernel, + ) + + return Mamba2CacheParams(shape=shape, layers=conv_layer_ids) + + # Patch the CLASS, not instances + Lfm2Config.full_attention_layer_ids = property(_get_full_attention_layer_ids) + Lfm2Config.linear_layer_ids = property(_get_linear_layer_ids) + Lfm2Config.mamba2_cache_params = property(_get_mamba2_cache_params) + Lfm2Config.mamba_chunk_size = property(_get_mamba_chunk_size) + Lfm2Config._sglang_patched = True + + logger.info("Patched Lfm2Config class with SGLang hybrid cache properties") + + +# Patch at module import time - this runs when lfm2.py is imported +_patch_lfm2_config_class() + + +# ============================================================================ +# Model Components +# ============================================================================ + +class Lfm2RMSNorm(nn.Module): + """ + LFM2-specific RMSNorm that uses weight * x (NOT (1 + weight) * x like Gemma). + This matches the HuggingFace Lfm2RMSNorm implementation exactly. + """ + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) + + +class Lfm2MLP(nn.Module): + """MLP with SwiGLU activation - uses w1/w2/w3 naming to match checkpoint.""" + def __init__( + self, + config: Lfm2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + intermediate_size = config.intermediate_size + if config.block_auto_adjust_ff_dim: + intermediate_size = int(2 * intermediate_size / 3) + if config.block_ffn_dim_multiplier is not None: + intermediate_size = int(config.block_ffn_dim_multiplier * intermediate_size) + intermediate_size = config.block_multiple_of * ( + (intermediate_size + config.block_multiple_of - 1) // config.block_multiple_of + ) + + self.w1 = ColumnParallelLinear( + input_size=config.hidden_size, + output_size=intermediate_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("w1", prefix), + ) + self.w3 = ColumnParallelLinear( + input_size=config.hidden_size, + output_size=intermediate_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("w3", prefix), + ) + self.w2 = RowParallelLinear( + input_size=intermediate_size, + output_size=config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("w2", prefix), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate, _ = self.w1(x) + up, _ = self.w3(x) + h = F.silu(gate) * up + out, _ = self.w2(h) + return out + + +class Lfm2Attention(nn.Module): + """Attention with RoPE and Q/K layernorm - matches checkpoint weight names.""" + def __init__( + self, + config: Lfm2Config, + layer_id: int, + attn_layer_id: int, # Sequential ID for attention layers only (for KV cache) + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.layer_id = layer_id + self.attn_layer_id = attn_layer_id + + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.total_num_kv_heads = config.num_key_value_heads + self.head_dim = getattr(config, "head_dim", None) or (self.hidden_size // self.total_num_heads) + self.scaling = self.head_dim**-0.5 + + rope_parameters = getattr(config, "rope_parameters", None) + if rope_parameters is not None and "rope_theta" in rope_parameters: + self.rope_theta = rope_parameters["rope_theta"] + else: + self.rope_theta = getattr(config, "rope_theta", 10000) + + self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.rope_scaling = getattr(config, "rope_scaling", None) + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + rope_scaling=self.rope_scaling, + base=self.rope_theta, + is_neox_style=True, + dtype=torch.get_default_dtype(), + ) + + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + + # Named out_proj to match checkpoint + self.out_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("out_proj", prefix), + ) + + # Named q_layernorm/k_layernorm to match checkpoint + self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) + self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) + + self.num_local_q_heads = self.qkv_proj.num_heads + self.num_local_kv_heads = self.qkv_proj.num_kv_heads + + self.attn = RadixAttention( + num_heads=self.num_local_q_heads, + head_dim=self.head_dim, + scaling=self.scaling, + num_kv_heads=self.num_local_kv_heads, + layer_id=self.layer_id, # Use global layer ID for routing in hybrid backend + prefix=add_prefix("attn", prefix), + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + T = hidden_states.shape[0] + qkv, _ = self.qkv_proj(hidden_states) + + q_size = self.num_local_q_heads * self.head_dim + kv_size = self.num_local_kv_heads * self.head_dim + + q, k, v = torch.split(qkv, [q_size, kv_size, kv_size], dim=-1) + + q = q.reshape(T, self.num_local_q_heads, self.head_dim) + k = k.reshape(T, self.num_local_kv_heads, self.head_dim) + + q = self.q_layernorm(q.reshape(-1, self.head_dim)).reshape(T, self.num_local_q_heads, self.head_dim) + k = self.k_layernorm(k.reshape(-1, self.head_dim)).reshape(T, self.num_local_kv_heads, self.head_dim) + + q, k = self.rotary_emb(positions, q, k) + + q = q.reshape(T, -1) + k = k.reshape(T, -1) + + attn_out = self.attn(q, k, v, forward_batch) + + out, _ = self.out_proj(attn_out) + return out + + +class Lfm2ShortConv(nn.Module): + """ + Short conv implementation using SGLang's MambaPool for state management. + + This implementation: + 1. Uses nn.Linear for in_proj/out_proj (matching HF checkpoint) + 2. Accesses conv state through HybridReqToTokenPool.mamba2_layer_cache() + 3. Handles prefill and decode modes properly + 4. Is CUDA graph compatible (uses index_copy_ instead of .item()) + """ + + def __init__( + self, + config: Lfm2Config, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + # conv_L_cache in config is the kernel size (e.g., 3), NOT kernel-1 + # The "cache" stores kernel_size - 1 values for causal conv + self.conv_kernel = int(config.conv_L_cache) # kernel_size from config + self.L_cache = self.conv_kernel - 1 # actual cache size = kernel - 1 + self.bias = bool(config.conv_bias) + self.hidden_size = config.hidden_size + + # Match HF exactly - use nn.Linear (not parallel versions) + self.in_proj = nn.Linear( + config.hidden_size, + 3 * config.hidden_size, + bias=self.bias, + ) + self.out_proj = nn.Linear( + config.hidden_size, + config.hidden_size, + bias=self.bias, + ) + + # Depthwise conv1d with causal padding + self.conv = nn.Conv1d( + in_channels=config.hidden_size, + out_channels=config.hidden_size, + kernel_size=self.conv_kernel, + groups=config.hidden_size, + bias=self.bias, + padding=self.L_cache, # Causal padding = kernel_size - 1 + ) + + def forward( + self, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + """ + Forward pass using SGLang's hybrid caching infrastructure. + + Conv state is accessed through: + forward_batch.req_to_token_pool.mamba2_layer_cache(self.layer_idx) + """ + forward_mode = forward_batch.forward_mode + + if forward_mode.is_idle(): + return hidden_states + + # Get conv cache through HybridReqToTokenPool + # mamba2_layer_cache returns a cache object with .conv and .temporal attributes + layer_cache = forward_batch.req_to_token_pool.mamba2_layer_cache(self.layer_idx) + # conv is a list of tensors, one per conv state component + # For Mamba2, conv[0] has shape [pool_size+1, conv_dim, conv_kernel-1] + # For LFM2, this is [pool_size+1, hidden_size, L_cache] + conv_state = layer_cache.conv[0] + + # Get request pool indices for current batch + req_pool_indices = forward_batch.req_pool_indices + + if forward_mode.is_decode(): + return self._forward_decode( + hidden_states, + conv_state, + req_pool_indices, + ) + else: + # Prefill/extend mode + seq_lens = getattr(forward_batch, 'extend_seq_lens', None) + + if seq_lens is not None and len(seq_lens) > 1: + return self._forward_prefill_multi( + hidden_states, + conv_state, + req_pool_indices, + seq_lens, + ) + else: + return self._forward_prefill_single( + hidden_states, + conv_state, + req_pool_indices, + ) + + def _forward_prefill_single( + self, + hidden_states: torch.Tensor, + conv_state: torch.Tensor, + req_pool_indices: torch.Tensor, + ) -> torch.Tensor: + """Prefill for a single sequence - matches HF slow_forward exactly.""" + T = hidden_states.shape[0] + + # Step 1: in_proj + proj = self.in_proj(hidden_states) # [T, 3H] + + # Step 2: transpose to [1, 3H, T] for conv + proj_t = proj.transpose(0, 1).unsqueeze(0) # [1, 3H, T] + + # Step 3: chunk into B, C, x + B_gate, C_gate, x = proj_t.chunk(3, dim=1) # each [1, H, T] + + # Step 4: Bx = B * x + Bx = B_gate * x # [1, H, T] + + # Step 5: conv with causal padding (output is truncated to T) + conv_out = self.conv(Bx)[..., :T] # [1, H, T] + + # Step 6: y = C * conv_out + y = C_gate * conv_out # [1, H, T] + + # Step 7: transpose back + y = y.squeeze(0).transpose(0, 1) # [T, H] + + # Step 8: out_proj + y = self.out_proj(y) # [T, H] + + # Store the final conv state (last L_cache values of Bx) + if T >= self.L_cache: + final_state = Bx[0, :, -self.L_cache:] # [H, L_cache] + else: + final_state = F.pad(Bx[0], (self.L_cache - T, 0), value=0.0) # [H, L_cache] + + # Store for the request using index_copy_ (CUDA graph compatible) + # Ensure dtype matches conv_state (may be bfloat16 while final_state is float32) + if req_pool_indices.numel() > 0: + conv_state.index_copy_(0, req_pool_indices[:1].long(), final_state.unsqueeze(0).to(conv_state.dtype)) + + return y + + def _forward_prefill_multi( + self, + hidden_states: torch.Tensor, + conv_state: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + ) -> torch.Tensor: + """Process multiple sequences separately to avoid cross-contamination.""" + outputs = [] + start_idx = 0 + + seq_lens_list = seq_lens.tolist() if isinstance(seq_lens, torch.Tensor) else list(seq_lens) + req_pool_indices_long = req_pool_indices.long() + + for i, seq_len in enumerate(seq_lens_list): + seq_len = int(seq_len) + end_idx = start_idx + seq_len + + seq_hidden = hidden_states[start_idx:end_idx] + T = seq_hidden.shape[0] + + # Process this sequence + proj = self.in_proj(seq_hidden) + proj_t = proj.transpose(0, 1).unsqueeze(0) + B_gate, C_gate, x = proj_t.chunk(3, dim=1) + Bx = B_gate * x + conv_out = self.conv(Bx)[..., :T] + y = C_gate * conv_out + y = y.squeeze(0).transpose(0, 1) + y = self.out_proj(y) + + outputs.append(y) + + # Store conv state for this sequence + if T >= self.L_cache: + final_state = Bx[0, :, -self.L_cache:] + else: + final_state = F.pad(Bx[0], (self.L_cache - T, 0), value=0.0) + + # Use index_copy_ for CUDA graph compatibility + # Ensure dtype matches conv_state (may be bfloat16 while final_state is float32) + conv_state.index_copy_(0, req_pool_indices_long[i:i+1], final_state.unsqueeze(0).to(conv_state.dtype)) + + start_idx = end_idx + + return torch.cat(outputs, dim=0) + + def _forward_decode( + self, + hidden_states: torch.Tensor, + conv_state: torch.Tensor, + req_pool_indices: torch.Tensor, + ) -> torch.Tensor: + """ + Decode mode: single token per sequence using cached state. + CUDA graph compatible - uses only tensor operations, no .item() calls. + """ + batch_size = hidden_states.shape[0] + + req_pool_indices_long = req_pool_indices.long() + + # in_proj for all tokens + proj = self.in_proj(hidden_states) # [B, 3H] + + # Split into gates + B_gate, C_gate, x = proj.chunk(3, dim=-1) # each [B, H] + + # Compute Bx + Bx = B_gate * x # [B, H] + + # Get conv weights - shape is [H, 1, kernel_size], we need [H, kernel_size] + conv_weights = self.conv.weight[:, 0, :] # [H, kernel_size] + + # Gather current states: [B, H, L_cache] + current_states = conv_state[req_pool_indices_long] + + # Roll states left by 1 and insert new Bx at the end + new_states = torch.cat([ + current_states[:, :, 1:], # [B, H, L_cache-1] + Bx.unsqueeze(-1) # [B, H, 1] + ], dim=-1) # [B, H, L_cache] + + # Scatter updated states back using index_copy_ (CUDA graph compatible) + # Ensure dtype matches conv_state (may be bfloat16 while new_states is float32) + conv_state.index_copy_(0, req_pool_indices_long, new_states.to(conv_state.dtype)) + + # Compute conv output: need full kernel_size inputs + # Prepend zeros to get [B, H, kernel_size] for the conv operation + # Actually, we need to apply the conv kernel to the last kernel_size values + # The state has L_cache = kernel_size - 1 values, plus the new Bx makes kernel_size + conv_input = torch.cat([ + current_states[:, :, -(self.conv_kernel - 1):], # [B, H, kernel_size-1] + Bx.unsqueeze(-1) # [B, H, 1] + ], dim=-1) # [B, H, kernel_size] + + # Apply conv weights: element-wise multiply and sum + conv_out = (conv_input * conv_weights.unsqueeze(0)).sum(dim=-1) # [B, H] + + # Add bias if present + if self.bias and self.conv.bias is not None: + conv_out = conv_out + self.conv.bias + + # Apply output gate + y = C_gate * conv_out # [B, H] + + # Apply out_proj (ensure dtype matches model weights) + y = self.out_proj(y.to(hidden_states.dtype)) # [B, H] + + return y + + +class Lfm2DecoderLayer(nn.Module): + """Decoder layer - can be attention or conv based on config.""" + def __init__( + self, + config: Lfm2Config, + layer_id: int, + attn_layer_id: int, # Sequential ID for attention layers (for KV cache) + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.layer_id = layer_id + self.layer_type = config.layer_types[layer_id] + self.is_attention_layer = self.layer_type == "full_attention" + + self.operator_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) + + if self.is_attention_layer: + self.self_attn = Lfm2Attention( + config=config, + layer_id=layer_id, + attn_layer_id=attn_layer_id, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + elif self.layer_type in ("conv", "short_conv"): + # Named 'conv' to match checkpoint (model.layers.X.conv.*) + self.conv = Lfm2ShortConv( + config=config, + layer_idx=layer_id, + quant_config=quant_config, + prefix=add_prefix("conv", prefix), + ) + else: + raise ValueError(f"Unknown layer type: {self.layer_type}") + + self.feed_forward = Lfm2MLP( + config=config, + quant_config=quant_config, + prefix=add_prefix("feed_forward", prefix), + ) + + def forward( + self, + layer_id: int, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + forward_batch: ForwardBatch, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward matching HF exactly.""" + if not forward_batch.forward_mode.is_idle(): + residual = hidden_states + normed = self.operator_norm(hidden_states) + + if self.is_attention_layer: + hidden_states = self.self_attn( + positions=positions, + hidden_states=normed, + forward_batch=forward_batch + ) + else: + hidden_states = self.conv( + hidden_states=normed, + forward_batch=forward_batch, + ) + + hidden_states = hidden_states + residual + hidden_states = hidden_states + self.feed_forward(self.ffn_norm(hidden_states)) + + return hidden_states, residual + + +class Lfm2Model(nn.Module): + def __init__( + self, + config: Lfm2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + prefix=add_prefix("embed_tokens", prefix), + ) + + # Compute attention layer IDs (sequential numbering for KV cache) + attn_layer_ids = [] + attn_count = 0 + for layer_type in config.layer_types: + if layer_type == "full_attention": + attn_layer_ids.append(attn_count) + attn_count += 1 + else: + attn_layer_ids.append(-1) + + self.num_attention_layers = attn_count + + logger.info(f"LFM2 model has {attn_count} attention layers and " + f"{len(config.layer_types) - attn_count} conv layers " + f"out of {config.num_hidden_layers} total") + + def get_layer(idx: int, prefix: str, **kwargs): + return Lfm2DecoderLayer( + config=config, + layer_id=idx, + attn_layer_id=attn_layer_ids[idx], + quant_config=quant_config, + prefix=prefix, + ) + + self.layers = make_layers(config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + + # Named embedding_norm to match checkpoint + self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = inputs_embeds if inputs_embeds is not None else self.embed_tokens(input_ids) + + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + layer_id=i, + positions=positions, + hidden_states=hidden_states, + residual=residual, + forward_batch=forward_batch, + ) + + hidden_states = self.embedding_norm(hidden_states) + return hidden_states + + +class Lfm2ForCausalLM(nn.Module): + """ + LFM2 for Causal Language Modeling. + + This model has a hybrid architecture with both attention and conv layers. + - Attention layers use standard KV cache (managed by SGLang) + - Conv layers use MambaPool for state caching (via HybridReqToTokenPool) + + IMPORTANT: For this to work, you must also modify model_runner.py to add + Lfm2Config to the mamba2_config property's isinstance check: + + from transformers import Lfm2Config + ... + if isinstance(config, FalconH1Config | NemotronHConfig | Lfm2Config): + return config + """ + fall_back_to_pt_during_load = False + + def __init__( + self, + config: Lfm2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.pp_group = get_pp_group() + assert self.pp_group.is_first_rank and self.pp_group.is_last_rank + + self.quant_config = quant_config + self.model = Lfm2Model(config, quant_config, prefix=add_prefix("model", prefix)) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + org_num_embeddings=config.vocab_size, + prefix=add_prefix("lm_head", prefix), + ) + self.logits_processor = LogitsProcessor(config) + + # Store number of attention layers for KV cache sizing + self.num_attention_layers = self.model.num_attention_layers + + def get_num_kv_cache_layers(self) -> int: + """Return the number of layers that need KV cache (attention layers only).""" + return self.num_attention_layers + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.model( + input_ids, + positions, + forward_batch, + inputs_embeds, + ) + return self.logits_processor(input_ids, hidden_states, self.lm_head, forward_batch) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False) -> Set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + + logger.info("=== Model parameter names (first 30) ===") + for i, (name, param) in enumerate(params_dict.items()): + if i >= 30: + break + logger.info(f" {name}: {param.shape}") + + loaded_params: Set[str] = set() + conv_weights_loaded = 0 + missing_params = [] + + embed_tokens_weight = None + + for name, loaded_weight in weights: + original_name = name + + if "rotary_emb.inv_freq" in name: + continue + + if "embed_tokens.weight" in name: + embed_tokens_weight = loaded_weight + + if conv_weights_loaded < 5 and ".conv." in name: + logger.info(f"Loading: {name}, shape: {loaded_weight.shape}") + + # Handle QKV stacking + did_stack = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + did_stack = True + break + if name not in params_dict: + did_stack = True + break + param = params_dict[name] + weight_loader = getattr(param, "weight_loader") + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + did_stack = True + break + if did_stack: + continue + + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + if len(missing_params) < 20: + missing_params.append(f"{original_name} -> {name}") + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + if ".conv." in name: + conv_weights_loaded += 1 + + # Handle tied lm_head weight + if "lm_head.weight" not in loaded_params and "lm_head.weight" in params_dict: + if embed_tokens_weight is not None: + logger.info("Tying lm_head.weight to embed_tokens.weight") + param = params_dict["lm_head.weight"] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, embed_tokens_weight) + loaded_params.add("lm_head.weight") + else: + logger.warning("lm_head.weight not found and no embed_tokens.weight to tie") + + if missing_params: + logger.warning(f"Missing params (first 20): {missing_params}") + + logger.info(f"Loaded {conv_weights_loaded} conv weight tensors") + logger.info(f"Total loaded params: {len(loaded_params)}") + + unloaded = set(params_dict.keys()) - loaded_params + if unloaded: + logger.warning(f"Unloaded params ({len(unloaded)}): {list(unloaded)[:10]}...") + + return loaded_params + + +EntryClass = [Lfm2ForCausalLM] From 3ae717b93304790b4fd08d28a695745843023d2f Mon Sep 17 00:00:00 2001 From: Piotr Mazurek Date: Wed, 7 Jan 2026 12:14:25 +0000 Subject: [PATCH 02/14] Add LFM2 to test_generation_models --- test/registered/models/test_generation_models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/registered/models/test_generation_models.py b/test/registered/models/test_generation_models.py index bc45d1918785..bdc2ce9b07b0 100644 --- a/test/registered/models/test_generation_models.py +++ b/test/registered/models/test_generation_models.py @@ -106,6 +106,10 @@ class ModelCase: trust_remote_code=True, skip_long_prompt=True, ), + ModelCase( + "LiquidAI/LFM2.5-1.2B-Instruct", + trust_remote_code=True, + ), ] TORCH_DTYPES = [torch.float16] From ac91a836a8da0408700fe3dc6340672f61a55b2e Mon Sep 17 00:00:00 2001 From: Piotr Mazurek Date: Wed, 7 Jan 2026 14:49:22 +0000 Subject: [PATCH 03/14] Add function calling --- .../srt/function_call/function_call_parser.py | 2 + .../sglang/srt/function_call/lfm2_detector.py | 414 ++++++++++++++++++ .../test_function_call_parser.py | 356 +++++++++++++++ 3 files changed, 772 insertions(+) create mode 100644 python/sglang/srt/function_call/lfm2_detector.py diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index 8df0a2401742..70b7f1ec3c02 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -19,6 +19,7 @@ from sglang.srt.function_call.gpt_oss_detector import GptOssDetector from sglang.srt.function_call.internlm_detector import InternlmDetector from sglang.srt.function_call.kimik2_detector import KimiK2Detector +from sglang.srt.function_call.lfm2_detector import Lfm2Detector from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.mimo_detector import MiMoDetector from sglang.srt.function_call.minimax_m2 import MinimaxM2Detector @@ -50,6 +51,7 @@ class FunctionCallParser: "glm47": Glm47MoeDetector, "gpt-oss": GptOssDetector, "kimi_k2": KimiK2Detector, + "lfm2": Lfm2Detector, "llama3": Llama32Detector, "mimo": MiMoDetector, "mistral": MistralDetector, diff --git a/python/sglang/srt/function_call/lfm2_detector.py b/python/sglang/srt/function_call/lfm2_detector.py new file mode 100644 index 000000000000..fde95c4e41f0 --- /dev/null +++ b/python/sglang/srt/function_call/lfm2_detector.py @@ -0,0 +1,414 @@ +""" +Detector for LFM2 (Liquid Foundation Model 2) function call format. + +Format Structure (Pythonic style): +``` +<|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|> +``` + +Multiple tool calls: +``` +<|tool_call_start|>[func1(arg="val"), func2(arg="val")]<|tool_call_end|> +``` + +Also supports JSON format: +``` +<|tool_call_start|>[{"name": "func_name", "arguments": {...}}]<|tool_call_end|> +``` +""" + +import ast +import json +import logging +import re +from typing import Any, Dict, List, Optional, Tuple + +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.environ import envs +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + StructureInfo, + ToolCallItem, + _GetInfoFunc, +) + +logger = logging.getLogger(__name__) + + +class Lfm2Detector(BaseFormatDetector): + """ + Detector for LFM2 (Liquid Foundation Model 2) function call format. + + Supports both Pythonic and JSON formats: + + Pythonic: + ``` + <|tool_call_start|>[calculator(expression="5 * 7")]<|tool_call_end|> + ``` + + JSON: + ``` + <|tool_call_start|>[{"name": "calculator", "arguments": {"expression": "5 * 7"}}]<|tool_call_end|> + ``` + """ + + def __init__(self): + """ + Initializes the detector with necessary state variables. + """ + super().__init__() + self.bot_token = "<|tool_call_start|>" + self.eot_token = "<|tool_call_end|>" + self.tool_call_separator = "" + + def has_tool_call(self, text: str) -> bool: + """Check if the text contains an LFM2 format tool call.""" + return self.bot_token in text + + def _get_parameter_value(self, val: ast.AST) -> Any: + """ + Extract Python literal value from AST node. + + Handles constants, dicts, and lists recursively. + Reuses pattern from PythonicDetector. + """ + if isinstance(val, ast.Constant): + return val.value + elif isinstance(val, ast.Dict): + return { + self._get_parameter_value(k): self._get_parameter_value(v) + for k, v in zip(val.keys, val.values) + if k is not None # Handle {**kwargs} case where key is None + } + elif isinstance(val, ast.List): + return [self._get_parameter_value(v) for v in val.elts] + elif isinstance(val, ast.Tuple): + return tuple(self._get_parameter_value(v) for v in val.elts) + elif isinstance(val, ast.Name): + # Handle True, False, None as names in older Python + if val.id == "True": + return True + elif val.id == "False": + return False + elif val.id == "None": + return None + else: + raise ValueError(f"Unsupported name reference: {val.id}") + elif isinstance(val, ast.UnaryOp) and isinstance(val.op, ast.USub): + # Handle negative numbers like -5 + inner = self._get_parameter_value(val.operand) + if isinstance(inner, (int, float)): + return -inner + raise ValueError(f"Cannot negate non-numeric value: {inner}") + else: + raise ValueError(f"Tool call arguments must be literals, got: {type(val).__name__}") + + def _parse_pythonic_call(self, call: ast.Call, call_index: int, tool_indices: Dict[str, int]) -> Optional[ToolCallItem]: + """ + Parse a single AST Call node into a ToolCallItem. + + Args: + call: AST Call node representing a function call + call_index: Index of this call in the list of calls + tool_indices: Mapping of tool names to their indices + + Returns: + ToolCallItem if successful, None if the call should be skipped + """ + if not isinstance(call.func, ast.Name): + logger.warning(f"Tool call function must be a simple name, got: {type(call.func).__name__}") + return None + + function_name = call.func.id + + # Validate that the function exists in the tools + if function_name not in tool_indices: + logger.warning(f"Model attempted to call undefined function: {function_name}") + if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get(): + return None # Skip unknown tools (default legacy behavior) + + # Parse arguments + arguments = {} + for keyword in call.keywords: + if keyword.arg is None: + # **kwargs unpacking - skip for now + logger.warning("Tool call with **kwargs unpacking is not supported") + continue + try: + arguments[keyword.arg] = self._get_parameter_value(keyword.value) + except ValueError as e: + logger.warning(f"Failed to parse argument {keyword.arg}: {e}") + return None + + return ToolCallItem( + tool_index=call_index, # Use the call index in the response, not tool position + name=function_name, + parameters=json.dumps(arguments, ensure_ascii=False), + ) + + def _parse_pythonic_content(self, content: str, tools: List[Tool]) -> Tuple[List[ToolCallItem], str]: + """ + Parse Pythonic format tool calls using AST. + + Args: + content: The content between tool call tags (without the tags) + tools: List of available tools + + Returns: + Tuple of (list of parsed calls, error message if any) + """ + content = content.strip() + tool_indices = self._get_tool_indices(tools) + + try: + module = ast.parse(content) + parsed = getattr(module.body[0], "value", None) if module.body else None + + if parsed is None: + return [], "Empty or invalid Python expression" + + # Handle both single call and list of calls + if isinstance(parsed, ast.List): + call_nodes = parsed.elts + elif isinstance(parsed, ast.Call): + call_nodes = [parsed] + else: + return [], f"Expected function call or list, got: {type(parsed).__name__}" + + # Validate all elements are calls + if not all(isinstance(e, ast.Call) for e in call_nodes): + return [], "Not all elements in list are function calls" + + calls = [] + for call_index, call in enumerate(call_nodes): + item = self._parse_pythonic_call(call, call_index, tool_indices) + if item is not None: + calls.append(item) + + return calls, "" + + except SyntaxError as e: + return [], f"Python syntax error: {e}" + except Exception as e: + logger.exception("Unexpected error in pythonic tool call parsing") + return [], f"Unexpected error: {e}" + + def _parse_json_content(self, content: str, tools: List[Tool]) -> Tuple[List[ToolCallItem], str]: + """ + Parse JSON format tool calls. + + Uses parse_base_json from BaseFormatDetector for consistent handling + of SGLANG_FORWARD_UNKNOWN_TOOLS and tool validation. + + Args: + content: The content between tool call tags (without the tags) + tools: List of available tools + + Returns: + Tuple of (list of parsed calls, error message if any) + """ + content = content.strip() + + try: + parsed = json.loads(content) + # parse_base_json handles list/dict normalization, tool validation, + # and SGLANG_FORWARD_UNKNOWN_TOOLS consistently with other detectors + calls = self.parse_base_json(parsed, tools) + return calls, "" + + except json.JSONDecodeError as e: + return [], f"JSON parse error: {e}" + + def _parse_tool_calls_content(self, content: str, tools: List[Tool]) -> List[ToolCallItem]: + """ + Parse the content between tool call tags. + Handles both JSON and Pythonic formats. + """ + content = content.strip() + + # First, try JSON format (faster check) + if content.startswith("[{") or content.startswith("{"): + calls, error = self._parse_json_content(content, tools) + if calls: + return calls + # If JSON parsing failed but it looked like JSON, log the error + if error: + logger.debug(f"JSON parsing failed: {error}, trying Pythonic format") + + # Try Pythonic format + calls, error = self._parse_pythonic_content(content, tools) + if calls: + return calls + + if error: + logger.warning(f"Failed to parse tool calls: {error}") + + return [] + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + """ + One-time parsing: Detects and parses tool calls in the provided text. + """ + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text + + if self.bot_token not in text: + return StreamingParseResult(normal_text=normal_text, calls=[]) + + # Find all <|tool_call_start|>...<|tool_call_end|> blocks + pattern = rf"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}" + match_result_list = re.findall(pattern, text, re.DOTALL) + + calls = [] + for match_result in match_result_list: + parsed_calls = self._parse_tool_calls_content(match_result, tools) + calls.extend(parsed_calls) + + return StreamingParseResult(normal_text=normal_text, calls=calls) + + def _strip_special_tokens(self, text: str) -> str: + """Remove special tokens from text.""" + return text.replace(self.bot_token, "").replace(self.eot_token, "") + + def _find_matching_bracket(self, buffer: str, start: int) -> int: + """ + Find the matching closing bracket for the opening bracket at start position. + Properly handles nested brackets and strings. + + Args: + buffer: The text buffer to search in + start: Position of the opening bracket '[' + + Returns: + Position of the matching closing bracket ']', or -1 if not found + """ + bracket_count = 0 + in_string = False + string_char = None + escape_next = False + + for i in range(start, len(buffer)): + char = buffer[i] + + if escape_next: + escape_next = False + continue + + if char == "\\": + escape_next = True + continue + + if char in ('"', "'") and not in_string: + in_string = True + string_char = char + elif char == string_char and in_string: + in_string = False + string_char = None + elif not in_string: + if char == "[": + bracket_count += 1 + elif char == "]": + bracket_count -= 1 + if bracket_count == 0: + return i + + return -1 # No matching bracket found + + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + """ + Streaming incremental parsing for LFM2 tool calls. + + This implementation properly handles Pythonic format by: + 1. Buffering until we see complete <|tool_call_start|>[...]<|tool_call_end|> + 2. Emitting normal text before tool calls immediately + 3. Parsing complete tool call blocks using detect_and_parse + + Based on PythonicDetector streaming logic. + """ + self._buffer += new_text + + # Check for partial bot_token at the end + partial_bot = self._ends_with_partial_token(self._buffer, self.bot_token) + partial_eot = self._ends_with_partial_token(self._buffer, self.eot_token) + + # Find bot_token position + bot_pos = self._buffer.find(self.bot_token) + + if bot_pos == -1: + # No tool call start found + if partial_bot: + # Might be partial bot_token, hold back that part + safe_text = self._buffer[:-partial_bot] + self._buffer = self._buffer[-partial_bot:] + return StreamingParseResult(normal_text=safe_text) + else: + # No tool call, emit all as normal text + normal_text = self._strip_special_tokens(self._buffer) + self._buffer = "" + return StreamingParseResult(normal_text=normal_text) + + # We have bot_token - extract any normal text before it + normal_text_before = self._buffer[:bot_pos] if bot_pos > 0 else "" + + # Look for the end token + eot_pos = self._buffer.find(self.eot_token, bot_pos + len(self.bot_token)) + + if eot_pos == -1: + # No end token yet - check if we might have a partial one + if partial_eot: + # Hold back the partial token, but we need to keep buffering + # Just emit any normal text before the tool call + if normal_text_before: + self._buffer = self._buffer[bot_pos:] + return StreamingParseResult(normal_text=normal_text_before) + # Keep buffering + return StreamingParseResult(normal_text="") + + # No end token and no partial - keep buffering but emit normal text + if normal_text_before: + self._buffer = self._buffer[bot_pos:] + return StreamingParseResult(normal_text=normal_text_before) + + # Just keep buffering + return StreamingParseResult(normal_text="") + + # We have a complete tool call block + tool_call_block = self._buffer[bot_pos:eot_pos + len(self.eot_token)] + remaining = self._buffer[eot_pos + len(self.eot_token):] + + # Parse the complete block + result = self.detect_and_parse(tool_call_block, tools) + + # Update buffer with remaining text + self._buffer = remaining + + # Add any normal text before the tool call + if normal_text_before: + result.normal_text = normal_text_before + (result.normal_text or "") + + return result + + def supports_structural_tag(self) -> bool: + """ + Return False because LFM2 uses Pythonic format which is not JSON-compatible. + + structural_tag only supports JSON-compatible content between begin and end, + so it cannot parse Pythonic function call syntax like `func(arg="val")`. + """ + return False + + def structure_info(self) -> _GetInfoFunc: + """ + Return structure info for constrained generation. + + Note: This is provided for completeness but won't be used since + supports_structural_tag() returns False. + """ + return lambda name: StructureInfo( + begin='<|tool_call_start|>[' + name + '(', + end=")]<|tool_call_end|>", + trigger="<|tool_call_start|>", + ) diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py index e2bb53f8b33a..b7e1fc99f23b 100644 --- a/test/registered/function_call/test_function_call_parser.py +++ b/test/registered/function_call/test_function_call_parser.py @@ -5,6 +5,7 @@ from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import StreamingParseResult from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector +from sglang.srt.function_call.lfm2_detector import Lfm2Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector @@ -2998,5 +2999,360 @@ def test_three_tool_calls_separate_chunks_with_commas(self): self.assertEqual(total_calls, 3, "Should have parsed exactly 3 tool calls") +class TestLfm2Detector(unittest.TestCase): + """Tests for LFM2 (Liquid Foundation Model 2) function call detector.""" + + def setUp(self): + """Set up test tools and detector.""" + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + }, + "unit": { + "type": "string", + "description": "Temperature unit", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="search", + description="Search for information", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query", + }, + }, + "required": ["query"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="calculator", + description="Perform calculations", + parameters={ + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Math expression", + }, + }, + "required": ["expression"], + }, + ), + ), + ] + self.detector = Lfm2Detector() + + # ==================== has_tool_call tests ==================== + + def test_has_tool_call_true(self): + """Test detection of tool call markers.""" + text = '<|tool_call_start|>[get_weather(city="Paris")]<|tool_call_end|>' + self.assertTrue(self.detector.has_tool_call(text)) + + def test_has_tool_call_false(self): + """Test no false positives for regular text.""" + text = "The weather in Paris is nice today." + self.assertFalse(self.detector.has_tool_call(text)) + + def test_has_tool_call_partial_marker(self): + """Test that partial markers are detected (start token present).""" + text = '<|tool_call_start|>[get_weather(city="Paris")' + self.assertTrue(self.detector.has_tool_call(text)) + + # ==================== detect_and_parse tests (Pythonic format) ==================== + + def test_detect_and_parse_pythonic_simple(self): + """Test parsing a simple Pythonic format tool call.""" + text = '<|tool_call_start|>[get_weather(city="Paris")]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[0].tool_index, 0) + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["city"], "Paris") + + def test_detect_and_parse_pythonic_multiple_args(self): + """Test parsing with multiple arguments.""" + text = '<|tool_call_start|>[get_weather(city="London", unit="celsius")]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["city"], "London") + self.assertEqual(params["unit"], "celsius") + + def test_detect_and_parse_pythonic_no_args(self): + """Test parsing function with no arguments.""" + # Add a no-arg tool for this test + tools_with_noarg = self.tools + [ + Tool( + type="function", + function=Function( + name="get_time", + description="Get current time", + parameters={"type": "object", "properties": {}}, + ), + ), + ] + text = "<|tool_call_start|>[get_time()]<|tool_call_end|>" + result = self.detector.detect_and_parse(text, tools_with_noarg) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_time") + + def test_detect_and_parse_pythonic_multiple_calls(self): + """Test parsing multiple tool calls in one block.""" + text = '<|tool_call_start|>[get_weather(city="Paris"), search(query="restaurants")]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[1].name, "search") + + params1 = json.loads(result.calls[0].parameters) + params2 = json.loads(result.calls[1].parameters) + self.assertEqual(params1["city"], "Paris") + self.assertEqual(params2["query"], "restaurants") + + def test_detect_and_parse_with_normal_text_before(self): + """Test parsing with normal text before the tool call.""" + text = 'Let me check the weather for you. <|tool_call_start|>[get_weather(city="Tokyo")]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, "Let me check the weather for you.") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + def test_detect_and_parse_special_characters_in_value(self): + """Test parsing with special characters in argument values.""" + text = '<|tool_call_start|>[search(query="what\'s the weather?")]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 1) + params = json.loads(result.calls[0].parameters) + self.assertIn("weather", params["query"]) + + def test_detect_and_parse_numeric_values(self): + """Test parsing with numeric argument values.""" + text = '<|tool_call_start|>[calculator(expression="5 * 7")]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "calculator") + + # ==================== detect_and_parse tests (JSON format) ==================== + + def test_detect_and_parse_json_simple(self): + """Test parsing JSON format tool call.""" + text = '<|tool_call_start|>[{"name": "get_weather", "arguments": {"city": "Berlin"}}]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["city"], "Berlin") + + def test_detect_and_parse_json_multiple_calls(self): + """Test parsing multiple JSON format tool calls.""" + text = '<|tool_call_start|>[{"name": "get_weather", "arguments": {"city": "Paris"}}, {"name": "search", "arguments": {"query": "hotels"}}]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[1].name, "search") + + def test_detect_and_parse_json_with_parameters_key(self): + """Test parsing JSON format with 'parameters' key instead of 'arguments'.""" + text = '<|tool_call_start|>[{"name": "get_weather", "parameters": {"city": "Madrid"}}]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 1) + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["city"], "Madrid") + + # ==================== Edge cases ==================== + + def test_detect_and_parse_no_tool_call(self): + """Test parsing text with no tool calls.""" + text = "This is just regular text without any tool calls." + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, text) + self.assertEqual(result.calls, []) + + def test_detect_and_parse_unknown_function(self): + """Test parsing with unknown function name - skipped by default (SGLANG_FORWARD_UNKNOWN_TOOLS=false).""" + text = '<|tool_call_start|>[unknown_function(arg="value")]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + # By default, unknown functions are skipped (consistent with other detectors) + self.assertEqual(len(result.calls), 0) + + def test_detect_and_parse_empty_content(self): + """Test parsing with empty content between markers.""" + text = "<|tool_call_start|><|tool_call_end|>" + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.calls, []) + + def test_detect_and_parse_multiple_blocks(self): + """Test parsing multiple separate tool call blocks.""" + text = '<|tool_call_start|>[get_weather(city="Paris")]<|tool_call_end|> Some text <|tool_call_start|>[search(query="food")]<|tool_call_end|>' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[1].name, "search") + + # ==================== Streaming tests ==================== + # The LFM2 detector buffers until it sees complete <|tool_call_start|>...<|tool_call_end|> + # blocks, then parses the complete block. This allows proper handling of both + # JSON and Pythonic formats. + + def test_streaming_json_complete_in_one_chunk(self): + """Test streaming with complete JSON tool call in one chunk.""" + text = '<|tool_call_start|>{"name": "get_weather", "arguments": {"city": "Rome"}}<|tool_call_end|>' + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + def test_streaming_json_split_across_chunks(self): + """Test streaming with JSON tool call split across multiple chunks - waits for complete block.""" + # Reset detector state + self.detector = Lfm2Detector() + + # First chunk: start marker and partial JSON (no end token) + chunk1 = '<|tool_call_start|>{"name": "get_weather", "arguments": {"city": ' + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + + # Should buffer and not emit calls yet (waiting for complete block) + self.assertEqual(len(result1.calls), 0) + self.assertEqual(result1.normal_text, "") + + # Second chunk: complete the JSON and end token + chunk2 = '"Vienna"}}<|tool_call_end|>' + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + + # Now should have the complete tool call + self.assertEqual(len(result2.calls), 1) + self.assertEqual(result2.calls[0].name, "get_weather") + + def test_streaming_json_normal_text_before_tool_call(self): + """Test streaming with normal text before JSON tool call.""" + # Reset detector state + self.detector = Lfm2Detector() + + chunk1 = "I'll check the weather. " + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + + # Normal text should be returned + self.assertIn("check the weather", result1.normal_text) + + chunk2 = '<|tool_call_start|>{"name": "get_weather", "arguments": {"city": "Amsterdam"}}<|tool_call_end|>' + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + + self.assertEqual(len(result2.calls), 1) + + def test_streaming_eot_token_filtering(self): + """Test that end-of-turn token is filtered from normal text.""" + # Reset detector state + self.detector = Lfm2Detector() + + # Send text that ends with tool call end token (JSON format) + text = '<|tool_call_start|>{"name": "get_weather", "arguments": {"city": "Oslo"}}<|tool_call_end|>' + result = self.detector.parse_streaming_increment(text, self.tools) + + # The normal_text should not contain the eot_token + self.assertNotIn("<|tool_call_end|>", result.normal_text) + + # ==================== Pythonic streaming tests ==================== + + def test_streaming_pythonic_complete_in_one_chunk(self): + """Test streaming with complete Pythonic tool call in one chunk.""" + self.detector = Lfm2Detector() + text = '<|tool_call_start|>[get_weather(city="Berlin")]<|tool_call_end|>' + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(json.loads(result.calls[0].parameters), {"city": "Berlin"}) + + def test_streaming_pythonic_split_across_chunks(self): + """Test streaming with Pythonic tool call split across multiple chunks.""" + self.detector = Lfm2Detector() + + # First chunk: start marker and partial call + chunk1 = '<|tool_call_start|>[get_weather(city="' + result1 = self.detector.parse_streaming_increment(chunk1, self.tools) + + # Should buffer and not emit calls yet + self.assertEqual(len(result1.calls), 0) + + # Second chunk: complete the call + chunk2 = 'Munich")]<|tool_call_end|>' + result2 = self.detector.parse_streaming_increment(chunk2, self.tools) + + # Now should have the complete tool call + self.assertEqual(len(result2.calls), 1) + self.assertEqual(result2.calls[0].name, "get_weather") + self.assertEqual(json.loads(result2.calls[0].parameters), {"city": "Munich"}) + + def test_streaming_pythonic_multiple_calls(self): + """Test streaming with multiple Pythonic tool calls.""" + self.detector = Lfm2Detector() + + text = '<|tool_call_start|>[get_weather(city="Paris"), search(query="hotels")]<|tool_call_end|>' + result = self.detector.parse_streaming_increment(text, self.tools) + + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[1].name, "search") + + # ==================== structure_info tests ==================== + + def test_supports_structural_tag(self): + """Test that LFM2 does not support structural tags (Pythonic format).""" + # LFM2 uses Pythonic format which is not JSON-compatible, + # so structural_tag constrained generation cannot be used + self.assertFalse(self.detector.supports_structural_tag()) + + def test_structure_info(self): + """Test structure info for constrained generation.""" + info_func = self.detector.structure_info() + info = info_func("get_weather") + + self.assertEqual(info.begin, "<|tool_call_start|>[get_weather(") + self.assertEqual(info.end, ")]<|tool_call_end|>") + self.assertEqual(info.trigger, "<|tool_call_start|>") + + if __name__ == "__main__": unittest.main() From de774563709f0e9ed001993d691a8fce1a62f522 Mon Sep 17 00:00:00 2001 From: Piotr Mazurek Date: Wed, 7 Jan 2026 15:43:01 +0000 Subject: [PATCH 04/14] Clean up the lfm2 implementation --- python/sglang/srt/models/lfm2.py | 626 +++++++++---------------------- 1 file changed, 179 insertions(+), 447 deletions(-) diff --git a/python/sglang/srt/models/lfm2.py b/python/sglang/srt/models/lfm2.py index 814a4bcb9373..f30e9b77a1aa 100644 --- a/python/sglang/srt/models/lfm2.py +++ b/python/sglang/srt/models/lfm2.py @@ -1,147 +1,46 @@ -# sglang/srt/models/lfm2.py -# LFM2 implementation for SGLang -# Based on HuggingFace's implementation -# -# This version uses SGLang's hybrid caching infrastructure (HybridReqToTokenPool + MambaPool) -# while keeping the original working model structure and weight names. -# -# IMPORTANT: This file patches Lfm2Config at MODULE IMPORT TIME to add the properties -# required by model_runner.py for hybrid cache detection. You must also add Lfm2Config -# to the isinstance check in model_runner.py's mamba2_config property. +""" +LFM2 (Liquid Foundation Model 2) implementation for SGLang. + +This is a hybrid architecture with both attention and short conv layers. +- Attention layers use standard KV cache (RadixAttention) +- Conv layers use MambaPool for state caching (via HybridReqToTokenPool) + +The model uses a gated 1D causal convolution (kernel=3) instead of attention +in some layers, providing linear memory complexity for those layers. +""" import logging -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch -from torch import nn import torch.nn.functional as F +from torch import nn -from transformers import Lfm2Config - +from sglang.srt.configs.lfm2 import Lfm2Config from sglang.srt.distributed import get_pp_group -from sglang.srt.layers.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope -from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, make_layers -# Import for Mamba state management - use SGLang's actual classes -from sglang.srt.configs.mamba_utils import ( - Mamba2StateShape, - Mamba2CacheParams, -) - logger = logging.getLogger(__name__) -# Debug flag - set to True to enable debug logging -DEBUG_LFM2 = False - - -def debug_tensor(name: str, t: torch.Tensor): - if DEBUG_LFM2: - logger.info(f"DEBUG {name}: shape={t.shape}, dtype={t.dtype}, " - f"min={t.min().item():.6f}, max={t.max().item():.6f}, " - f"mean={t.mean().item():.6f}, std={t.std().item():.6f}") - - -# ============================================================================ -# Config Patching - MUST happen at module import time -# ============================================================================ - -def _patch_lfm2_config_class(): - """ - Patch Lfm2Config CLASS (not instance) with properties required by model_runner.py. - - This must happen at module import time, BEFORE model_runner.py checks the config type. - - model_runner.py's mamba2_config property does: - if isinstance(config, FalconH1Config | NemotronHConfig | Lfm2Config): - return config - - And then uses config.mamba2_cache_params to set up HybridReqToTokenPool. - """ - if getattr(Lfm2Config, '_sglang_patched', False): - return - - def _get_full_attention_layer_ids(self) -> List[int]: - """Return indices of attention layers for KV cache.""" - return [i for i, lt in enumerate(self.layer_types) if lt == "full_attention"] - - def _get_linear_layer_ids(self) -> List[int]: - """Return indices of conv layers for conv state cache.""" - return [i for i, lt in enumerate(self.layer_types) if lt in ("conv", "short_conv")] - - def _get_mamba_chunk_size(self) -> int: - """Return chunk size for Mamba2 backend. LFM2 doesn't use chunking, return 1.""" - return 1 - - def _get_mamba2_cache_params(self) -> Optional[Mamba2CacheParams]: - """ - Get cache params for HybridReqToTokenPool initialization. - - Uses SGLang's Mamba2StateShape to describe the conv state shape. - LFM2 only uses the conv state (no SSM temporal state), so we set - state_size=0 which makes temporal state minimal. - """ - from sglang.srt.layers.dp_attention import get_attention_tp_size - - conv_layer_ids = [i for i, lt in enumerate(self.layer_types) if lt in ("conv", "short_conv")] - if not conv_layer_ids: - return None - - hidden_size = self.hidden_size - # conv_L_cache in config is kernel_size (e.g., 3) - conv_kernel = int(self.conv_L_cache) - L_cache = conv_kernel - 1 # actual cache size - tp_size = get_attention_tp_size() - - # Create Mamba2StateShape compatible with SGLang's infrastructure - # For LFM2 conv layers: - # - Conv state shape: (hidden_size, L_cache) per layer - # - No SSM temporal state (state_size=0) - # - # Mamba2StateShape.create() computes: - # - conv_dim = intermediate_size // tp_size (for TP sharding) - # - We want conv_dim = hidden_size, so intermediate_size = hidden_size * tp_size - shape = Mamba2StateShape.create( - tp_world_size=tp_size, - intermediate_size=hidden_size * tp_size, # Results in conv_dim = hidden_size - n_groups=1, - num_heads=1, - head_dim=1, - state_size=0, # No SSM state, only conv state - conv_kernel=conv_kernel, - ) - - return Mamba2CacheParams(shape=shape, layers=conv_layer_ids) - - # Patch the CLASS, not instances - Lfm2Config.full_attention_layer_ids = property(_get_full_attention_layer_ids) - Lfm2Config.linear_layer_ids = property(_get_linear_layer_ids) - Lfm2Config.mamba2_cache_params = property(_get_mamba2_cache_params) - Lfm2Config.mamba_chunk_size = property(_get_mamba_chunk_size) - Lfm2Config._sglang_patched = True - - logger.info("Patched Lfm2Config class with SGLang hybrid cache properties") - - -# Patch at module import time - this runs when lfm2.py is imported -_patch_lfm2_config_class() - - -# ============================================================================ -# Model Components -# ============================================================================ class Lfm2RMSNorm(nn.Module): - """ - LFM2-specific RMSNorm that uses weight * x (NOT (1 + weight) * x like Gemma). - This matches the HuggingFace Lfm2RMSNorm implementation exactly. - """ + """LFM2-specific RMSNorm: weight * x (not (1 + weight) * x like Gemma).""" + def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) @@ -156,7 +55,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Lfm2MLP(nn.Module): - """MLP with SwiGLU activation - uses w1/w2/w3 naming to match checkpoint.""" + """MLP with SwiGLU activation.""" + def __init__( self, config: Lfm2Config, @@ -165,31 +65,35 @@ def __init__( ): super().__init__() intermediate_size = config.intermediate_size + if config.block_auto_adjust_ff_dim: intermediate_size = int(2 * intermediate_size / 3) if config.block_ffn_dim_multiplier is not None: - intermediate_size = int(config.block_ffn_dim_multiplier * intermediate_size) + intermediate_size = int( + config.block_ffn_dim_multiplier * intermediate_size + ) intermediate_size = config.block_multiple_of * ( - (intermediate_size + config.block_multiple_of - 1) // config.block_multiple_of + (intermediate_size + config.block_multiple_of - 1) + // config.block_multiple_of ) self.w1 = ColumnParallelLinear( - input_size=config.hidden_size, - output_size=intermediate_size, + config.hidden_size, + intermediate_size, bias=False, quant_config=quant_config, prefix=add_prefix("w1", prefix), ) self.w3 = ColumnParallelLinear( - input_size=config.hidden_size, - output_size=intermediate_size, + config.hidden_size, + intermediate_size, bias=False, quant_config=quant_config, prefix=add_prefix("w3", prefix), ) self.w2 = RowParallelLinear( - input_size=intermediate_size, - output_size=config.hidden_size, + intermediate_size, + config.hidden_size, bias=False, quant_config=quant_config, prefix=add_prefix("w2", prefix), @@ -198,47 +102,42 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: gate, _ = self.w1(x) up, _ = self.w3(x) - h = F.silu(gate) * up - out, _ = self.w2(h) + out, _ = self.w2(F.silu(gate) * up) return out class Lfm2Attention(nn.Module): - """Attention with RoPE and Q/K layernorm - matches checkpoint weight names.""" + """Grouped-query attention with RoPE and Q/K layernorm.""" + def __init__( self, config: Lfm2Config, layer_id: int, - attn_layer_id: int, # Sequential ID for attention layers only (for KV cache) + attn_layer_id: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() - self.config = config - self.layer_id = layer_id - self.attn_layer_id = attn_layer_id - self.hidden_size = config.hidden_size self.total_num_heads = config.num_attention_heads self.total_num_kv_heads = config.num_key_value_heads - self.head_dim = getattr(config, "head_dim", None) or (self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) or ( + self.hidden_size // self.total_num_heads + ) self.scaling = self.head_dim**-0.5 rope_parameters = getattr(config, "rope_parameters", None) if rope_parameters is not None and "rope_theta" in rope_parameters: - self.rope_theta = rope_parameters["rope_theta"] + rope_theta = rope_parameters["rope_theta"] else: - self.rope_theta = getattr(config, "rope_theta", 10000) - - self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - self.rope_scaling = getattr(config, "rope_scaling", None) + rope_theta = getattr(config, "rope_theta", 10000) self.rotary_emb = get_rope( head_size=self.head_dim, rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - rope_scaling=self.rope_scaling, - base=self.rope_theta, + max_position=getattr(config, "max_position_embeddings", 8192), + rope_scaling=getattr(config, "rope_scaling", None), + base=rope_theta, is_neox_style=True, dtype=torch.get_default_dtype(), ) @@ -252,8 +151,6 @@ def __init__( quant_config=quant_config, prefix=add_prefix("qkv_proj", prefix), ) - - # Named out_proj to match checkpoint self.out_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, @@ -262,7 +159,6 @@ def __init__( prefix=add_prefix("out_proj", prefix), ) - # Named q_layernorm/k_layernorm to match checkpoint self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) @@ -274,7 +170,7 @@ def __init__( head_dim=self.head_dim, scaling=self.scaling, num_kv_heads=self.num_local_kv_heads, - layer_id=self.layer_id, # Use global layer ID for routing in hybrid backend + layer_id=layer_id, prefix=add_prefix("attn", prefix), ) @@ -289,35 +185,32 @@ def forward( q_size = self.num_local_q_heads * self.head_dim kv_size = self.num_local_kv_heads * self.head_dim - q, k, v = torch.split(qkv, [q_size, kv_size, kv_size], dim=-1) q = q.reshape(T, self.num_local_q_heads, self.head_dim) k = k.reshape(T, self.num_local_kv_heads, self.head_dim) - q = self.q_layernorm(q.reshape(-1, self.head_dim)).reshape(T, self.num_local_q_heads, self.head_dim) - k = self.k_layernorm(k.reshape(-1, self.head_dim)).reshape(T, self.num_local_kv_heads, self.head_dim) + q = self.q_layernorm(q.reshape(-1, self.head_dim)).reshape( + T, self.num_local_q_heads, self.head_dim + ) + k = self.k_layernorm(k.reshape(-1, self.head_dim)).reshape( + T, self.num_local_kv_heads, self.head_dim + ) q, k = self.rotary_emb(positions, q, k) - q = q.reshape(T, -1) - k = k.reshape(T, -1) - - attn_out = self.attn(q, k, v, forward_batch) - + attn_out = self.attn(q.reshape(T, -1), k.reshape(T, -1), v, forward_batch) out, _ = self.out_proj(attn_out) return out class Lfm2ShortConv(nn.Module): """ - Short conv implementation using SGLang's MambaPool for state management. - - This implementation: - 1. Uses nn.Linear for in_proj/out_proj (matching HF checkpoint) - 2. Accesses conv state through HybridReqToTokenPool.mamba2_layer_cache() - 3. Handles prefill and decode modes properly - 4. Is CUDA graph compatible (uses index_copy_ instead of .item()) + Gated short convolution layer using SGLang's MambaPool for state management. + + Architecture: in_proj -> split(B, C, x) -> Bx -> conv1d -> C*conv_out -> out_proj + - Uses double gating: B (before conv) and C (after conv) + - Fixed-size cache: stores last (kernel_size - 1) tokens """ def __init__( @@ -328,35 +221,21 @@ def __init__( prefix: str = "", ): super().__init__() - self.config = config self.layer_idx = layer_idx - # conv_L_cache in config is the kernel size (e.g., 3), NOT kernel-1 - # The "cache" stores kernel_size - 1 values for causal conv - self.conv_kernel = int(config.conv_L_cache) # kernel_size from config - self.L_cache = self.conv_kernel - 1 # actual cache size = kernel - 1 + self.conv_kernel = int(config.conv_L_cache) + self.L_cache = self.conv_kernel - 1 self.bias = bool(config.conv_bias) self.hidden_size = config.hidden_size - # Match HF exactly - use nn.Linear (not parallel versions) - self.in_proj = nn.Linear( - config.hidden_size, - 3 * config.hidden_size, - bias=self.bias, - ) - self.out_proj = nn.Linear( + self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.bias) + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.bias) + self.conv = nn.Conv1d( config.hidden_size, config.hidden_size, - bias=self.bias, - ) - - # Depthwise conv1d with causal padding - self.conv = nn.Conv1d( - in_channels=config.hidden_size, - out_channels=config.hidden_size, kernel_size=self.conv_kernel, groups=config.hidden_size, bias=self.bias, - padding=self.L_cache, # Causal padding = kernel_size - 1 + padding=self.L_cache, ) def forward( @@ -364,51 +243,22 @@ def forward( hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: - """ - Forward pass using SGLang's hybrid caching infrastructure. - - Conv state is accessed through: - forward_batch.req_to_token_pool.mamba2_layer_cache(self.layer_idx) - """ - forward_mode = forward_batch.forward_mode - - if forward_mode.is_idle(): + if forward_batch.forward_mode.is_idle(): return hidden_states - - # Get conv cache through HybridReqToTokenPool - # mamba2_layer_cache returns a cache object with .conv and .temporal attributes + layer_cache = forward_batch.req_to_token_pool.mamba2_layer_cache(self.layer_idx) - # conv is a list of tensors, one per conv state component - # For Mamba2, conv[0] has shape [pool_size+1, conv_dim, conv_kernel-1] - # For LFM2, this is [pool_size+1, hidden_size, L_cache] conv_state = layer_cache.conv[0] - - # Get request pool indices for current batch req_pool_indices = forward_batch.req_pool_indices - - if forward_mode.is_decode(): - return self._forward_decode( - hidden_states, - conv_state, - req_pool_indices, + + if forward_batch.forward_mode.is_decode(): + return self._forward_decode(hidden_states, conv_state, req_pool_indices) + + seq_lens = getattr(forward_batch, "extend_seq_lens", None) + if seq_lens is not None and len(seq_lens) > 1: + return self._forward_prefill_multi( + hidden_states, conv_state, req_pool_indices, seq_lens ) - else: - # Prefill/extend mode - seq_lens = getattr(forward_batch, 'extend_seq_lens', None) - - if seq_lens is not None and len(seq_lens) > 1: - return self._forward_prefill_multi( - hidden_states, - conv_state, - req_pool_indices, - seq_lens, - ) - else: - return self._forward_prefill_single( - hidden_states, - conv_state, - req_pool_indices, - ) + return self._forward_prefill_single(hidden_states, conv_state, req_pool_indices) def _forward_prefill_single( self, @@ -416,44 +266,29 @@ def _forward_prefill_single( conv_state: torch.Tensor, req_pool_indices: torch.Tensor, ) -> torch.Tensor: - """Prefill for a single sequence - matches HF slow_forward exactly.""" T = hidden_states.shape[0] - - # Step 1: in_proj - proj = self.in_proj(hidden_states) # [T, 3H] - - # Step 2: transpose to [1, 3H, T] for conv - proj_t = proj.transpose(0, 1).unsqueeze(0) # [1, 3H, T] - - # Step 3: chunk into B, C, x - B_gate, C_gate, x = proj_t.chunk(3, dim=1) # each [1, H, T] - - # Step 4: Bx = B * x - Bx = B_gate * x # [1, H, T] - - # Step 5: conv with causal padding (output is truncated to T) - conv_out = self.conv(Bx)[..., :T] # [1, H, T] - - # Step 6: y = C * conv_out - y = C_gate * conv_out # [1, H, T] - - # Step 7: transpose back - y = y.squeeze(0).transpose(0, 1) # [T, H] - - # Step 8: out_proj - y = self.out_proj(y) # [T, H] - - # Store the final conv state (last L_cache values of Bx) + + proj = self.in_proj(hidden_states) + proj_t = proj.transpose(0, 1).unsqueeze(0) + B_gate, C_gate, x = proj_t.chunk(3, dim=1) + Bx = B_gate * x + conv_out = self.conv(Bx)[..., :T] + y = C_gate * conv_out + y = self.out_proj(y.squeeze(0).transpose(0, 1)) + + # Store final conv state if T >= self.L_cache: - final_state = Bx[0, :, -self.L_cache:] # [H, L_cache] + final_state = Bx[0, :, -self.L_cache :] else: - final_state = F.pad(Bx[0], (self.L_cache - T, 0), value=0.0) # [H, L_cache] - - # Store for the request using index_copy_ (CUDA graph compatible) - # Ensure dtype matches conv_state (may be bfloat16 while final_state is float32) + final_state = F.pad(Bx[0], (self.L_cache - T, 0), value=0.0) + if req_pool_indices.numel() > 0: - conv_state.index_copy_(0, req_pool_indices[:1].long(), final_state.unsqueeze(0).to(conv_state.dtype)) - + conv_state.index_copy_( + 0, + req_pool_indices[:1].long(), + final_state.unsqueeze(0).to(conv_state.dtype), + ) + return y def _forward_prefill_multi( @@ -463,44 +298,40 @@ def _forward_prefill_multi( req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, ) -> torch.Tensor: - """Process multiple sequences separately to avoid cross-contamination.""" outputs = [] start_idx = 0 - - seq_lens_list = seq_lens.tolist() if isinstance(seq_lens, torch.Tensor) else list(seq_lens) + seq_lens_list = ( + seq_lens.tolist() if isinstance(seq_lens, torch.Tensor) else list(seq_lens) + ) req_pool_indices_long = req_pool_indices.long() - + for i, seq_len in enumerate(seq_lens_list): seq_len = int(seq_len) end_idx = start_idx + seq_len - seq_hidden = hidden_states[start_idx:end_idx] T = seq_hidden.shape[0] - - # Process this sequence + proj = self.in_proj(seq_hidden) proj_t = proj.transpose(0, 1).unsqueeze(0) B_gate, C_gate, x = proj_t.chunk(3, dim=1) Bx = B_gate * x conv_out = self.conv(Bx)[..., :T] y = C_gate * conv_out - y = y.squeeze(0).transpose(0, 1) - y = self.out_proj(y) - + y = self.out_proj(y.squeeze(0).transpose(0, 1)) outputs.append(y) - - # Store conv state for this sequence + if T >= self.L_cache: - final_state = Bx[0, :, -self.L_cache:] + final_state = Bx[0, :, -self.L_cache :] else: final_state = F.pad(Bx[0], (self.L_cache - T, 0), value=0.0) - - # Use index_copy_ for CUDA graph compatibility - # Ensure dtype matches conv_state (may be bfloat16 while final_state is float32) - conv_state.index_copy_(0, req_pool_indices_long[i:i+1], final_state.unsqueeze(0).to(conv_state.dtype)) - + + conv_state.index_copy_( + 0, + req_pool_indices_long[i : i + 1], + final_state.unsqueeze(0).to(conv_state.dtype), + ) start_idx = end_idx - + return torch.cat(outputs, dim=0) def _forward_decode( @@ -509,77 +340,48 @@ def _forward_decode( conv_state: torch.Tensor, req_pool_indices: torch.Tensor, ) -> torch.Tensor: - """ - Decode mode: single token per sequence using cached state. - CUDA graph compatible - uses only tensor operations, no .item() calls. - """ - batch_size = hidden_states.shape[0] - req_pool_indices_long = req_pool_indices.long() - - # in_proj for all tokens - proj = self.in_proj(hidden_states) # [B, 3H] - - # Split into gates - B_gate, C_gate, x = proj.chunk(3, dim=-1) # each [B, H] - - # Compute Bx - Bx = B_gate * x # [B, H] - - # Get conv weights - shape is [H, 1, kernel_size], we need [H, kernel_size] - conv_weights = self.conv.weight[:, 0, :] # [H, kernel_size] - - # Gather current states: [B, H, L_cache] + + proj = self.in_proj(hidden_states) + B_gate, C_gate, x = proj.chunk(3, dim=-1) + Bx = B_gate * x + + conv_weights = self.conv.weight[:, 0, :] current_states = conv_state[req_pool_indices_long] - - # Roll states left by 1 and insert new Bx at the end - new_states = torch.cat([ - current_states[:, :, 1:], # [B, H, L_cache-1] - Bx.unsqueeze(-1) # [B, H, 1] - ], dim=-1) # [B, H, L_cache] - - # Scatter updated states back using index_copy_ (CUDA graph compatible) - # Ensure dtype matches conv_state (may be bfloat16 while new_states is float32) - conv_state.index_copy_(0, req_pool_indices_long, new_states.to(conv_state.dtype)) - - # Compute conv output: need full kernel_size inputs - # Prepend zeros to get [B, H, kernel_size] for the conv operation - # Actually, we need to apply the conv kernel to the last kernel_size values - # The state has L_cache = kernel_size - 1 values, plus the new Bx makes kernel_size - conv_input = torch.cat([ - current_states[:, :, -(self.conv_kernel - 1):], # [B, H, kernel_size-1] - Bx.unsqueeze(-1) # [B, H, 1] - ], dim=-1) # [B, H, kernel_size] - - # Apply conv weights: element-wise multiply and sum - conv_out = (conv_input * conv_weights.unsqueeze(0)).sum(dim=-1) # [B, H] - - # Add bias if present + + # Update state: roll left, insert new value at end + new_states = torch.cat( + [current_states[:, :, 1:], Bx.unsqueeze(-1)], dim=-1 + ) + conv_state.index_copy_( + 0, req_pool_indices_long, new_states.to(conv_state.dtype) + ) + + # Apply conv: use last kernel_size values + conv_input = torch.cat( + [current_states[:, :, -(self.conv_kernel - 1) :], Bx.unsqueeze(-1)], dim=-1 + ) + conv_out = (conv_input * conv_weights.unsqueeze(0)).sum(dim=-1) + if self.bias and self.conv.bias is not None: conv_out = conv_out + self.conv.bias - - # Apply output gate - y = C_gate * conv_out # [B, H] - - # Apply out_proj (ensure dtype matches model weights) - y = self.out_proj(y.to(hidden_states.dtype)) # [B, H] - return y + y = C_gate * conv_out + return self.out_proj(y.to(hidden_states.dtype)) class Lfm2DecoderLayer(nn.Module): - """Decoder layer - can be attention or conv based on config.""" + """Decoder layer - either attention or conv based on config.""" + def __init__( self, config: Lfm2Config, layer_id: int, - attn_layer_id: int, # Sequential ID for attention layers (for KV cache) + attn_layer_id: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() - self.config = config - self.layer_id = layer_id self.layer_type = config.layer_types[layer_id] self.is_attention_layer = self.layer_type == "full_attention" @@ -594,16 +396,13 @@ def __init__( quant_config=quant_config, prefix=add_prefix("self_attn", prefix), ) - elif self.layer_type in ("conv", "short_conv"): - # Named 'conv' to match checkpoint (model.layers.X.conv.*) + else: self.conv = Lfm2ShortConv( config=config, layer_idx=layer_id, quant_config=quant_config, prefix=add_prefix("conv", prefix), ) - else: - raise ValueError(f"Unknown layer type: {self.layer_type}") self.feed_forward = Lfm2MLP( config=config, @@ -620,26 +419,20 @@ def forward( forward_batch: ForwardBatch, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward matching HF exactly.""" if not forward_batch.forward_mode.is_idle(): residual = hidden_states normed = self.operator_norm(hidden_states) - + if self.is_attention_layer: - hidden_states = self.self_attn( - positions=positions, - hidden_states=normed, - forward_batch=forward_batch - ) + hidden_states = self.self_attn(positions, normed, forward_batch) else: - hidden_states = self.conv( - hidden_states=normed, - forward_batch=forward_batch, - ) - + hidden_states = self.conv(normed, forward_batch) + hidden_states = hidden_states + residual - hidden_states = hidden_states + self.feed_forward(self.ffn_norm(hidden_states)) - + hidden_states = hidden_states + self.feed_forward( + self.ffn_norm(hidden_states) + ) + return hidden_states, residual @@ -660,7 +453,7 @@ def __init__( prefix=add_prefix("embed_tokens", prefix), ) - # Compute attention layer IDs (sequential numbering for KV cache) + # Compute attention layer IDs for KV cache attn_layer_ids = [] attn_count = 0 for layer_type in config.layer_types: @@ -669,12 +462,8 @@ def __init__( attn_count += 1 else: attn_layer_ids.append(-1) - + self.num_attention_layers = attn_count - - logger.info(f"LFM2 model has {attn_count} attention layers and " - f"{len(config.layer_types) - attn_count} conv layers " - f"out of {config.num_hidden_layers} total") def get_layer(idx: int, prefix: str, **kwargs): return Lfm2DecoderLayer( @@ -685,9 +474,9 @@ def get_layer(idx: int, prefix: str, **kwargs): prefix=prefix, ) - self.layers = make_layers(config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") - - # Named embedding_norm to match checkpoint + self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) def forward( @@ -697,12 +486,13 @@ def forward( forward_batch: ForwardBatch, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = inputs_embeds if inputs_embeds is not None else self.embed_tokens(input_ids) + hidden_states = ( + inputs_embeds if inputs_embeds is not None else self.embed_tokens(input_ids) + ) residual = None for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states, residual = layer( + hidden_states, residual = self.layers[i]( layer_id=i, positions=positions, hidden_states=hidden_states, @@ -710,26 +500,12 @@ def forward( forward_batch=forward_batch, ) - hidden_states = self.embedding_norm(hidden_states) - return hidden_states + return self.embedding_norm(hidden_states) class Lfm2ForCausalLM(nn.Module): - """ - LFM2 for Causal Language Modeling. - - This model has a hybrid architecture with both attention and conv layers. - - Attention layers use standard KV cache (managed by SGLang) - - Conv layers use MambaPool for state caching (via HybridReqToTokenPool) - - IMPORTANT: For this to work, you must also modify model_runner.py to add - Lfm2Config to the mamba2_config property's isinstance check: - - from transformers import Lfm2Config - ... - if isinstance(config, FalconH1Config | NemotronHConfig | Lfm2Config): - return config - """ + """LFM2 for causal language modeling with hybrid attention/conv architecture.""" + fall_back_to_pt_during_load = False def __init__( @@ -753,12 +529,9 @@ def __init__( prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) - - # Store number of attention layers for KV cache sizing self.num_attention_layers = self.model.num_attention_layers def get_num_kv_cache_layers(self) -> int: - """Return the number of layers that need KV cache (attention layers only).""" return self.num_attention_layers @torch.no_grad() @@ -770,15 +543,14 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ): - hidden_states = self.model( - input_ids, - positions, - forward_batch, - inputs_embeds, + hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch ) - return self.logits_processor(input_ids, hidden_states, self.lm_head, forward_batch) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False) -> Set[str]: + def load_weights( + self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False + ) -> Set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -786,89 +558,49 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool ] params_dict = dict(self.named_parameters()) - - logger.info("=== Model parameter names (first 30) ===") - for i, (name, param) in enumerate(params_dict.items()): - if i >= 30: - break - logger.info(f" {name}: {param.shape}") - loaded_params: Set[str] = set() - conv_weights_loaded = 0 - missing_params = [] - embed_tokens_weight = None for name, loaded_weight in weights: - original_name = name - if "rotary_emb.inv_freq" in name: continue - + if "embed_tokens.weight" in name: embed_tokens_weight = loaded_weight - if conv_weights_loaded < 5 and ".conv." in name: - logger.info(f"Loading: {name}, shape: {loaded_weight.shape}") - # Handle QKV stacking - did_stack = False for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) if name.endswith(".bias") and name not in params_dict: - did_stack = True break if name not in params_dict: - did_stack = True break param = params_dict[name] weight_loader = getattr(param, "weight_loader") weight_loader(param, loaded_weight, shard_id) loaded_params.add(name) - did_stack = True break - if did_stack: - continue - - if name.endswith(".bias") and name not in params_dict: - continue - - if name not in params_dict: - if len(missing_params) < 20: - missing_params.append(f"{original_name} -> {name}") - continue + else: + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - - if ".conv." in name: - conv_weights_loaded += 1 + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) # Handle tied lm_head weight if "lm_head.weight" not in loaded_params and "lm_head.weight" in params_dict: if embed_tokens_weight is not None: - logger.info("Tying lm_head.weight to embed_tokens.weight") param = params_dict["lm_head.weight"] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, embed_tokens_weight) loaded_params.add("lm_head.weight") - else: - logger.warning("lm_head.weight not found and no embed_tokens.weight to tie") - - if missing_params: - logger.warning(f"Missing params (first 20): {missing_params}") - - logger.info(f"Loaded {conv_weights_loaded} conv weight tensors") - logger.info(f"Total loaded params: {len(loaded_params)}") - - unloaded = set(params_dict.keys()) - loaded_params - if unloaded: - logger.warning(f"Unloaded params ({len(unloaded)}): {list(unloaded)[:10]}...") - + return loaded_params From dafdb3b70476d6168384c768d7790fa26c43f9dc Mon Sep 17 00:00:00 2001 From: Piotr Mazurek Date: Wed, 7 Jan 2026 16:49:47 +0000 Subject: [PATCH 05/14] LMF2 - optimized conv1d forward pass --- python/sglang/srt/models/lfm2.py | 190 +++++++++++-------------------- 1 file changed, 67 insertions(+), 123 deletions(-) diff --git a/python/sglang/srt/models/lfm2.py b/python/sglang/srt/models/lfm2.py index f30e9b77a1aa..e8d9fc18cacd 100644 --- a/python/sglang/srt/models/lfm2.py +++ b/python/sglang/srt/models/lfm2.py @@ -7,6 +7,8 @@ The model uses a gated 1D causal convolution (kernel=3) instead of attention in some layers, providing linear memory complexity for those layers. + +Uses optimized causal_conv1d kernels from the mamba package for fast inference. """ import logging @@ -17,6 +19,10 @@ from torch import nn from sglang.srt.configs.lfm2 import Lfm2Config +from sglang.srt.layers.attention.mamba.causal_conv1d import ( + causal_conv1d_fn, + causal_conv1d_update, +) from sglang.srt.distributed import get_pp_group from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -206,11 +212,12 @@ def forward( class Lfm2ShortConv(nn.Module): """ - Gated short convolution layer using SGLang's MambaPool for state management. + Gated short convolution layer using optimized causal_conv1d kernels. Architecture: in_proj -> split(B, C, x) -> Bx -> conv1d -> C*conv_out -> out_proj - Uses double gating: B (before conv) and C (after conv) - Fixed-size cache: stores last (kernel_size - 1) tokens + - Uses causal_conv1d_fn for prefill and causal_conv1d_update for decode """ def __init__( @@ -224,19 +231,19 @@ def __init__( self.layer_idx = layer_idx self.conv_kernel = int(config.conv_L_cache) self.L_cache = self.conv_kernel - 1 - self.bias = bool(config.conv_bias) + self.use_bias = bool(config.conv_bias) self.hidden_size = config.hidden_size - self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.bias) - self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.bias) - self.conv = nn.Conv1d( - config.hidden_size, - config.hidden_size, - kernel_size=self.conv_kernel, - groups=config.hidden_size, - bias=self.bias, - padding=self.L_cache, - ) + self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.use_bias) + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.use_bias) + + # Conv weights stored in format matching causal_conv1d: (hidden_size, kernel_size) + # Weight loading will handle conversion from HF's (hidden_size, 1, kernel_size) + self.conv_weight = nn.Parameter(torch.empty(config.hidden_size, self.conv_kernel)) + if self.use_bias: + self.conv_bias = nn.Parameter(torch.empty(config.hidden_size)) + else: + self.register_parameter("conv_bias", None) def forward( self, @@ -250,124 +257,50 @@ def forward( conv_state = layer_cache.conv[0] req_pool_indices = forward_batch.req_pool_indices - if forward_batch.forward_mode.is_decode(): - return self._forward_decode(hidden_states, conv_state, req_pool_indices) - - seq_lens = getattr(forward_batch, "extend_seq_lens", None) - if seq_lens is not None and len(seq_lens) > 1: - return self._forward_prefill_multi( - hidden_states, conv_state, req_pool_indices, seq_lens - ) - return self._forward_prefill_single(hidden_states, conv_state, req_pool_indices) - - def _forward_prefill_single( - self, - hidden_states: torch.Tensor, - conv_state: torch.Tensor, - req_pool_indices: torch.Tensor, - ) -> torch.Tensor: - T = hidden_states.shape[0] - + # Project and split into gates: B (pre-conv), C (post-conv), x (input) proj = self.in_proj(hidden_states) - proj_t = proj.transpose(0, 1).unsqueeze(0) - B_gate, C_gate, x = proj_t.chunk(3, dim=1) + B_gate, C_gate, x = proj.chunk(3, dim=-1) Bx = B_gate * x - conv_out = self.conv(Bx)[..., :T] - y = C_gate * conv_out - y = self.out_proj(y.squeeze(0).transpose(0, 1)) - # Store final conv state - if T >= self.L_cache: - final_state = Bx[0, :, -self.L_cache :] - else: - final_state = F.pad(Bx[0], (self.L_cache - T, 0), value=0.0) - - if req_pool_indices.numel() > 0: - conv_state.index_copy_( - 0, - req_pool_indices[:1].long(), - final_state.unsqueeze(0).to(conv_state.dtype), + if forward_batch.forward_mode.is_decode(): + # Decode: single token per request, use optimized update kernel + conv_out = causal_conv1d_update( + Bx, + conv_state, + self.conv_weight, + self.conv_bias, + activation=None, + conv_state_indices=req_pool_indices.to(torch.int32), ) - - return y - - def _forward_prefill_multi( - self, - hidden_states: torch.Tensor, - conv_state: torch.Tensor, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - ) -> torch.Tensor: - outputs = [] - start_idx = 0 - seq_lens_list = ( - seq_lens.tolist() if isinstance(seq_lens, torch.Tensor) else list(seq_lens) - ) - req_pool_indices_long = req_pool_indices.long() - - for i, seq_len in enumerate(seq_lens_list): - seq_len = int(seq_len) - end_idx = start_idx + seq_len - seq_hidden = hidden_states[start_idx:end_idx] - T = seq_hidden.shape[0] - - proj = self.in_proj(seq_hidden) - proj_t = proj.transpose(0, 1).unsqueeze(0) - B_gate, C_gate, x = proj_t.chunk(3, dim=1) - Bx = B_gate * x - conv_out = self.conv(Bx)[..., :T] - y = C_gate * conv_out - y = self.out_proj(y.squeeze(0).transpose(0, 1)) - outputs.append(y) - - if T >= self.L_cache: - final_state = Bx[0, :, -self.L_cache :] + else: + # Prefill: multiple tokens, use varlen kernel + T = hidden_states.shape[0] + Bx_t = Bx.transpose(0, 1).contiguous() + + # Build query_start_loc: [0, cumsum(seq_lens)...] + extend_start_loc = forward_batch.extend_start_loc + if extend_start_loc is not None and len(extend_start_loc) > 1: + query_start_loc = torch.cat([ + extend_start_loc, + torch.tensor([T], dtype=torch.int32, device=hidden_states.device) + ]) + cache_indices = req_pool_indices.to(torch.int32) else: - final_state = F.pad(Bx[0], (self.L_cache - T, 0), value=0.0) - - conv_state.index_copy_( - 0, - req_pool_indices_long[i : i + 1], - final_state.unsqueeze(0).to(conv_state.dtype), - ) - start_idx = end_idx + query_start_loc = torch.tensor([0, T], dtype=torch.int32, device=hidden_states.device) + cache_indices = req_pool_indices[:1].to(torch.int32) - return torch.cat(outputs, dim=0) + conv_out = causal_conv1d_fn( + Bx_t, + self.conv_weight, + self.conv_bias, + query_start_loc=query_start_loc, + cache_indices=cache_indices, + has_initial_state=None, + conv_states=conv_state, + activation=None, + ).transpose(0, 1) - def _forward_decode( - self, - hidden_states: torch.Tensor, - conv_state: torch.Tensor, - req_pool_indices: torch.Tensor, - ) -> torch.Tensor: - req_pool_indices_long = req_pool_indices.long() - - proj = self.in_proj(hidden_states) - B_gate, C_gate, x = proj.chunk(3, dim=-1) - Bx = B_gate * x - - conv_weights = self.conv.weight[:, 0, :] - current_states = conv_state[req_pool_indices_long] - - # Update state: roll left, insert new value at end - new_states = torch.cat( - [current_states[:, :, 1:], Bx.unsqueeze(-1)], dim=-1 - ) - conv_state.index_copy_( - 0, req_pool_indices_long, new_states.to(conv_state.dtype) - ) - - # Apply conv: use last kernel_size values - conv_input = torch.cat( - [current_states[:, :, -(self.conv_kernel - 1) :], Bx.unsqueeze(-1)], dim=-1 - ) - conv_out = (conv_input * conv_weights.unsqueeze(0)).sum(dim=-1) - - if self.bias and self.conv.bias is not None: - conv_out = conv_out + self.conv.bias - - y = C_gate * conv_out - return self.out_proj(y.to(hidden_states.dtype)) + return self.out_proj(C_gate * conv_out) class Lfm2DecoderLayer(nn.Module): @@ -568,6 +501,17 @@ def load_weights( if "embed_tokens.weight" in name: embed_tokens_weight = loaded_weight + # Handle conv.weight -> conv_weight conversion for ShortConv layers + # HF shape: (hidden_size, 1, kernel_size) -> squeeze to (hidden_size, kernel_size) + if ".conv.weight" in name: + name = name.replace(".conv.weight", ".conv_weight") + # Squeeze out the middle dimension: (D, 1, K) -> (D, K) + loaded_weight = loaded_weight.squeeze(1) + + # Handle conv.bias -> conv_bias conversion + if ".conv.bias" in name: + name = name.replace(".conv.bias", ".conv_bias") + # Handle QKV stacking for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: From 6b3ff94af0c6a0ef3fb64530d30e059b20bbdbfc Mon Sep 17 00:00:00 2001 From: Piotr Mazurek Date: Thu, 8 Jan 2026 10:29:26 +0000 Subject: [PATCH 06/14] Use optimized RMSNorm kernel --- python/sglang/srt/models/lfm2.py | 37 +++++++++++++++++--------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/models/lfm2.py b/python/sglang/srt/models/lfm2.py index e8d9fc18cacd..2b2994c55195 100644 --- a/python/sglang/srt/models/lfm2.py +++ b/python/sglang/srt/models/lfm2.py @@ -24,6 +24,7 @@ causal_conv1d_update, ) from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -44,20 +45,22 @@ logger = logging.getLogger(__name__) -class Lfm2RMSNorm(nn.Module): - """LFM2-specific RMSNorm: weight * x (not (1 + weight) * x like Gemma).""" +# We don't use it, we keep it for reference. If we run sglang.srt.layers.layernorm.RMSNorm +# kernel for some reason the difference in logporbs slighlty increases, but to an acceptable degree +# class Lfm2RMSNorm(nn.Module): +# """LFM2-specific RMSNorm: weight * x (not (1 + weight) * x like Gemma).""" - def __init__(self, hidden_size: int, eps: float = 1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps +# def __init__(self, hidden_size: int, eps: float = 1e-6): +# super().__init__() +# self.weight = nn.Parameter(torch.ones(hidden_size)) +# self.variance_epsilon = eps - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return (self.weight * hidden_states).to(input_dtype) +# def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: +# input_dtype = hidden_states.dtype +# hidden_states = hidden_states.to(torch.float32) +# variance = hidden_states.pow(2).mean(-1, keepdim=True) +# hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) +# return (self.weight * hidden_states).to(input_dtype) class Lfm2MLP(nn.Module): @@ -165,8 +168,8 @@ def __init__( prefix=add_prefix("out_proj", prefix), ) - self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) - self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) + self.q_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) + self.k_layernorm = RMSNorm(self.head_dim, eps=config.norm_eps) self.num_local_q_heads = self.qkv_proj.num_heads self.num_local_kv_heads = self.qkv_proj.num_kv_heads @@ -318,8 +321,8 @@ def __init__( self.layer_type = config.layer_types[layer_id] self.is_attention_layer = self.layer_type == "full_attention" - self.operator_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) - self.ffn_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) + self.operator_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) if self.is_attention_layer: self.self_attn = Lfm2Attention( @@ -410,7 +413,7 @@ def get_layer(idx: int, prefix: str, **kwargs): self.layers = make_layers( config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" ) - self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) + self.embedding_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) def forward( self, From a168c427d6c0954f4f644fdc9f769f970ea659a8 Mon Sep 17 00:00:00 2001 From: Piotr Mazurek Date: Thu, 8 Jan 2026 10:31:45 +0000 Subject: [PATCH 07/14] Handle the cuda graph caching issue steaming from the default dtype --- python/sglang/srt/configs/lfm2.py | 14 ++++++++++++-- python/sglang/srt/configs/mamba_utils.py | 15 +++++++++++---- python/sglang/srt/model_executor/model_runner.py | 4 +++- python/sglang/srt/models/lfm2.py | 2 +- 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/configs/lfm2.py b/python/sglang/srt/configs/lfm2.py index b94cd5f11afa..e4c29048557e 100644 --- a/python/sglang/srt/configs/lfm2.py +++ b/python/sglang/srt/configs/lfm2.py @@ -20,7 +20,8 @@ from transformers import CONFIG_MAPPING from transformers.utils import logging -from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape +import torch +from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape, Mamba2StateDType logger = logging.get_logger(__name__) @@ -85,7 +86,16 @@ def mamba2_cache_params(self) -> Optional[Mamba2CacheParams]: conv_kernel=conv_kernel, ) - return Mamba2CacheParams(shape=shape, layers=conv_layer_ids) + # Get conv dtype from torch default (set by model runner before this is called) + # Fall back to bfloat16 if default is float32 + default_dtype = torch.get_default_dtype() + conv_dtype = default_dtype if default_dtype in (torch.float16, torch.bfloat16) else torch.bfloat16 + + return Mamba2CacheParams( + shape=shape, + layers=conv_layer_ids, + dtype=Mamba2StateDType(conv=conv_dtype, temporal=torch.float32), + ) # Override HuggingFace's Lfm2Config with our extended version diff --git a/python/sglang/srt/configs/mamba_utils.py b/python/sglang/srt/configs/mamba_utils.py index d2ff3762b140..6710715117ef 100644 --- a/python/sglang/srt/configs/mamba_utils.py +++ b/python/sglang/srt/configs/mamba_utils.py @@ -10,7 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, etc.""" +"""Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, LFM2, etc.""" import os from abc import ABC @@ -41,16 +41,23 @@ class Mamba2StateDType: temporal: torch.dtype -CONV_DTYPE = torch.bfloat16 +def get_conv_dtype() -> torch.dtype: + """Get conv state dtype - uses torch default dtype which is set by model runner.""" + default_dtype = torch.get_default_dtype() + # Only use float16/bfloat16 for conv state, fall back to bfloat16 for float32 + if default_dtype in (torch.float16, torch.bfloat16): + return default_dtype + return torch.bfloat16 def mamba2_state_dtype() -> Mamba2StateDType: dtype_map = { "float32": torch.float32, "bfloat16": torch.bfloat16, + "float16": torch.float16, } - ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]] - return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype) + ssm_dtype = dtype_map.get(os.environ.get("SGLANG_MAMBA_SSM_DTYPE", "float32"), torch.float32) + return Mamba2StateDType(conv=get_conv_dtype(), temporal=ssm_dtype) @dataclass(kw_only=True, frozen=True) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index dbb732544b17..5325f309f7e8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -545,7 +545,9 @@ def initialize(self, min_per_gpu_memory: float): self.configure_kv_cache_dtype() # Init memory pool and attention backends - self.init_memory_pool(min_per_gpu_memory) + # Set default dtype so mamba2_cache_params picks up the correct dtype for conv state + with set_default_torch_dtype(self.model_config.dtype): + self.init_memory_pool(min_per_gpu_memory) # Init max running requests self.max_running_requests = min( diff --git a/python/sglang/srt/models/lfm2.py b/python/sglang/srt/models/lfm2.py index 2b2994c55195..e99c3e001e28 100644 --- a/python/sglang/srt/models/lfm2.py +++ b/python/sglang/srt/models/lfm2.py @@ -46,7 +46,7 @@ # We don't use it, we keep it for reference. If we run sglang.srt.layers.layernorm.RMSNorm -# kernel for some reason the difference in logporbs slighlty increases, but to an acceptable degree +# kernel the difference in logporbs slighlty increases, but to an acceptable degree # class Lfm2RMSNorm(nn.Module): # """LFM2-specific RMSNorm: weight * x (not (1 + weight) * x like Gemma).""" From 27a59cd2df2b4c95b04f2180dd219bc38c90e5a1 Mon Sep 17 00:00:00 2001 From: Piotr Mazurek Date: Thu, 8 Jan 2026 12:12:03 +0000 Subject: [PATCH 08/14] Add LFM2 to tool_choice tests --- .../function_call/test_tool_choice.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test/registered/openai_server/function_call/test_tool_choice.py b/test/registered/openai_server/function_call/test_tool_choice.py index b12cd70d0ced..6d05c0066a0f 100644 --- a/test/registered/openai_server/function_call/test_tool_choice.py +++ b/test/registered/openai_server/function_call/test_tool_choice.py @@ -855,5 +855,33 @@ def test_complex_parameters_required_non_streaming(self): # cls.tokenizer = get_tokenizer(cls.model) +class TestToolChoiceLfm2(TestToolChoiceLlama32): + """Test tool_choice functionality with LiquidAI LFM2 model""" + + @classmethod + def setUpClass(cls): + cls.flaky_tests = { + "test_multi_tool_scenario_auto", + "test_multi_tool_scenario_required", + } + + cls.model = "LiquidAI/LFM2.5-1.2B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=[ + "--tool-call-parser", + "lfm2", + ], + ) + cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(cls.model) + + if __name__ == "__main__": unittest.main() From 493eb5b62635ca91822b180cee1fe69d3685ed8e Mon Sep 17 00:00:00 2001 From: Piotr Mazurek Date: Sat, 10 Jan 2026 12:11:55 +0000 Subject: [PATCH 09/14] Apply pre-commit formatting fixes (isort, black, typo fix) --- python/sglang/srt/configs/__init__.py | 2 +- python/sglang/srt/configs/lfm2.py | 20 +++++++--- python/sglang/srt/configs/mamba_utils.py | 4 +- .../sglang/srt/function_call/lfm2_detector.py | 39 +++++++++++++------ .../sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/models/lfm2.py | 34 ++++++++++------ .../test_function_call_parser.py | 6 ++- 7 files changed, 75 insertions(+), 32 deletions(-) diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index eadb3f03521a..120da37ec5bd 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -6,12 +6,12 @@ from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.falcon_h1 import FalconH1Config from sglang.srt.configs.janus_pro import MultiModalityConfig -from sglang.srt.configs.lfm2 import Lfm2Config from sglang.srt.configs.jet_nemotron import JetNemotronConfig from sglang.srt.configs.jet_vlm import JetVLMConfig from sglang.srt.configs.kimi_linear import KimiLinearConfig from sglang.srt.configs.kimi_vl import KimiVLConfig from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig +from sglang.srt.configs.lfm2 import Lfm2Config from sglang.srt.configs.longcat_flash import LongcatFlashConfig from sglang.srt.configs.nano_nemotron_vl import NemotronH_Nano_VL_V2_Config from sglang.srt.configs.nemotron_h import NemotronHConfig diff --git a/python/sglang/srt/configs/lfm2.py b/python/sglang/srt/configs/lfm2.py index e4c29048557e..b189c8b69f85 100644 --- a/python/sglang/srt/configs/lfm2.py +++ b/python/sglang/srt/configs/lfm2.py @@ -16,12 +16,16 @@ from typing import List, Optional -from transformers import Lfm2Config as HFLfm2Config +import torch from transformers import CONFIG_MAPPING +from transformers import Lfm2Config as HFLfm2Config from transformers.utils import logging -import torch -from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape, Mamba2StateDType +from sglang.srt.configs.mamba_utils import ( + Mamba2CacheParams, + Mamba2StateDType, + Mamba2StateShape, +) logger = logging.get_logger(__name__) @@ -42,7 +46,9 @@ def full_attention_layer_ids(self) -> List[int]: @property def linear_layer_ids(self) -> List[int]: """Return indices of conv layers for conv state cache.""" - return [i for i, lt in enumerate(self.layer_types) if lt in ("conv", "short_conv")] + return [ + i for i, lt in enumerate(self.layer_types) if lt in ("conv", "short_conv") + ] @property def mamba_chunk_size(self) -> int: @@ -89,7 +95,11 @@ def mamba2_cache_params(self) -> Optional[Mamba2CacheParams]: # Get conv dtype from torch default (set by model runner before this is called) # Fall back to bfloat16 if default is float32 default_dtype = torch.get_default_dtype() - conv_dtype = default_dtype if default_dtype in (torch.float16, torch.bfloat16) else torch.bfloat16 + conv_dtype = ( + default_dtype + if default_dtype in (torch.float16, torch.bfloat16) + else torch.bfloat16 + ) return Mamba2CacheParams( shape=shape, diff --git a/python/sglang/srt/configs/mamba_utils.py b/python/sglang/srt/configs/mamba_utils.py index 6710715117ef..1135a0ba53ab 100644 --- a/python/sglang/srt/configs/mamba_utils.py +++ b/python/sglang/srt/configs/mamba_utils.py @@ -56,7 +56,9 @@ def mamba2_state_dtype() -> Mamba2StateDType: "bfloat16": torch.bfloat16, "float16": torch.float16, } - ssm_dtype = dtype_map.get(os.environ.get("SGLANG_MAMBA_SSM_DTYPE", "float32"), torch.float32) + ssm_dtype = dtype_map.get( + os.environ.get("SGLANG_MAMBA_SSM_DTYPE", "float32"), torch.float32 + ) return Mamba2StateDType(conv=get_conv_dtype(), temporal=ssm_dtype) diff --git a/python/sglang/srt/function_call/lfm2_detector.py b/python/sglang/srt/function_call/lfm2_detector.py index fde95c4e41f0..f4302fdd20ae 100644 --- a/python/sglang/srt/function_call/lfm2_detector.py +++ b/python/sglang/srt/function_call/lfm2_detector.py @@ -102,9 +102,13 @@ def _get_parameter_value(self, val: ast.AST) -> Any: return -inner raise ValueError(f"Cannot negate non-numeric value: {inner}") else: - raise ValueError(f"Tool call arguments must be literals, got: {type(val).__name__}") + raise ValueError( + f"Tool call arguments must be literals, got: {type(val).__name__}" + ) - def _parse_pythonic_call(self, call: ast.Call, call_index: int, tool_indices: Dict[str, int]) -> Optional[ToolCallItem]: + def _parse_pythonic_call( + self, call: ast.Call, call_index: int, tool_indices: Dict[str, int] + ) -> Optional[ToolCallItem]: """ Parse a single AST Call node into a ToolCallItem. @@ -117,14 +121,18 @@ def _parse_pythonic_call(self, call: ast.Call, call_index: int, tool_indices: Di ToolCallItem if successful, None if the call should be skipped """ if not isinstance(call.func, ast.Name): - logger.warning(f"Tool call function must be a simple name, got: {type(call.func).__name__}") + logger.warning( + f"Tool call function must be a simple name, got: {type(call.func).__name__}" + ) return None function_name = call.func.id # Validate that the function exists in the tools if function_name not in tool_indices: - logger.warning(f"Model attempted to call undefined function: {function_name}") + logger.warning( + f"Model attempted to call undefined function: {function_name}" + ) if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get(): return None # Skip unknown tools (default legacy behavior) @@ -147,7 +155,9 @@ def _parse_pythonic_call(self, call: ast.Call, call_index: int, tool_indices: Di parameters=json.dumps(arguments, ensure_ascii=False), ) - def _parse_pythonic_content(self, content: str, tools: List[Tool]) -> Tuple[List[ToolCallItem], str]: + def _parse_pythonic_content( + self, content: str, tools: List[Tool] + ) -> Tuple[List[ToolCallItem], str]: """ Parse Pythonic format tool calls using AST. @@ -174,7 +184,10 @@ def _parse_pythonic_content(self, content: str, tools: List[Tool]) -> Tuple[List elif isinstance(parsed, ast.Call): call_nodes = [parsed] else: - return [], f"Expected function call or list, got: {type(parsed).__name__}" + return ( + [], + f"Expected function call or list, got: {type(parsed).__name__}", + ) # Validate all elements are calls if not all(isinstance(e, ast.Call) for e in call_nodes): @@ -194,7 +207,9 @@ def _parse_pythonic_content(self, content: str, tools: List[Tool]) -> Tuple[List logger.exception("Unexpected error in pythonic tool call parsing") return [], f"Unexpected error: {e}" - def _parse_json_content(self, content: str, tools: List[Tool]) -> Tuple[List[ToolCallItem], str]: + def _parse_json_content( + self, content: str, tools: List[Tool] + ) -> Tuple[List[ToolCallItem], str]: """ Parse JSON format tool calls. @@ -220,7 +235,9 @@ def _parse_json_content(self, content: str, tools: List[Tool]) -> Tuple[List[Too except json.JSONDecodeError as e: return [], f"JSON parse error: {e}" - def _parse_tool_calls_content(self, content: str, tools: List[Tool]) -> List[ToolCallItem]: + def _parse_tool_calls_content( + self, content: str, tools: List[Tool] + ) -> List[ToolCallItem]: """ Parse the content between tool call tags. Handles both JSON and Pythonic formats. @@ -376,8 +393,8 @@ def parse_streaming_increment( return StreamingParseResult(normal_text="") # We have a complete tool call block - tool_call_block = self._buffer[bot_pos:eot_pos + len(self.eot_token)] - remaining = self._buffer[eot_pos + len(self.eot_token):] + tool_call_block = self._buffer[bot_pos : eot_pos + len(self.eot_token)] + remaining = self._buffer[eot_pos + len(self.eot_token) :] # Parse the complete block result = self.detect_and_parse(tool_call_block, tools) @@ -408,7 +425,7 @@ def structure_info(self) -> _GetInfoFunc: supports_structural_tag() returns False. """ return lambda name: StructureInfo( - begin='<|tool_call_start|>[' + name + '(', + begin="<|tool_call_start|>[" + name + "(", end=")]<|tool_call_end|>", trigger="<|tool_call_start|>", ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5325f309f7e8..9fa40b1906aa 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -35,10 +35,10 @@ JetNemotronConfig, JetVLMConfig, KimiLinearConfig, + Lfm2Config, NemotronH_Nano_VL_V2_Config, NemotronHConfig, Qwen3NextConfig, - Lfm2Config, ) from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig, LoadFormat diff --git a/python/sglang/srt/models/lfm2.py b/python/sglang/srt/models/lfm2.py index e99c3e001e28..639acb3819b5 100644 --- a/python/sglang/srt/models/lfm2.py +++ b/python/sglang/srt/models/lfm2.py @@ -19,11 +19,11 @@ from torch import nn from sglang.srt.configs.lfm2 import Lfm2Config +from sglang.srt.distributed import get_pp_group from sglang.srt.layers.attention.mamba.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update, ) -from sglang.srt.distributed import get_pp_group from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -45,8 +45,8 @@ logger = logging.getLogger(__name__) -# We don't use it, we keep it for reference. If we run sglang.srt.layers.layernorm.RMSNorm -# kernel the difference in logporbs slighlty increases, but to an acceptable degree +# We don't use it, we keep it for reference. If we run sglang.srt.layers.layernorm.RMSNorm +# kernel the difference in logprobs slightly increases, but to an acceptable degree # class Lfm2RMSNorm(nn.Module): # """LFM2-specific RMSNorm: weight * x (not (1 + weight) * x like Gemma).""" @@ -237,12 +237,18 @@ def __init__( self.use_bias = bool(config.conv_bias) self.hidden_size = config.hidden_size - self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.use_bias) - self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.use_bias) + self.in_proj = nn.Linear( + config.hidden_size, 3 * config.hidden_size, bias=self.use_bias + ) + self.out_proj = nn.Linear( + config.hidden_size, config.hidden_size, bias=self.use_bias + ) # Conv weights stored in format matching causal_conv1d: (hidden_size, kernel_size) # Weight loading will handle conversion from HF's (hidden_size, 1, kernel_size) - self.conv_weight = nn.Parameter(torch.empty(config.hidden_size, self.conv_kernel)) + self.conv_weight = nn.Parameter( + torch.empty(config.hidden_size, self.conv_kernel) + ) if self.use_bias: self.conv_bias = nn.Parameter(torch.empty(config.hidden_size)) else: @@ -283,13 +289,19 @@ def forward( # Build query_start_loc: [0, cumsum(seq_lens)...] extend_start_loc = forward_batch.extend_start_loc if extend_start_loc is not None and len(extend_start_loc) > 1: - query_start_loc = torch.cat([ - extend_start_loc, - torch.tensor([T], dtype=torch.int32, device=hidden_states.device) - ]) + query_start_loc = torch.cat( + [ + extend_start_loc, + torch.tensor( + [T], dtype=torch.int32, device=hidden_states.device + ), + ] + ) cache_indices = req_pool_indices.to(torch.int32) else: - query_start_loc = torch.tensor([0, T], dtype=torch.int32, device=hidden_states.device) + query_start_loc = torch.tensor( + [0, T], dtype=torch.int32, device=hidden_states.device + ) cache_indices = req_pool_indices[:1].to(torch.int32) conv_out = causal_conv1d_fn( diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py index b7e1fc99f23b..93906dad1f28 100644 --- a/test/registered/function_call/test_function_call_parser.py +++ b/test/registered/function_call/test_function_call_parser.py @@ -5,12 +5,12 @@ from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import StreamingParseResult from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector -from sglang.srt.function_call.lfm2_detector import Lfm2Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector from sglang.srt.function_call.json_array_parser import JsonArrayParser from sglang.srt.function_call.kimik2_detector import KimiK2Detector +from sglang.srt.function_call.lfm2_detector import Lfm2Detector from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.pythonic_detector import PythonicDetector @@ -3151,7 +3151,9 @@ def test_detect_and_parse_with_normal_text_before(self): def test_detect_and_parse_special_characters_in_value(self): """Test parsing with special characters in argument values.""" - text = '<|tool_call_start|>[search(query="what\'s the weather?")]<|tool_call_end|>' + text = ( + '<|tool_call_start|>[search(query="what\'s the weather?")]<|tool_call_end|>' + ) result = self.detector.detect_and_parse(text, self.tools) self.assertEqual(len(result.calls), 1) From 123e04ea599b0c012682aa40462a1ea3f365b496 Mon Sep 17 00:00:00 2001 From: Piotr Mazurek Date: Sun, 11 Jan 2026 08:24:09 +0000 Subject: [PATCH 10/14] Removed unused _find_matching_bracket --- .../sglang/srt/function_call/lfm2_detector.py | 44 ------------------- 1 file changed, 44 deletions(-) diff --git a/python/sglang/srt/function_call/lfm2_detector.py b/python/sglang/srt/function_call/lfm2_detector.py index f4302fdd20ae..80ef9c452333 100644 --- a/python/sglang/srt/function_call/lfm2_detector.py +++ b/python/sglang/srt/function_call/lfm2_detector.py @@ -288,50 +288,6 @@ def _strip_special_tokens(self, text: str) -> str: """Remove special tokens from text.""" return text.replace(self.bot_token, "").replace(self.eot_token, "") - def _find_matching_bracket(self, buffer: str, start: int) -> int: - """ - Find the matching closing bracket for the opening bracket at start position. - Properly handles nested brackets and strings. - - Args: - buffer: The text buffer to search in - start: Position of the opening bracket '[' - - Returns: - Position of the matching closing bracket ']', or -1 if not found - """ - bracket_count = 0 - in_string = False - string_char = None - escape_next = False - - for i in range(start, len(buffer)): - char = buffer[i] - - if escape_next: - escape_next = False - continue - - if char == "\\": - escape_next = True - continue - - if char in ('"', "'") and not in_string: - in_string = True - string_char = char - elif char == string_char and in_string: - in_string = False - string_char = None - elif not in_string: - if char == "[": - bracket_count += 1 - elif char == "]": - bracket_count -= 1 - if bracket_count == 0: - return i - - return -1 # No matching bracket found - def parse_streaming_increment( self, new_text: str, tools: List[Tool] ) -> StreamingParseResult: From 97e613826d1f4bdff19b21b4134900f98d7a5f8c Mon Sep 17 00:00:00 2001 From: Piotr Mazurek Date: Thu, 15 Jan 2026 11:19:02 +0000 Subject: [PATCH 11/14] Use model dtype for LFM2 conv state cache --- python/sglang/srt/configs/lfm2.py | 11 +++-------- python/sglang/srt/configs/mamba_utils.py | 10 ++-------- python/sglang/srt/configs/model_config.py | 4 ++++ python/sglang/srt/model_executor/model_runner.py | 10 +++++++--- 4 files changed, 16 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/configs/lfm2.py b/python/sglang/srt/configs/lfm2.py index b189c8b69f85..6ce5ad97940d 100644 --- a/python/sglang/srt/configs/lfm2.py +++ b/python/sglang/srt/configs/lfm2.py @@ -92,14 +92,9 @@ def mamba2_cache_params(self) -> Optional[Mamba2CacheParams]: conv_kernel=conv_kernel, ) - # Get conv dtype from torch default (set by model runner before this is called) - # Fall back to bfloat16 if default is float32 - default_dtype = torch.get_default_dtype() - conv_dtype = ( - default_dtype - if default_dtype in (torch.float16, torch.bfloat16) - else torch.bfloat16 - ) + # Use runtime dtype from model_config (propagated to hf_config.torch_dtype) + # Conv state must match model's inference dtype + conv_dtype = getattr(self, "torch_dtype", torch.bfloat16) return Mamba2CacheParams( shape=shape, diff --git a/python/sglang/srt/configs/mamba_utils.py b/python/sglang/srt/configs/mamba_utils.py index 1135a0ba53ab..810709fc9596 100644 --- a/python/sglang/srt/configs/mamba_utils.py +++ b/python/sglang/srt/configs/mamba_utils.py @@ -41,13 +41,7 @@ class Mamba2StateDType: temporal: torch.dtype -def get_conv_dtype() -> torch.dtype: - """Get conv state dtype - uses torch default dtype which is set by model runner.""" - default_dtype = torch.get_default_dtype() - # Only use float16/bfloat16 for conv state, fall back to bfloat16 for float32 - if default_dtype in (torch.float16, torch.bfloat16): - return default_dtype - return torch.bfloat16 +CONV_DTYPE = torch.bfloat16 def mamba2_state_dtype() -> Mamba2StateDType: @@ -59,7 +53,7 @@ def mamba2_state_dtype() -> Mamba2StateDType: ssm_dtype = dtype_map.get( os.environ.get("SGLANG_MAMBA_SSM_DTYPE", "float32"), torch.float32 ) - return Mamba2StateDType(conv=get_conv_dtype(), temporal=ssm_dtype) + return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype) @dataclass(kw_only=True, frozen=True) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index aa10cb08d653..e363fc5da2d4 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -191,6 +191,10 @@ def __init__( ) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + # Propagate runtime dtype to hf_config so that hybrid models (mamba, LFM2, etc.) + # can use it for conv state cache dtype + self.hf_config.torch_dtype = self.dtype + # Derive context length and model shapes self._derive_context_length(context_length) self._derive_model_shapes() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9fa40b1906aa..3bb42676a872 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -545,9 +545,7 @@ def initialize(self, min_per_gpu_memory: float): self.configure_kv_cache_dtype() # Init memory pool and attention backends - # Set default dtype so mamba2_cache_params picks up the correct dtype for conv state - with set_default_torch_dtype(self.model_config.dtype): - self.init_memory_pool(min_per_gpu_memory) + self.init_memory_pool(min_per_gpu_memory) # Init max running requests self.max_running_requests = min( @@ -1481,6 +1479,12 @@ def hybrid_gdn_config(self): @property def mamba2_config(self): config = self.model_config.hf_config + if isinstance(config, NemotronHConfig) and self.is_draft_worker: + # NemotronH MTP draft models have no Mamba layers (pattern like "*E") + # so they shouldn't use HybridLinearAttnBackend + pattern = getattr(config, "mtp_hybrid_override_pattern", None) + if pattern is not None and "M" not in pattern: + return None if isinstance(config, FalconH1Config | NemotronHConfig | Lfm2Config): return config if isinstance(config, NemotronH_Nano_VL_V2_Config): From e6d006072169f97d5171e6dcbd5042a041a07f9d Mon Sep 17 00:00:00 2001 From: Piotr Mazurek Date: Sat, 17 Jan 2026 11:16:42 +0000 Subject: [PATCH 12/14] Fix LFM2 on Blackwell (SM100) GPUs LFM2 was failing on B200/SM100 because: 1. SM100 defaults to trtllm_mha backend which forces page_size=64 2. MambaRadixCache requires page_size=1 for hybrid models 3. Triton backend doesn't work because LFM2's first layer is conv, not attention Add Lfm2ForCausalLM to server_args.py with same handling as NemotronH: - Use flashinfer backend on SM100 (supports page_size=1) - Disable overlap schedule with radix cache - Block triton backend (layer 0 is not an attention layer) --- python/sglang/srt/server_args.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 78ec3fc71d30..5a8f234a5669 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1552,6 +1552,33 @@ def _handle_model_specific_adjustments(self): ) self.disable_radix_cache = True self.disable_overlap_schedule = False + elif model_arch in ["Lfm2ForCausalLM"]: + assert ( + not self.enable_mamba_extra_buffer() + ), f"mamba extra_buffer is not supported for {model_arch} model" + if not self.disable_radix_cache: + logger.warning( + "Disabling overlap schedule since mamba no_buffer is not compatible with " + "overlap schedule, try to use --disable-radix-cache if overlap schedule is necessary" + ) + self.disable_overlap_schedule = True + if is_sm100_supported(): + if self.attention_backend is None: + self.attention_backend = "flashinfer" + logger.info( + f"Use flashinfer as attention backend on sm100 for {model_arch}" + ) + if self.attention_backend == "trtllm_mha": + logger.warning( + "Disabling radix cache since trtllm_mha does not support page_size = 1, which is required by MambaRadixCache. " + "Try to use --attention-backend flashinfer if radix cache is necessary." + ) + self.disable_radix_cache = True + self.disable_overlap_schedule = False + assert self.attention_backend != "triton", ( + f"{model_arch} does not support triton attention backend, " + "as the first layer might not be an attention layer" + ) if envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set(): self.disable_overlap_schedule = True From e159dfe450de2a1dee2863e66f373264f70563a3 Mon Sep 17 00:00:00 2001 From: Piotr Mazurek Date: Wed, 21 Jan 2026 15:34:38 +0000 Subject: [PATCH 13/14] Updated way we set up conv type in Mamba2StateDType --- python/sglang/srt/configs/lfm2.py | 14 +++----------- python/sglang/srt/configs/mamba_utils.py | 8 ++++---- python/sglang/srt/configs/model_config.py | 4 ---- 3 files changed, 7 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/configs/lfm2.py b/python/sglang/srt/configs/lfm2.py index 6ce5ad97940d..147feed98d7d 100644 --- a/python/sglang/srt/configs/lfm2.py +++ b/python/sglang/srt/configs/lfm2.py @@ -16,16 +16,11 @@ from typing import List, Optional -import torch from transformers import CONFIG_MAPPING from transformers import Lfm2Config as HFLfm2Config from transformers.utils import logging -from sglang.srt.configs.mamba_utils import ( - Mamba2CacheParams, - Mamba2StateDType, - Mamba2StateShape, -) +from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape logger = logging.get_logger(__name__) @@ -92,14 +87,11 @@ def mamba2_cache_params(self) -> Optional[Mamba2CacheParams]: conv_kernel=conv_kernel, ) - # Use runtime dtype from model_config (propagated to hf_config.torch_dtype) - # Conv state must match model's inference dtype - conv_dtype = getattr(self, "torch_dtype", torch.bfloat16) - + # Uses default mamba2_state_dtype() which reads SGLANG_MAMBA_CONV_DTYPE env var + # (defaults to bfloat16). Set SGLANG_MAMBA_CONV_DTYPE=float16 for fp16 inference. return Mamba2CacheParams( shape=shape, layers=conv_layer_ids, - dtype=Mamba2StateDType(conv=conv_dtype, temporal=torch.float32), ) diff --git a/python/sglang/srt/configs/mamba_utils.py b/python/sglang/srt/configs/mamba_utils.py index 810709fc9596..9e64d752494c 100644 --- a/python/sglang/srt/configs/mamba_utils.py +++ b/python/sglang/srt/configs/mamba_utils.py @@ -41,19 +41,19 @@ class Mamba2StateDType: temporal: torch.dtype -CONV_DTYPE = torch.bfloat16 - - def mamba2_state_dtype() -> Mamba2StateDType: dtype_map = { "float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16, } + conv_dtype = dtype_map.get( + os.environ.get("SGLANG_MAMBA_CONV_DTYPE", "bfloat16"), torch.bfloat16 + ) ssm_dtype = dtype_map.get( os.environ.get("SGLANG_MAMBA_SSM_DTYPE", "float32"), torch.float32 ) - return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype) + return Mamba2StateDType(conv=conv_dtype, temporal=ssm_dtype) @dataclass(kw_only=True, frozen=True) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 458cc8053cff..cd62f3f76d52 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -192,10 +192,6 @@ def __init__( ) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) - # Propagate runtime dtype to hf_config so that hybrid models (mamba, LFM2, etc.) - # can use it for conv state cache dtype - self.hf_config.torch_dtype = self.dtype - # Derive context length and model shapes self._derive_context_length(context_length) self._derive_model_shapes() From c5ee6ea7bc466a74fb55305c3c60d8b979fc42e5 Mon Sep 17 00:00:00 2001 From: Piotr Mazurek Date: Wed, 21 Jan 2026 16:14:57 +0000 Subject: [PATCH 14/14] Set conv dtype variable to mach model dtype in tests --- test/registered/models/test_generation_models.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/registered/models/test_generation_models.py b/test/registered/models/test_generation_models.py index 7f0293390d1f..a180102ec30d 100644 --- a/test/registered/models/test_generation_models.py +++ b/test/registered/models/test_generation_models.py @@ -138,6 +138,12 @@ def assert_close_logits_and_output_strs( ) max_new_tokens = 32 + # Set conv dtype for hybrid models to match inference dtype + dtype_str = {torch.float16: "float16", torch.bfloat16: "bfloat16"}.get( + torch_dtype, "bfloat16" + ) + os.environ["SGLANG_MAMBA_CONV_DTYPE"] = dtype_str + with HFRunner( model_path, torch_dtype=torch_dtype,