Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c64da9e
feat: Add plugin system support for SGLang
Baidu-AIAK Mar 25, 2026
40b9d26
update: Improve common.py modifications for plugin system
Baidu-AIAK Mar 25, 2026
c7bb8ab
docs: Add plugin documentation and update implementation
Baidu-AIAK Mar 26, 2026
b75ad74
docs: Update plugin documentation with additional details
Baidu-AIAK Mar 26, 2026
da9afbd
feat: Add plugin system support for SGLang
Mar 30, 2026
2132d76
refactor: Merge ClassReplacer into HookRegistry, rename sglang_hook t…
Baidu-AIAK Mar 30, 2026
c3ffa19
Replace custom resolve_obj() with stdlib pkgutil.resolve_name() && ad…
Mar 31, 2026
4b25ba4
DeviceMixin:add 12 new methods with [Active]/[Planned] annotations &&…
Apr 1, 2026
a850974
fix: Delete accidental return statement in assert_pkg_version
Apr 2, 2026
0730a45
fix: Apply REPLACE hooks before AROUND/BEFORE/AFTER regardless of reg…
Apr 6, 2026
34130e4
fix: Propagate hook patches to modules with stale bindings
Apr 9, 2026
0ba5f75
Merge branch 'main' into plugin
alexnails Apr 13, 2026
70adaed
lint
alexnails Apr 13, 2026
01650d9
Move inline imports to top level, add plugin unit tests, fix plugin.m…
Apr 13, 2026
6c89c24
Fix pre-commit: isort import order in engine.py/scheduler.py, black f…
Apr 13, 2026
3954351
Rewrite platform & plugin unit tests per skill guidelines
Apr 14, 2026
42f4542
Fix pre-commit
Apr 14, 2026
1c05d10
Merge branch 'main' into plugin
alexnails Apr 14, 2026
4679dc1
Merge upstream main into plugin, resolve conflicts in scheduler.py an…
Apr 16, 2026
908a221
Fix UnboundLocalError of current_platform in HybridLinearKVPool
Apr 16, 2026
58e4c07
Merge upstream main into plugin, resolve conflicts in server_args.py
Apr 17, 2026
9f853a8
Merge branch 'main' into plugin
mickqian Apr 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
414 changes: 414 additions & 0 deletions docs/platforms/plugin.md

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions python/sglang/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def serve(args, extra_argv):
)
return

from sglang.srt.plugins import load_plugins

load_plugins()

model_type, dispatch_argv = _extract_model_type_override(extra_argv)
model_path = get_model_path(dispatch_argv)
try:
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def run_server(server_args):
stacklevel=1,
)

from sglang.srt.plugins import load_plugins

load_plugins()

server_args = prepare_server_args(sys.argv[1:])

try:
Expand Down
8 changes: 7 additions & 1 deletion python/sglang/srt/compilation/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sglang.srt.compilation.npu_piecewise_backend import NPUPiecewiseBackend
from sglang.srt.compilation.pass_manager import PostGradPassManager
from sglang.srt.environ import envs
from sglang.srt.platforms import current_platform
from sglang.srt.utils.common import is_npu

logger = logging.getLogger(__name__)
Expand All @@ -48,7 +49,12 @@ def make_backend(
sglang_backend,
):

backend_cls = CUDAPiecewiseBackend if not is_npu() else NPUPiecewiseBackend
if current_platform.is_out_of_tree():
backend_cls = current_platform.get_piecewise_backend_cls()
elif is_npu():
backend_cls = NPUPiecewiseBackend
else:
backend_cls = CUDAPiecewiseBackend
return backend_cls(
graph,
compile_config,
Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.observability.trace import process_tracing_init, trace_set_thread_info
from sglang.srt.plugins import load_plugins
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
MultiprocessingSerializer,
Expand Down Expand Up @@ -167,6 +168,10 @@ def __init__(self, **kwargs):
Please refer to `ServerArgs` for the documentation.
"""

# Ensure plugins are loaded before ServerArgs construction,
# so hooks on ServerArgs.__post_init__ fire correctly.
load_plugins()

# Parse server_args
if "server_args" in kwargs:
# Directly load server_args
Expand Down Expand Up @@ -647,6 +652,11 @@ def _launch_subprocesses(
# Configure global environment
configure_logger(server_args)
_set_envs_and_config(server_args)

# Defensive: ensure plugins loaded (may already be loaded by
# Engine.__init__ or CLI entry).
load_plugins()

server_args.check_server_args()
_set_gc(server_args)

Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,10 @@ class Envs:
# Sglang Cache Dir
SGLANG_CACHE_DIR = EnvStr(os.path.expanduser("~/.cache/sglang"))

# Plugin system
SGLANG_PLATFORM = EnvStr("")
SGLANG_PLUGINS = EnvStr("")


envs = Envs()
EnvField._allow_set_name = False
Expand Down
23 changes: 22 additions & 1 deletion python/sglang/srt/layers/utils/multi_platform.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Callable
from typing import Callable, ClassVar

from torch import nn

from sglang.kernel_api_logging import debug_kernel_api
from sglang.srt.platforms import current_platform
from sglang.srt.utils import (
cpu_has_amx_support,
is_cpu,
Expand All @@ -23,6 +24,15 @@


class MultiPlatformOp(nn.Module):

# OOT forward registry: maps dispatch_key -> {op_cls -> forward_fn}
_oot_forward_registry: ClassVar[dict[str, dict[type, Callable]]] = {}

@classmethod
def register_oot_forward(cls, op_cls: type, fn: Callable, platform_key: str):
"""Register an OOT forward implementation for a specific op class and platform."""
cls._oot_forward_registry.setdefault(platform_key, {})[op_cls] = fn

def __init__(self):
super().__init__()
self._forward_method: Callable = self.dispatch_forward()
Expand Down Expand Up @@ -100,6 +110,17 @@ def forward_cpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)

def dispatch_forward(self):
# OOT platform dispatch: check registry then method lookup
if current_platform.is_out_of_tree():
key = current_platform.get_dispatch_key_name()
oot = self._oot_forward_registry.get(key, {})
if type(self) in oot:
return oot[type(self)].__get__(self)
method = getattr(self, f"forward_{key}", None)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are using getattr with dispatch_key to auto-find forward methods defined in sub-classes, it seems we no longer need all forward_xxx in MultiPlatformOp? Maybe forward_native can be kept as an escape hatch.

The current MultiPlatformOp's forward_xxx methods provide some fallback logic, e.g., hip -> cuda. And we can have fallback_dispatch_keys for each platform to cover this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. Once the platform interface stabilizes, we'd love to clean these up — the
fallback_dispatch_keys idea is a nice approach too. For now we're trying to
keep in-tree changes minimal in this PR, so I'd prefer to address it in a
follow-up if that's okay with you. Really appreciate the input!

if method is not None:
return method
return self.forward_native

if _is_cuda:
return self.forward_cuda
elif _is_hip:
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@
)
from sglang.srt.observability.trace import process_tracing_init, trace_set_thread_info
from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.plugins import load_plugins
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
Expand Down Expand Up @@ -3742,6 +3743,8 @@ def run_scheduler_process(
dp_rank: Optional[int],
pipe_writer,
):
# Load plugins so hooks can override Scheduler and its dependencies.
load_plugins()
dp_rank = configure_scheduler_process(
server_args,
gpu_id,
Expand Down
15 changes: 12 additions & 3 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
set_mla_kv_buffer_triton_fp8_quant,
set_mla_kv_scale_buffer_triton,
)
from sglang.srt.platforms import current_platform
from sglang.srt.utils import (
cpu_has_amx_support,
is_cpu,
Expand Down Expand Up @@ -780,8 +781,12 @@ def __init__(
self._create_buffers()

self.device_module = torch.get_device_module(self.device)

_use_alt_stream = _is_cuda or current_platform.is_cuda_alike()
Comment thread
alexnails marked this conversation as resolved.
self.alt_stream = (
self.device_module.Stream() if _is_cuda and enable_alt_stream else None
self.device_module.Stream()
if _use_alt_stream and enable_alt_stream
else None
)

if enable_kv_cache_copy:
Expand Down Expand Up @@ -1262,7 +1267,9 @@ def __init__(

TokenToKVPoolClass = MHATokenToKVPool

if _is_npu:
if current_platform.is_out_of_tree():
TokenToKVPoolClass = current_platform.get_mha_kv_pool_cls()
elif _is_npu:
from sglang.srt.hardware_backend.npu.memory_pool_npu import (
NPUMHATokenToKVPool,
)
Expand All @@ -1283,7 +1290,9 @@ def __init__(

TokenToKVPoolClass = MLATokenToKVPool

if _is_npu:
if current_platform.is_out_of_tree():
TokenToKVPoolClass = current_platform.get_mla_kv_pool_cls()
elif _is_npu:
from sglang.srt.hardware_backend.npu.memory_pool_npu import (
NPUMLATokenToKVPool,
)
Expand Down
44 changes: 34 additions & 10 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.platforms import current_platform
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import (
ServerArgs,
Expand Down Expand Up @@ -207,6 +208,8 @@
from sglang.srt.hardware_backend.npu.utils import init_npu_backend

init_npu_backend()
elif current_platform.is_out_of_tree():
current_platform.init_backend()

MLA_ATTENTION_BACKENDS = [
"aiter",
Expand Down Expand Up @@ -702,6 +705,7 @@ def initialize(self, pre_model_load_memory: float):
# Init routed experts capturer
self.init_routed_experts_capturer()

# TODO: Refactor device-specific init branches into platform interface (separate PR).
# Must be called BEFORE init_device_graphs() so CUDA graph capture
# runs with aux hidden state capture enabled.
self.init_aux_hidden_state_capture()
Expand All @@ -714,6 +718,13 @@ def initialize(self, pre_model_load_memory: float):
elif self.device in ["npu", "cpu"]:
self.init_attention_backend()
self.init_device_graphs()
elif current_platform.is_out_of_tree():
self.init_attention_backend()
if current_platform.support_cuda_graph():
Copy link
Copy Markdown
Collaborator

@alexnails alexnails Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kind of on the same note as another comment, but cuda graph we just mean FULL-style CUDA graph capture right?

might be worth renaming... hmmm

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally understand the naming confusion. However, renaming touches 85+ sites across the codebase, which is a pretty
large scope for this PR. Would it be okay to leave this as a TODO and address
it in a dedicated follow-up?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I am thinking that that is its own PR. please leave a todo

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added TODO comment at the OOT branch in initialize().

self.init_device_graphs()
else:
self.graph_runner = None
self.graph_mem_usage = 0
else:
self.graph_runner = None
self.graph_mem_usage = 0
Expand Down Expand Up @@ -1483,7 +1494,14 @@ def model_load_weights(model, iter):
self.server_args.load_format = load_format
self.load_config = load_config

if recapture_cuda_graph and (self.device == "cuda" or self.device == "musa"):
if recapture_cuda_graph and (
self.device == "cuda"
or self.device == "musa"
or (
current_platform.is_out_of_tree()
Comment thread
alexnails marked this conversation as resolved.
and current_platform.support_cuda_graph()
)
):
self.init_device_graphs()

logger.info("Update weights end.")
Expand Down Expand Up @@ -2532,23 +2550,29 @@ def init_device_graphs(self):
tic = time.perf_counter()
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
graph_backend = defaultdict(
lambda: "cuda graph",
lambda: f"{current_platform.device_name} graph",
{
"cuda": "cuda graph",
"musa": "cuda graph",
"cpu": "cpu graph",
"npu": "npu graph",
},
)
logger.info(
f"Capture {graph_backend[self.device]} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
)
graph_runners = defaultdict(
lambda: CudaGraphRunner,
{
"cpu": CPUGraphRunner,
"npu": NPUGraphRunner,
},
)
self.graph_runner = graph_runners[self.device](self)
if current_platform.is_out_of_tree():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as Alex said, we need to clean this up in the next PR

GraphRunnerCls = current_platform.get_graph_runner_cls()
self.graph_runner = GraphRunnerCls(self)
else:
graph_runners = defaultdict(
lambda: CudaGraphRunner,
{
"cpu": CPUGraphRunner,
"npu": NPUGraphRunner,
},
)
self.graph_runner = graph_runners[self.device](self)

after_mem = get_available_gpu_memory(self.device, self.gpu_id)
self.graph_mem_usage = before_mem - after_mem
Expand Down
70 changes: 68 additions & 2 deletions python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,63 @@ def _init_pools(self: ModelRunner):

# Initialize token_to_kv_pool
is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
if self.server_args.attention_backend == "ascend" and not self.mambaish_config:

# Check out-of-tree platform (plugin system) first
from sglang.srt.platforms import current_platform

if current_platform.is_out_of_tree() and not self.mambaish_config:
if self.use_mla_backend and is_nsa_model:
PoolCls = current_platform.get_nsa_kv_pool_cls()
self.token_to_kv_pool = PoolCls(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.num_effective_layers,
device=self.device,
kv_cache_dim=self.calculate_mla_kv_cache_dim(),
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
)
elif self.use_mla_backend:
PoolCls = current_platform.get_mla_kv_pool_cls()
self.token_to_kv_pool = PoolCls(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
index_head_dim=(
self.model_config.index_head_dim if is_nsa_model else None
),
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
)
else:
PoolCls = current_platform.get_mha_kv_pool_cls()
self.token_to_kv_pool = PoolCls(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(
get_attention_tp_size()
),
head_dim=self.model_config.head_dim,
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
)
Comment on lines +289 to +338
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for initializing the different KV pool types for out-of-tree platforms involves significant code duplication, especially for the constructor arguments. This can be refactored to improve readability and maintainability by extracting common arguments into a dictionary.

        if current_platform.is_out_of_tree() and not self.mambaish_config:
            pool_args = {
                "max_total_num_tokens": self.max_total_num_tokens,
                "page_size": self.page_size,
                "dtype": self.kv_cache_dtype,
                "layer_num": self.num_effective_layers,
                "device": self.device,
                "enable_memory_saver": self.server_args.enable_memory_saver,
                "start_layer": self.start_layer,
                "end_layer": self.end_layer,
            }

            if self.use_mla_backend and is_nsa_model:
                PoolCls = current_platform.get_nsa_kv_pool_cls()
                pool_args.update({
                    "kv_lora_rank": self.model_config.kv_lora_rank,
                    "qk_rope_head_dim": self.model_config.qk_rope_head_dim,
                    "kv_cache_dim": self.calculate_mla_kv_cache_dim(),
                    "index_head_dim": get_nsa_index_head_dim(
                        self.model_config.hf_config
                    ),
                })
            elif self.use_mla_backend:
                PoolCls = current_platform.get_mla_kv_pool_cls()
                pool_args.update({
                    "kv_lora_rank": self.model_config.kv_lora_rank,
                    "qk_rope_head_dim": self.model_config.qk_rope_head_dim,
                    "index_head_dim": (
                        self.model_config.index_head_dim if is_nsa_model else None
                    ),
                })
            else:
                PoolCls = current_platform.get_mha_kv_pool_cls()
                pool_args.update({
                    "head_num": self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    "head_dim": self.model_config.head_dim,
                })

            self.token_to_kv_pool = PoolCls(**pool_args)

elif (
self.server_args.attention_backend == "ascend" and not self.mambaish_config
):
if self.is_hybrid_swa:
from sglang.srt.hardware_backend.npu.memory_pool_npu import (
NPUMHATokenToKVPool,
Expand Down Expand Up @@ -513,7 +569,17 @@ def _init_pools(self: ModelRunner):
# Initialize token_to_kv_pool_allocator
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
if self.token_to_kv_pool_allocator is None:
if _is_npu and (
if current_platform.is_out_of_tree():
AllocatorCls = current_platform.get_paged_allocator_cls()
self.token_to_kv_pool_allocator = AllocatorCls(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
elif _is_npu and (
self.server_args.attention_backend == "ascend"
or self.hybrid_gdn_config is not None
):
Expand Down
Loading
Loading