Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions python/sglang/srt/configs/linear_attn_model_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Registry for linear attention hybrid models (softmax + linear attention).

External models can register themselves without modifying SGLang core files:

from sglang.srt.configs.linear_attn_model_registry import (
register_linear_attn_model, LinearAttnModelSpec,
)

register_linear_attn_model(LinearAttnModelSpec(
config_class=MyLinearAttnConfig,
backend_class_name="sglang.srt.layers.attention.linear.kda_backend.KDAAttnBackend",
arch_names=["MyLinearAttnForCausalLM"],
uses_mamba_radix_cache=True,
support_mamba_cache=True,
))
"""

from __future__ import annotations

import importlib
import logging
from dataclasses import dataclass, field
from typing import Any, Optional

logger = logging.getLogger(__name__)


@dataclass
class LinearAttnModelSpec:
"""Specification for a hybrid (softmax + linear attention) model."""

config_class: type
backend_class_name: str # fully-qualified class name, lazily imported
arch_names: list[str] = field(default_factory=list)
uses_mamba_radix_cache: bool = True
support_mamba_cache: bool = True
support_mamba_cache_extra_buffer: bool = False
unwrap_text_config: bool = False # call get_text_config() before isinstance check


_LINEAR_ATTN_MODEL_REGISTRY: list[LinearAttnModelSpec] = []


def register_linear_attn_model(spec: LinearAttnModelSpec) -> None:
_LINEAR_ATTN_MODEL_REGISTRY.append(spec)
logger.info(
"Registered linear attn model: config=%s, backend=%s, archs=%s",
spec.config_class.__name__,
spec.backend_class_name.rsplit(".", 1)[-1],
spec.arch_names,
)


def get_linear_attn_config(hf_config: Any) -> Optional[tuple[LinearAttnModelSpec, Any]]:
for spec in _LINEAR_ATTN_MODEL_REGISTRY:
config = hf_config.get_text_config() if spec.unwrap_text_config else hf_config
if isinstance(config, spec.config_class):
return spec, config
return None


def get_linear_attn_spec_by_arch(arch_name: str) -> Optional[LinearAttnModelSpec]:
for spec in _LINEAR_ATTN_MODEL_REGISTRY:
if arch_name in spec.arch_names:
return spec
return None


def import_backend_class(dotted_name: str) -> type:
module_path, class_name = dotted_name.rsplit(".", 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)
19 changes: 16 additions & 3 deletions python/sglang/srt/layers/attention/attention_registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import logging
from typing import TYPE_CHECKING

from sglang.srt.configs.linear_attn_model_registry import (
get_linear_attn_config,
import_backend_class,
)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -225,9 +230,17 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
elif runner.hybrid_lightning_config is not None:
linear_attn_backend = LightningAttentionBackend(runner)
else:
raise ValueError(
"Expected hybrid GDN or NemotronH models, but got unknown model."
)
spec_result = get_linear_attn_config(runner.model_config.hf_config)
if spec_result is not None:
spec, _ = spec_result
BackendClass = import_backend_class(spec.backend_class_name)
linear_attn_backend = BackendClass(runner)
else:
raise ValueError(
"Expected hybrid GDN or NemotronH models, but got unknown model. "
"If this is a custom hybrid model, use register_linear_attn_model() "
"from sglang.srt.configs.linear_attn_model_registry."
)
full_attn_layers = cfg.full_attention_layer_ids
return HybridLinearAttnBackend(
full_attn_backend, linear_attn_backend, full_attn_layers
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
if (
model_runner.hybrid_gdn_config is not None
or model_runner.kimi_linear_config is not None
or model_runner.linear_attn_model_spec is not None
):
# For hybrid linear models, layer_id = 0 may not be full attention
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,9 +706,14 @@ def init_cache_with_memory_pool(self):

# Hybrid memory pool
self.is_hybrid_swa = self.tp_worker.is_hybrid_swa
_spec = self.tp_worker.model_runner.linear_attn_model_spec
_registry_needs_mamba = (
_spec.uses_mamba_radix_cache if _spec is not None else False
)
self.is_hybrid_ssm = (
self.tp_worker.model_runner.hybrid_gdn_config is not None
or self.tp_worker.model_runner.mamba2_config is not None
or _registry_needs_mamba
)

self.sliding_window_size = None
Expand Down
19 changes: 18 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
Qwen3NextConfig,
)
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.linear_attn_model_registry import get_linear_attn_config
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import AttentionArch, ModelConfig, ModelImpl
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
Expand Down Expand Up @@ -1890,14 +1891,30 @@ def kimi_linear_config(self):
return config
return None

def _get_linear_attn_registry_result(self):
if not hasattr(self, "_linear_attn_registry_cache"):
self._linear_attn_registry_cache = get_linear_attn_config(
self.model_config.hf_config
)
return self._linear_attn_registry_cache

@property
def linear_attn_model_spec(self):
result = self._get_linear_attn_registry_result()
return result[0] if result else None

@property
def mambaish_config(self):
return (
existing = (
self.mamba2_config
or self.hybrid_gdn_config
or self.kimi_linear_config
or self.hybrid_lightning_config
)
if existing:
return existing
result = self._get_linear_attn_registry_result()
return result[1] if result else None

def configure_kv_cache_dtype(self):
if self.server_args.kv_cache_dtype == "auto":
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import tempfile
from typing import Any, Callable, Dict, List, Literal, Optional, Union

from sglang.srt.configs.linear_attn_model_registry import get_linear_attn_spec_by_arch
from sglang.srt.connector import ConnectorType
from sglang.srt.environ import envs
from sglang.srt.function_call.function_call_parser import FunctionCallParser
Expand Down Expand Up @@ -1501,6 +1502,14 @@ def _handle_model_specific_adjustments(self):
hf_config = self.get_model_config().hf_config
model_arch = hf_config.architectures[0]

_hybrid_spec = get_linear_attn_spec_by_arch(model_arch)
if _hybrid_spec is not None:
self._handle_mamba_radix_cache(
model_arch=model_arch,
support_mamba_cache=_hybrid_spec.support_mamba_cache,
support_mamba_cache_extra_buffer=_hybrid_spec.support_mamba_cache_extra_buffer,
)

if model_arch in [
"MistralLarge3ForCausalLM",
"PixtralForConditionalGeneration",
Expand Down
161 changes: 161 additions & 0 deletions test/registered/unit/configs/test_linear_attn_model_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Unit tests for srt/configs/linear_attn_model_registry.py"""

import unittest

from sglang.srt.configs.linear_attn_model_registry import (
_LINEAR_ATTN_MODEL_REGISTRY,
LinearAttnModelSpec,
get_linear_attn_config,
get_linear_attn_spec_by_arch,
import_backend_class,
register_linear_attn_model,
)
from sglang.test.ci.ci_register import register_cpu_ci
from sglang.test.test_utils import CustomTestCase

register_cpu_ci(est_time=5, suite="stage-a-test-cpu")


# Dummy config classes for testing
class FakeLinearAttnConfig:
full_attention_layer_ids = [0, 2, 4]


class FakeVLMWrapperConfig:
"""Simulates a VLM wrapper that has get_text_config()."""

def __init__(self):
self._text_config = FakeLinearAttnConfig()

def get_text_config(self):
return self._text_config


class AnotherConfig:
pass


class TestLinearAttnModelRegistry(CustomTestCase):
def setUp(self):
# Save and clear the global registry between tests
self._saved_registry = list(_LINEAR_ATTN_MODEL_REGISTRY)
_LINEAR_ATTN_MODEL_REGISTRY.clear()

def tearDown(self):
_LINEAR_ATTN_MODEL_REGISTRY.clear()
_LINEAR_ATTN_MODEL_REGISTRY.extend(self._saved_registry)

def _make_spec(self, **overrides):
defaults = dict(
config_class=FakeLinearAttnConfig,
backend_class_name="sglang.srt.layers.attention.triton_backend.TritonAttnBackend",
arch_names=["FakeModelForCausalLM"],
)
defaults.update(overrides)
return LinearAttnModelSpec(**defaults)

def test_register_and_lookup_by_config(self):
spec = self._make_spec()
register_linear_attn_model(spec)

hf_config = FakeLinearAttnConfig()
result = get_linear_attn_config(hf_config)
self.assertIsNotNone(result)
self.assertIs(result[0], spec)
self.assertIs(result[1], hf_config)

def test_lookup_no_match(self):
spec = self._make_spec()
register_linear_attn_model(spec)

result = get_linear_attn_config(AnotherConfig())
self.assertIsNone(result)

def test_lookup_empty_registry(self):
result = get_linear_attn_config(FakeLinearAttnConfig())
self.assertIsNone(result)

def test_unwrap_text_config(self):
spec = self._make_spec(unwrap_text_config=True)
register_linear_attn_model(spec)

vlm_config = FakeVLMWrapperConfig()
result = get_linear_attn_config(vlm_config)
self.assertIsNotNone(result)
self.assertIs(result[0], spec)
# The resolved config should be the inner text config
self.assertIsInstance(result[1], FakeLinearAttnConfig)
self.assertIs(result[1], vlm_config._text_config)

def test_unwrap_text_config_no_match(self):
"""unwrap_text_config=False should not call get_text_config()."""
spec = self._make_spec(unwrap_text_config=False)
register_linear_attn_model(spec)

vlm_config = FakeVLMWrapperConfig()
# VLM wrapper itself is not a FakeLinearAttnConfig, so no match
result = get_linear_attn_config(vlm_config)
self.assertIsNone(result)

def test_lookup_by_arch(self):
spec = self._make_spec(arch_names=["AlphaForCausalLM", "BetaForCausalLM"])
register_linear_attn_model(spec)

self.assertIs(get_linear_attn_spec_by_arch("AlphaForCausalLM"), spec)
self.assertIs(get_linear_attn_spec_by_arch("BetaForCausalLM"), spec)
self.assertIsNone(get_linear_attn_spec_by_arch("GammaForCausalLM"))

def test_lookup_by_arch_empty_registry(self):
self.assertIsNone(get_linear_attn_spec_by_arch("AnyArch"))

def test_multiple_registrations(self):
spec_a = self._make_spec(
config_class=FakeLinearAttnConfig,
arch_names=["AlphaForCausalLM"],
)
spec_b = self._make_spec(
config_class=AnotherConfig,
arch_names=["BetaForCausalLM"],
)
register_linear_attn_model(spec_a)
register_linear_attn_model(spec_b)

# Config-based lookup
self.assertIs(get_linear_attn_config(FakeLinearAttnConfig())[0], spec_a)
self.assertIs(get_linear_attn_config(AnotherConfig())[0], spec_b)

# Arch-based lookup
self.assertIs(get_linear_attn_spec_by_arch("AlphaForCausalLM"), spec_a)
self.assertIs(get_linear_attn_spec_by_arch("BetaForCausalLM"), spec_b)

def test_first_match_wins(self):
"""When two specs match the same config class, the first registered wins."""
spec1 = self._make_spec(backend_class_name="pkg.Backend1")
spec2 = self._make_spec(backend_class_name="pkg.Backend2")
register_linear_attn_model(spec1)
register_linear_attn_model(spec2)

result = get_linear_attn_config(FakeLinearAttnConfig())
self.assertIs(result[0], spec1)

def test_import_backend_class(self):
# Import a real stdlib class to verify the mechanism
cls = import_backend_class("collections.OrderedDict")
from collections import OrderedDict

self.assertIs(cls, OrderedDict)

def test_spec_defaults(self):
spec = LinearAttnModelSpec(
config_class=FakeLinearAttnConfig,
backend_class_name="pkg.mod.Cls",
)
self.assertEqual(spec.arch_names, [])
self.assertTrue(spec.uses_mamba_radix_cache)
self.assertTrue(spec.support_mamba_cache)
self.assertFalse(spec.support_mamba_cache_extra_buffer)
self.assertFalse(spec.unwrap_text_config)


if __name__ == "__main__":
unittest.main()
Loading