diff --git a/python/sglang/srt/configs/linear_attn_model_registry.py b/python/sglang/srt/configs/linear_attn_model_registry.py new file mode 100644 index 000000000000..33fdae8f0783 --- /dev/null +++ b/python/sglang/srt/configs/linear_attn_model_registry.py @@ -0,0 +1,72 @@ +"""Registry for linear attention hybrid models (softmax + linear attention). + +External models can register themselves without modifying SGLang core files: + + from sglang.srt.configs.linear_attn_model_registry import ( + register_linear_attn_model, LinearAttnModelSpec, + ) + + register_linear_attn_model(LinearAttnModelSpec( + config_class=MyLinearAttnConfig, + backend_class_name="sglang.srt.layers.attention.linear.kda_backend.KDAAttnBackend", + arch_names=["MyLinearAttnForCausalLM"], + uses_mamba_radix_cache=True, + support_mamba_cache=True, + )) +""" + +from __future__ import annotations + +import importlib +import logging +from dataclasses import dataclass, field +from typing import Any, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class LinearAttnModelSpec: + """Specification for a hybrid (softmax + linear attention) model.""" + + config_class: type + backend_class_name: str # fully-qualified class name, lazily imported + arch_names: list[str] = field(default_factory=list) + uses_mamba_radix_cache: bool = True + support_mamba_cache: bool = True + support_mamba_cache_extra_buffer: bool = False + unwrap_text_config: bool = False # call get_text_config() before isinstance check + + +_LINEAR_ATTN_MODEL_REGISTRY: list[LinearAttnModelSpec] = [] + + +def register_linear_attn_model(spec: LinearAttnModelSpec) -> None: + _LINEAR_ATTN_MODEL_REGISTRY.append(spec) + logger.info( + "Registered linear attn model: config=%s, backend=%s, archs=%s", + spec.config_class.__name__, + spec.backend_class_name.rsplit(".", 1)[-1], + spec.arch_names, + ) + + +def get_linear_attn_config(hf_config: Any) -> Optional[tuple[LinearAttnModelSpec, Any]]: + for spec in _LINEAR_ATTN_MODEL_REGISTRY: + config = hf_config.get_text_config() if spec.unwrap_text_config else hf_config + if isinstance(config, spec.config_class): + return spec, config + return None + + +def get_linear_attn_spec_by_arch(arch_name: str) -> Optional[LinearAttnModelSpec]: + for spec in _LINEAR_ATTN_MODEL_REGISTRY: + if arch_name in spec.arch_names: + return spec + return None + + +def import_backend_class(dotted_name: str) -> type: + module_path, class_name = dotted_name.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index 2353c15993fd..0a5920575c0f 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -1,6 +1,11 @@ import logging from typing import TYPE_CHECKING +from sglang.srt.configs.linear_attn_model_registry import ( + get_linear_attn_config, + import_backend_class, +) + logger = logging.getLogger(__name__) @@ -225,9 +230,17 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac elif runner.hybrid_lightning_config is not None: linear_attn_backend = LightningAttentionBackend(runner) else: - raise ValueError( - "Expected hybrid GDN or NemotronH models, but got unknown model." - ) + spec_result = get_linear_attn_config(runner.model_config.hf_config) + if spec_result is not None: + spec, _ = spec_result + BackendClass = import_backend_class(spec.backend_class_name) + linear_attn_backend = BackendClass(runner) + else: + raise ValueError( + "Expected hybrid GDN or NemotronH models, but got unknown model. " + "If this is a custom hybrid model, use register_linear_attn_model() " + "from sglang.srt.configs.linear_attn_model_registry." + ) full_attn_layers = cfg.full_attention_layer_ids return HybridLinearAttnBackend( full_attn_backend, linear_attn_backend, full_attn_layers diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 18ed55572cfe..e5705a03781a 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -97,6 +97,7 @@ def __init__( if ( model_runner.hybrid_gdn_config is not None or model_runner.kimi_linear_config is not None + or model_runner.linear_attn_model_spec is not None ): # For hybrid linear models, layer_id = 0 may not be full attention self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 060fdbe4be7d..5a8779f30f30 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -706,9 +706,14 @@ def init_cache_with_memory_pool(self): # Hybrid memory pool self.is_hybrid_swa = self.tp_worker.is_hybrid_swa + _spec = self.tp_worker.model_runner.linear_attn_model_spec + _registry_needs_mamba = ( + _spec.uses_mamba_radix_cache if _spec is not None else False + ) self.is_hybrid_ssm = ( self.tp_worker.model_runner.hybrid_gdn_config is not None or self.tp_worker.model_runner.mamba2_config is not None + or _registry_needs_mamba ) self.sliding_window_size = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a59742b94354..669cab133c49 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -50,6 +50,7 @@ Qwen3NextConfig, ) from sglang.srt.configs.device_config import DeviceConfig +from sglang.srt.configs.linear_attn_model_registry import get_linear_attn_config from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.model_config import AttentionArch, ModelConfig, ModelImpl from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp @@ -1890,14 +1891,30 @@ def kimi_linear_config(self): return config return None + def _get_linear_attn_registry_result(self): + if not hasattr(self, "_linear_attn_registry_cache"): + self._linear_attn_registry_cache = get_linear_attn_config( + self.model_config.hf_config + ) + return self._linear_attn_registry_cache + + @property + def linear_attn_model_spec(self): + result = self._get_linear_attn_registry_result() + return result[0] if result else None + @property def mambaish_config(self): - return ( + existing = ( self.mamba2_config or self.hybrid_gdn_config or self.kimi_linear_config or self.hybrid_lightning_config ) + if existing: + return existing + result = self._get_linear_attn_registry_result() + return result[1] if result else None def configure_kv_cache_dtype(self): if self.server_args.kv_cache_dtype == "auto": diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 57dbc1998b90..b46d83d70c13 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -26,6 +26,7 @@ import tempfile from typing import Any, Callable, Dict, List, Literal, Optional, Union +from sglang.srt.configs.linear_attn_model_registry import get_linear_attn_spec_by_arch from sglang.srt.connector import ConnectorType from sglang.srt.environ import envs from sglang.srt.function_call.function_call_parser import FunctionCallParser @@ -1501,6 +1502,14 @@ def _handle_model_specific_adjustments(self): hf_config = self.get_model_config().hf_config model_arch = hf_config.architectures[0] + _hybrid_spec = get_linear_attn_spec_by_arch(model_arch) + if _hybrid_spec is not None: + self._handle_mamba_radix_cache( + model_arch=model_arch, + support_mamba_cache=_hybrid_spec.support_mamba_cache, + support_mamba_cache_extra_buffer=_hybrid_spec.support_mamba_cache_extra_buffer, + ) + if model_arch in [ "MistralLarge3ForCausalLM", "PixtralForConditionalGeneration", diff --git a/test/registered/unit/configs/test_linear_attn_model_registry.py b/test/registered/unit/configs/test_linear_attn_model_registry.py new file mode 100644 index 000000000000..6fbd27d0ae7e --- /dev/null +++ b/test/registered/unit/configs/test_linear_attn_model_registry.py @@ -0,0 +1,161 @@ +"""Unit tests for srt/configs/linear_attn_model_registry.py""" + +import unittest + +from sglang.srt.configs.linear_attn_model_registry import ( + _LINEAR_ATTN_MODEL_REGISTRY, + LinearAttnModelSpec, + get_linear_attn_config, + get_linear_attn_spec_by_arch, + import_backend_class, + register_linear_attn_model, +) +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=5, suite="stage-a-test-cpu") + + +# Dummy config classes for testing +class FakeLinearAttnConfig: + full_attention_layer_ids = [0, 2, 4] + + +class FakeVLMWrapperConfig: + """Simulates a VLM wrapper that has get_text_config().""" + + def __init__(self): + self._text_config = FakeLinearAttnConfig() + + def get_text_config(self): + return self._text_config + + +class AnotherConfig: + pass + + +class TestLinearAttnModelRegistry(CustomTestCase): + def setUp(self): + # Save and clear the global registry between tests + self._saved_registry = list(_LINEAR_ATTN_MODEL_REGISTRY) + _LINEAR_ATTN_MODEL_REGISTRY.clear() + + def tearDown(self): + _LINEAR_ATTN_MODEL_REGISTRY.clear() + _LINEAR_ATTN_MODEL_REGISTRY.extend(self._saved_registry) + + def _make_spec(self, **overrides): + defaults = dict( + config_class=FakeLinearAttnConfig, + backend_class_name="sglang.srt.layers.attention.triton_backend.TritonAttnBackend", + arch_names=["FakeModelForCausalLM"], + ) + defaults.update(overrides) + return LinearAttnModelSpec(**defaults) + + def test_register_and_lookup_by_config(self): + spec = self._make_spec() + register_linear_attn_model(spec) + + hf_config = FakeLinearAttnConfig() + result = get_linear_attn_config(hf_config) + self.assertIsNotNone(result) + self.assertIs(result[0], spec) + self.assertIs(result[1], hf_config) + + def test_lookup_no_match(self): + spec = self._make_spec() + register_linear_attn_model(spec) + + result = get_linear_attn_config(AnotherConfig()) + self.assertIsNone(result) + + def test_lookup_empty_registry(self): + result = get_linear_attn_config(FakeLinearAttnConfig()) + self.assertIsNone(result) + + def test_unwrap_text_config(self): + spec = self._make_spec(unwrap_text_config=True) + register_linear_attn_model(spec) + + vlm_config = FakeVLMWrapperConfig() + result = get_linear_attn_config(vlm_config) + self.assertIsNotNone(result) + self.assertIs(result[0], spec) + # The resolved config should be the inner text config + self.assertIsInstance(result[1], FakeLinearAttnConfig) + self.assertIs(result[1], vlm_config._text_config) + + def test_unwrap_text_config_no_match(self): + """unwrap_text_config=False should not call get_text_config().""" + spec = self._make_spec(unwrap_text_config=False) + register_linear_attn_model(spec) + + vlm_config = FakeVLMWrapperConfig() + # VLM wrapper itself is not a FakeLinearAttnConfig, so no match + result = get_linear_attn_config(vlm_config) + self.assertIsNone(result) + + def test_lookup_by_arch(self): + spec = self._make_spec(arch_names=["AlphaForCausalLM", "BetaForCausalLM"]) + register_linear_attn_model(spec) + + self.assertIs(get_linear_attn_spec_by_arch("AlphaForCausalLM"), spec) + self.assertIs(get_linear_attn_spec_by_arch("BetaForCausalLM"), spec) + self.assertIsNone(get_linear_attn_spec_by_arch("GammaForCausalLM")) + + def test_lookup_by_arch_empty_registry(self): + self.assertIsNone(get_linear_attn_spec_by_arch("AnyArch")) + + def test_multiple_registrations(self): + spec_a = self._make_spec( + config_class=FakeLinearAttnConfig, + arch_names=["AlphaForCausalLM"], + ) + spec_b = self._make_spec( + config_class=AnotherConfig, + arch_names=["BetaForCausalLM"], + ) + register_linear_attn_model(spec_a) + register_linear_attn_model(spec_b) + + # Config-based lookup + self.assertIs(get_linear_attn_config(FakeLinearAttnConfig())[0], spec_a) + self.assertIs(get_linear_attn_config(AnotherConfig())[0], spec_b) + + # Arch-based lookup + self.assertIs(get_linear_attn_spec_by_arch("AlphaForCausalLM"), spec_a) + self.assertIs(get_linear_attn_spec_by_arch("BetaForCausalLM"), spec_b) + + def test_first_match_wins(self): + """When two specs match the same config class, the first registered wins.""" + spec1 = self._make_spec(backend_class_name="pkg.Backend1") + spec2 = self._make_spec(backend_class_name="pkg.Backend2") + register_linear_attn_model(spec1) + register_linear_attn_model(spec2) + + result = get_linear_attn_config(FakeLinearAttnConfig()) + self.assertIs(result[0], spec1) + + def test_import_backend_class(self): + # Import a real stdlib class to verify the mechanism + cls = import_backend_class("collections.OrderedDict") + from collections import OrderedDict + + self.assertIs(cls, OrderedDict) + + def test_spec_defaults(self): + spec = LinearAttnModelSpec( + config_class=FakeLinearAttnConfig, + backend_class_name="pkg.mod.Cls", + ) + self.assertEqual(spec.arch_names, []) + self.assertTrue(spec.uses_mamba_radix_cache) + self.assertTrue(spec.support_mamba_cache) + self.assertFalse(spec.support_mamba_cache_extra_buffer) + self.assertFalse(spec.unwrap_text_config) + + +if __name__ == "__main__": + unittest.main()