Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/sglang/srt/configs/lfm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ def mamba2_cache_params(self) -> Optional[Mamba2CacheParams]:
return None

hidden_size = self.hidden_size
# conv_L_cache in config is kernel_size (e.g., 3)
conv_kernel = int(self.conv_L_cache)
L_cache = conv_kernel - 1 # actual cache size (e.g., 2 for kernel=3)

# get_attention_tp_size() requires initialization, default to 1 if not available
try:
Expand All @@ -77,11 +75,13 @@ def mamba2_cache_params(self) -> Optional[Mamba2CacheParams]:

# For ShortConv layers, we use a simplified Mamba2StateShape
# LFM2 doesn't use SSM state (state_size=0), only conv state
# We pass num_heads=tp_size so divide(tp_size, tp_size)=1 always works.
# Since state_size=0, the temporal state shape has zero elements anyway.
shape = Mamba2StateShape.create(
tp_world_size=tp_size,
intermediate_size=hidden_size,
n_groups=1, # ShortConv doesn't use grouping
num_heads=1, # ShortConv is not multi-head
num_heads=tp_size, # Ensures divide works; temporal state is empty anyway
head_dim=hidden_size, # Conv operates on full hidden dim
state_size=0, # No SSM temporal state for ShortConv
conv_kernel=conv_kernel,
Expand Down
102 changes: 45 additions & 57 deletions python/sglang/srt/models/lfm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
from torch import nn

from sglang.srt.configs.lfm2 import Lfm2Config
from sglang.srt.distributed import get_pp_group
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.attention.mamba.causal_conv1d import (
causal_conv1d_fn,
causal_conv1d_update,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
Expand All @@ -39,30 +40,15 @@
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
sharded_weight_loader,
)
from sglang.srt.utils import add_prefix, make_layers, set_weight_attrs

logger = logging.getLogger(__name__)


# We don't use it, we keep it for reference. If we run sglang.srt.layers.layernorm.RMSNorm
# kernel the difference in logprobs slightly increases, but to an acceptable degree
# class Lfm2RMSNorm(nn.Module):
# """LFM2-specific RMSNorm: weight * x (not (1 + weight) * x like Gemma)."""

# def __init__(self, hidden_size: int, eps: float = 1e-6):
# super().__init__()
# self.weight = nn.Parameter(torch.ones(hidden_size))
# self.variance_epsilon = eps

# def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# input_dtype = hidden_states.dtype
# hidden_states = hidden_states.to(torch.float32)
# variance = hidden_states.pow(2).mean(-1, keepdim=True)
# hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# return (self.weight * hidden_states).to(input_dtype)


class Lfm2MLP(nn.Module):
"""MLP with SwiGLU activation."""

Expand Down Expand Up @@ -122,7 +108,6 @@ def __init__(
self,
config: Lfm2Config,
layer_id: int,
attn_layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
Expand Down Expand Up @@ -221,6 +206,7 @@ class Lfm2ShortConv(nn.Module):
- Uses double gating: B (before conv) and C (after conv)
- Fixed-size cache: stores last (kernel_size - 1) tokens
- Uses causal_conv1d_fn for prefill and causal_conv1d_update for decode
- Supports tensor parallelism: hidden dimension is sharded across TP ranks
"""

def __init__(
Expand All @@ -233,24 +219,39 @@ def __init__(
super().__init__()
self.layer_idx = layer_idx
self.conv_kernel = int(config.conv_L_cache)
self.L_cache = self.conv_kernel - 1
self.use_bias = bool(config.conv_bias)
self.hidden_size = config.hidden_size

self.in_proj = nn.Linear(
config.hidden_size, 3 * config.hidden_size, bias=self.use_bias
tp_size = get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = self.hidden_size // tp_size

# Use MergedColumnParallelLinear so each output (B, C, x) is sharded separately
self.in_proj = MergedColumnParallelLinear(
config.hidden_size,
[config.hidden_size] * 3, # B, C, x each get hidden_size
bias=self.use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
self.out_proj = nn.Linear(
config.hidden_size, config.hidden_size, bias=self.use_bias
self.out_proj = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=self.use_bias,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)

# Conv weights stored in format matching causal_conv1d: (hidden_size, kernel_size)
# Weight loading will handle conversion from HF's (hidden_size, 1, kernel_size)
# Conv weights sharded along hidden dimension: (hidden_size/tp, kernel_size)
self.conv_weight = nn.Parameter(
torch.empty(config.hidden_size, self.conv_kernel)
torch.empty(self.hidden_size_per_partition, self.conv_kernel)
)
set_weight_attrs(self.conv_weight, {"weight_loader": sharded_weight_loader(0)})
if self.use_bias:
self.conv_bias = nn.Parameter(torch.empty(config.hidden_size))
self.conv_bias = nn.Parameter(torch.empty(self.hidden_size_per_partition))
set_weight_attrs(
self.conv_bias, {"weight_loader": sharded_weight_loader(0)}
)
else:
self.register_parameter("conv_bias", None)

Expand All @@ -267,7 +268,7 @@ def forward(
req_pool_indices = forward_batch.req_pool_indices

# Project and split into gates: B (pre-conv), C (post-conv), x (input)
proj = self.in_proj(hidden_states)
proj, _ = self.in_proj(hidden_states)
B_gate, C_gate, x = proj.chunk(3, dim=-1)
Bx = B_gate * x

Expand Down Expand Up @@ -315,7 +316,8 @@ def forward(
activation=None,
).transpose(0, 1)

return self.out_proj(C_gate * conv_out)
output, _ = self.out_proj(C_gate * conv_out)
return output


class Lfm2DecoderLayer(nn.Module):
Expand All @@ -325,7 +327,6 @@ def __init__(
self,
config: Lfm2Config,
layer_id: int,
attn_layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
Expand All @@ -340,7 +341,6 @@ def __init__(
self.self_attn = Lfm2Attention(
config=config,
layer_id=layer_id,
attn_layer_id=attn_layer_id,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
)
Expand Down Expand Up @@ -401,23 +401,15 @@ def __init__(
prefix=add_prefix("embed_tokens", prefix),
)

# Compute attention layer IDs for KV cache
attn_layer_ids = []
attn_count = 0
for layer_type in config.layer_types:
if layer_type == "full_attention":
attn_layer_ids.append(attn_count)
attn_count += 1
else:
attn_layer_ids.append(-1)

self.num_attention_layers = attn_count
# Count attention layers for KV cache sizing
self.num_attention_layers = sum(
1 for lt in config.layer_types if lt == "full_attention"
)

def get_layer(idx: int, prefix: str, **kwargs):
return Lfm2DecoderLayer(
config=config,
layer_id=idx,
attn_layer_id=attn_layer_ids[idx],
quant_config=quant_config,
prefix=prefix,
)
Expand Down Expand Up @@ -516,16 +508,12 @@ def load_weights(
if "embed_tokens.weight" in name:
embed_tokens_weight = loaded_weight

# Handle conv.weight -> conv_weight conversion for ShortConv layers
# HF shape: (hidden_size, 1, kernel_size) -> squeeze to (hidden_size, kernel_size)
if ".conv.weight" in name:
name = name.replace(".conv.weight", ".conv_weight")
# Squeeze out the middle dimension: (D, 1, K) -> (D, K)
loaded_weight = loaded_weight.squeeze(1)

# Handle conv.bias -> conv_bias conversion
if ".conv.bias" in name:
name = name.replace(".conv.bias", ".conv_bias")
# Handle conv weight/bias naming: HF uses conv.conv, we use conv_weight/conv_bias
if ".conv.conv.weight" in name:
name = name.replace(".conv.conv.weight", ".conv.conv_weight")
loaded_weight = loaded_weight.squeeze(1) # (D, 1, K) -> (D, K)
if ".conv.conv.bias" in name:
name = name.replace(".conv.conv.bias", ".conv.conv_bias")
Comment on lines +512 to +516
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The change in weight naming from .conv.weight to .conv.conv.weight and .conv.bias to .conv.conv.bias suggests a discrepancy between the internal naming convention of the SGLang model and the HuggingFace checkpoint. While this fix addresses the loading issue, it would be beneficial to add a comment explaining this specific naming adaptation, especially if it's a common pattern for LFM2 models or a known quirk of the upstream checkpoint. This improves clarity for future maintainers.


# Handle QKV stacking
for param_name, weight_name, shard_id in stacked_params_mapping:
Expand Down
Loading