Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/sglang/srt/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sglang.srt.configs.kimi_linear import KimiLinearConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.lfm2 import Lfm2Config
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
from sglang.srt.configs.nano_nemotron_vl import NemotronH_Nano_VL_V2_Config
from sglang.srt.configs.nemotron_h import NemotronHConfig
Expand Down Expand Up @@ -42,6 +43,7 @@
"DotsVLMConfig",
"DotsOCRConfig",
"FalconH1Config",
"Lfm2Config",
"NemotronHConfig",
"NemotronH_Nano_VL_V2_Config",
"JetNemotronConfig",
Expand Down
102 changes: 102 additions & 0 deletions python/sglang/srt/configs/lfm2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# coding=utf-8
# Copyright 2024 Liquid AI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LFM2 (Liquid Foundation Model 2) configuration"""

from typing import List, Optional

from transformers import CONFIG_MAPPING
from transformers import Lfm2Config as HFLfm2Config
from transformers.utils import logging

from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape

logger = logging.get_logger(__name__)


class Lfm2Config(HFLfm2Config):
"""
SGLang configuration for LFM2 models.

Extends HuggingFace's Lfm2Config with hybrid model properties needed by SGLang.
LFM2 uses a hybrid architecture mixing full attention and ShortConv layers.
"""

@property
def full_attention_layer_ids(self) -> List[int]:
"""Return indices of attention layers for KV cache."""
return [i for i, lt in enumerate(self.layer_types) if lt == "full_attention"]

@property
def linear_layer_ids(self) -> List[int]:
"""Return indices of conv layers for conv state cache."""
return [
i for i, lt in enumerate(self.layer_types) if lt in ("conv", "short_conv")
]

@property
def mamba_chunk_size(self) -> int:
"""Return chunk size for Mamba2 backend. LFM2 doesn't use chunking, return 1."""
return 1

@property
def mamba2_cache_params(self) -> Optional[Mamba2CacheParams]:
"""
Get cache params for HybridReqToTokenPool initialization.

LFM2 uses ShortConv layers with a small fixed-size cache (kernel_size - 1).
Unlike full Mamba2 models, LFM2 only uses the conv state, not SSM temporal state.
"""
from sglang.srt.layers.dp_attention import get_attention_tp_size

conv_layer_ids = self.linear_layer_ids
if not conv_layer_ids:
return None

hidden_size = self.hidden_size
# conv_L_cache in config is kernel_size (e.g., 3)
conv_kernel = int(self.conv_L_cache)
L_cache = conv_kernel - 1 # actual cache size (e.g., 2 for kernel=3)

# get_attention_tp_size() requires initialization, default to 1 if not available
try:
tp_size = get_attention_tp_size()
except (AssertionError, RuntimeError):
tp_size = 1

# For ShortConv layers, we use a simplified Mamba2StateShape
# LFM2 doesn't use SSM state (state_size=0), only conv state
shape = Mamba2StateShape.create(
tp_world_size=tp_size,
Copy link
Copy Markdown
Collaborator

@yizhang2077 yizhang2077 Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to refactor it later but it is ok for current pr. I think ShortConv-only models being mixed with mamba models is tricky here. cc @ispobock @hebiao064

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need to do some refactor later.

intermediate_size=hidden_size,
n_groups=1, # ShortConv doesn't use grouping
num_heads=1, # ShortConv is not multi-head
head_dim=hidden_size, # Conv operates on full hidden dim
state_size=0, # No SSM temporal state for ShortConv
conv_kernel=conv_kernel,
)

# Uses default mamba2_state_dtype() which reads SGLANG_MAMBA_CONV_DTYPE env var
# (defaults to bfloat16). Set SGLANG_MAMBA_CONV_DTYPE=float16 for fp16 inference.
return Mamba2CacheParams(
shape=shape,
layers=conv_layer_ids,
)


# Override HuggingFace's Lfm2Config with our extended version
# Cannot use .register() because lfm2 is already registered by transformers
# Directly modify the internal _extra_content dict instead
CONFIG_MAPPING._extra_content["lfm2"] = Lfm2Config
logger.info("Registered SGLang Lfm2Config to override HuggingFace's version")
15 changes: 9 additions & 6 deletions python/sglang/srt/configs/mamba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, etc."""
"""Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, LFM2, etc."""

import os
from abc import ABC
Expand Down Expand Up @@ -41,16 +41,19 @@ class Mamba2StateDType:
temporal: torch.dtype


CONV_DTYPE = torch.bfloat16


def mamba2_state_dtype() -> Mamba2StateDType:
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype)
conv_dtype = dtype_map.get(
os.environ.get("SGLANG_MAMBA_CONV_DTYPE", "bfloat16"), torch.bfloat16
)
ssm_dtype = dtype_map.get(
os.environ.get("SGLANG_MAMBA_SSM_DTYPE", "float32"), torch.float32
)
return Mamba2StateDType(conv=conv_dtype, temporal=ssm_dtype)


@dataclass(kw_only=True, frozen=True)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/function_call/function_call_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sglang.srt.function_call.gpt_oss_detector import GptOssDetector
from sglang.srt.function_call.internlm_detector import InternlmDetector
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.lfm2_detector import Lfm2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mimo_detector import MiMoDetector
from sglang.srt.function_call.minimax_m2 import MinimaxM2Detector
Expand Down Expand Up @@ -50,6 +51,7 @@ class FunctionCallParser:
"glm47": Glm47MoeDetector,
"gpt-oss": GptOssDetector,
"kimi_k2": KimiK2Detector,
"lfm2": Lfm2Detector,
"llama3": Llama32Detector,
"mimo": MiMoDetector,
"mistral": MistralDetector,
Expand Down
Loading
Loading