Skip to content

Multi platform Plugin#21388

Merged
merrymercy merged 22 commits intosgl-project:mainfrom
Baidu-AIAK:plugin
Apr 20, 2026
Merged

Multi platform Plugin#21388
merrymercy merged 22 commits intosgl-project:mainfrom
Baidu-AIAK:plugin

Conversation

@Baidu-AIAK
Copy link
Copy Markdown
Contributor

@Baidu-AIAK Baidu-AIAK commented Mar 25, 2026

Summary

Introduce a unified plugin framework for SGLang, inspired by vLLM's platform abstraction, enabling hardware vendors and advanced users to extend SGLang without forking or modifying the main repository.

The framework provides two plugin types, both discovered via Python's standard setuptools entry_points mechanism:

  • Platform Plugins (sglang.platform_plugins): Register custom hardware platforms — device ops, KV cache pools, attention backends, CUDA Graph runners, compilation backends, multi-platform dispatch, etc.
  • General Function Plugins (sglang.plugins): Inject hooks (BEFORE / AFTER / AROUND / REPLACE) into arbitrary functions/methods in SGLang, or replace entire classes — all managed by a single HookRegistry.

Key Design Principles

  • Non-invasive: Existing CUDA/ROCm/NPU/XPU code is untouched. In-tree hardware continues to use the existing is_cuda() / is_npu() utility functions (432+ call sites across 195 files). The Platform system is exclusively for OOT discovery.
  • Zero-config: Plugins are auto-discovered after pip install — no SGLang code changes required.
  • Allowlist control: SGLANG_PLUGINS env var provides comma-separated allowlist filtering.

Changes

New framework core files (5 files)

File Description
srt/plugins/__init__.py Entry point discovery, idempotent load_plugins(), SGLANG_PLUGINS env var filtering
srt/plugins/hook_registry.py HookRegistry with BEFORE/AFTER/AROUND/REPLACE support, plugin_hook decorator, class replacement via setattr, resolve_obj()
srt/platforms/device_mixin.py PlatformEnum (10 members) + DeviceMixin base class (identity queries + device operations)
srt/platforms/interface.py SRTPlatform(DeviceMixin) — factory methods, capability flags, lifecycle hooks
srt/platforms/__init__.py Lazy-initialized current_platform singleton, pure entry_points OOT discovery + SGLANG_PLATFORM env var override

Plugin loading integration (4 call sites)

File Location Process Purpose
cli/serve.py serve() top Main sglang serve CLI — before model type dispatch
launch_server.py __main__ Main Legacy python -m sglang.launch_server entrypoint
entrypoints/engine.py _launch_subprocesses() Main Python API Engine(model_path=...) — before check_server_args()
managers/scheduler.py run_scheduler_process() Subprocess Spawned subprocess — re-registers hooks before Scheduler instantiation

load_plugins() is idempotent (boolean guard). The subprocess call is necessary because mp.Process(spawn) creates a fresh Python interpreter that does not inherit main process memory state.

OOT code paths (7 files, all additive — no existing code removed)

File OOT branches Purpose
server_args.py 3 apply_server_args_defaults(), disable piecewise CUDA Graph, default attention backend
model_runner.py 5 Module-level init_backend(), post-init initialization, graph recapture after weight update, graph log label, custom GraphRunner class
model_runner_kv_cache_mixin.py 2 KV cache pool selection (MHA/MLA/NSA) via platform factory methods, paged allocator
memory_pool.py 3 HybridMambaTokenToKVPool MHA/MLA pools, is_cuda_alike() for alt_stream
multi_platform.py 1 OOT forward dispatch registry + forward_{key} method lookup
compilation/backend.py 1 Piecewise compilation backend via platform factory
utils/common.py 3 get_available_gpu_memory() OOT fallback, get_device_memory_capacity(), get_compiler_backend()

Architecture

Main Process (cli/serve.py | engine.py)
  ├── load_plugins()
  │     ├── discover & execute general plugins (sglang.plugins)
  │     └── HookRegistry.apply_hooks()      ← monkey-patch main process targets
  ├── check_server_args()                   ← affected by hooks ✅
  └── mp.Process(spawn) ─────────────────── fresh subprocess, all patches lost
        │
        Subprocess (scheduler.py → run_scheduler_process)
          ├── load_plugins()                ← idempotent re-load, re-registers hooks
          │     └── HookRegistry.apply_hooks()  ← patch subprocess targets (Scheduler, ModelRunner, etc.)
          └── Scheduler(...)                ← gets replacement class if registered ✅

Class hierarchy

DeviceMixin                                  # identity queries + device operations
├── SRTPlatform(DeviceMixin)                 # + factory methods, capability flags
│   └── MySRTPlatform(SRTPlatform, MyDeviceMixin)   # OOT vendor platform
└── MMPlatform(DeviceMixin)                  # (future) multimodal/diffusion
    └── MyMMPlatform(MMPlatform, MyDeviceMixin)

Vendors implement MyDeviceMixin(DeviceMixin) once for device operations, then mix it into both SRTPlatform and future MMPlatform subclasses via Python MRO.

Hook system

The plugin_hook decorator and HookRegistry.register() provide a unified API for both function hooks and class replacement:

from sglang.srt.plugins.hook_registry import plugin_hook, HookType

# Function hook — wrap Scheduler.__init__ with extra logic
@plugin_hook("sglang.srt.managers.scheduler.Scheduler.__init__", HookType.AROUND)
def my_scheduler_init_hook(original_fn, self, *args, **kwargs):
    original_fn(self, *args, **kwargs)
    logger.info("Scheduler initialized with custom config")

# Class replacement — swap Scheduler entirely
@plugin_hook("sglang.srt.managers.scheduler.Scheduler", HookType.REPLACE)
class MyScheduler(Scheduler):
    ...

Class replacement uses direct setattr (not functools.wraps wrapper), preserving isinstance / issubclass / inheritance semantics. A dual validation mechanism prevents misuse: classes can only use REPLACE, and function hooks cannot target class objects.

Platform discovery flow

1. SGLANG_PLATFORM env var     → forced class (dev/testing)
2. entry_points("sglang.platform_plugins") → OOT discovery (at most one active)
3. Fallback                    → base SRTPlatform(UNSPECIFIED)

Documentation

  • docs/platforms/plugin.md — comprehensive guide covering both plugin types, architecture diagrams, API reference, and quickstart examples.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request establishes a robust, multi-platform plugin architecture for SGLang. The primary goal is to enable external extensibility, allowing developers to integrate new hardware platforms or customize core behaviors through a non-invasive plugin system. This significantly enhances SGLang's adaptability by abstracting hardware-specific logic and providing flexible function hooking and class replacement mechanisms, all while maintaining a clean core codebase.

Highlights

  • Unified Plugin Framework: Introduced a unified plugin framework for SGLang, enabling hardware vendors and advanced users to extend functionality without modifying the core repository.
  • Two Plugin Types: Implemented two distinct plugin types: Platform Plugins for custom hardware integration (device operations, KV cache pools, attention backends) and General Function Plugins for injecting hooks or replacing classes.
  • Automatic Discovery and Non-Invasive Design: Enabled automatic plugin discovery via Python's setuptools entry_points, ensuring a zero-configuration setup. The design is non-invasive, integrating out-of-tree (OOT) code paths as elif branches to preserve existing hardware-specific logic.
  • Two-Phase Hook Application: Incorporated a two-phase hook application process to correctly handle SGLang's spawn multiprocessing model, ensuring hooks are applied consistently in both main and worker processes.
  • Allowlist Control: Provided granular control over plugin loading through the SGLANG_PLUGINS environment variable, allowing users to specify an allowlist of plugins.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive plugin-based architecture to SGLang, enabling "out-of-tree" (OOT) hardware platform support and general engine extensibility. It defines a Platform abstraction with a PlatformEnum and a CudaPlatform implementation, handling dynamic platform discovery and lazy initialization. This new abstraction is integrated across various core components, including compilation, engine entrypoints, multi-platform operations, memory management, model runner, and server argument processing, allowing OOT platforms to customize behavior and provide specific implementations for various subsystems. A review comment suggests refactoring the KV pool initialization logic for OOT platforms to reduce code duplication and improve maintainability.

Comment on lines +449 to +500
if current_platform.is_out_of_tree() and not self.mambaish_config:
if self.use_mla_backend and is_nsa_model:
PoolCls = current_platform.get_nsa_kv_pool_cls()
self.token_to_kv_pool = PoolCls(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.num_effective_layers,
device=self.device,
kv_cache_dim=self.calculate_mla_kv_cache_dim(),
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
index_head_dim=get_nsa_index_head_dim(
self.model_config.hf_config
),
)
elif self.use_mla_backend:
PoolCls = current_platform.get_mla_kv_pool_cls()
self.token_to_kv_pool = PoolCls(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
index_head_dim=(
self.model_config.index_head_dim if is_nsa_model else None
),
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
)
else:
PoolCls = current_platform.get_mha_kv_pool_cls()
self.token_to_kv_pool = PoolCls(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(
get_attention_tp_size()
),
head_dim=self.model_config.head_dim,
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for initializing the different KV pool types for out-of-tree platforms involves significant code duplication, especially for the constructor arguments. This can be refactored to improve readability and maintainability by extracting common arguments into a dictionary.

        if current_platform.is_out_of_tree() and not self.mambaish_config:
            pool_args = {
                "max_total_num_tokens": self.max_total_num_tokens,
                "page_size": self.page_size,
                "dtype": self.kv_cache_dtype,
                "layer_num": self.num_effective_layers,
                "device": self.device,
                "enable_memory_saver": self.server_args.enable_memory_saver,
                "start_layer": self.start_layer,
                "end_layer": self.end_layer,
            }

            if self.use_mla_backend and is_nsa_model:
                PoolCls = current_platform.get_nsa_kv_pool_cls()
                pool_args.update({
                    "kv_lora_rank": self.model_config.kv_lora_rank,
                    "qk_rope_head_dim": self.model_config.qk_rope_head_dim,
                    "kv_cache_dim": self.calculate_mla_kv_cache_dim(),
                    "index_head_dim": get_nsa_index_head_dim(
                        self.model_config.hf_config
                    ),
                })
            elif self.use_mla_backend:
                PoolCls = current_platform.get_mla_kv_pool_cls()
                pool_args.update({
                    "kv_lora_rank": self.model_config.kv_lora_rank,
                    "qk_rope_head_dim": self.model_config.qk_rope_head_dim,
                    "index_head_dim": (
                        self.model_config.index_head_dim if is_nsa_model else None
                    ),
                })
            else:
                PoolCls = current_platform.get_mha_kv_pool_cls()
                pool_args.update({
                    "head_num": self.model_config.get_num_kv_heads(
                        get_attention_tp_size()
                    ),
                    "head_dim": self.model_config.head_dim,
                })

            self.token_to_kv_pool = PoolCls(**pool_args)

@github-actions github-actions Bot added the documentation Improvements or additions to documentation label Mar 26, 2026
Comment thread python/sglang/srt/managers/tp_worker.py Outdated
Comment on lines +263 to +273
# Apply worker-level platform patches (phase 2 monkey patching).
from sglang.srt.platforms import current_platform

current_platform.apply_worker_patches()

# Apply deferred hooks (phase 2, idempotent).
# Re-discover plugins in subprocess (spawn'd processes lose main-process state).
from sglang.srt.plugins import load_general_plugins
from sglang.srt.plugins.hook_registry import HookRegistry
load_general_plugins()
HookRegistry.apply_hooks()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should call apply_hooks right after the process is created, so the hook can override anything including Scheduler and TpWorker.
A good place is run_scheduler_process

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Done — load_plugins() (which calls HookRegistry.apply_hooks() internally) is now called in run_scheduler_process() before Scheduler() construction. The call in tp_worker.py has been removed since TpModelWorker is always created inside the Scheduler process, so it was redundant.

Comment thread python/sglang/srt/managers/tp_worker.py Outdated
Comment on lines +272 to +273
load_general_plugins()
HookRegistry.apply_hooks()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplify the API, we should reduce the call to only one function call.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — load_plugins() is now the single entry point. It discovers plugins, executes them, and calls HookRegistry.apply_hooks() internally. Callers no longer need a separate apply_hooks() step.

Comment thread docs/platforms/plugin.md Outdated
| Plugin Type | Entry Point Group | Purpose |
|---|---|---|
| **Hardware Platform Plugin** | `sglang.platform_plugins` | Register a custom hardware platform (device operations, KV cache pools, attention backends, CUDA Graph, compilation backends, etc.) |
| **General Function Plugin** | `sglang.general_plugins` | Inject hooks (before/after/around/replace) into any function/method in sglang, or replace entire classes |
Copy link
Copy Markdown
Contributor

@merrymercy merrymercy Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name sglang.general_plugins is confusing. We do not have a folder with this name sglang.general_plugins. Can you improve the name?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks! Renamed to sglang.plugins — shorter and consistent with the actual package structure.

Comment thread python/sglang/srt/platforms/__init__.py
Copy link
Copy Markdown
Collaborator

@alexnails alexnails left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably have some more comments but going to drop this now so there is stuff to go over. Please slack me if you have any questions as I want to help you as much as I can

def is_musa(self) -> bool:
return self._enum == PlatformEnum.MUSA

def is_cuda_alike(self) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yeahdongcn multi modal and SRT handle is cuda alike differently wrt MUSA. leave as a note that this needs to be resolved later?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In DeviceMixin, we unify 'is_cuda_alike()' as CUDA+ROM+MUSA. Therefore, once MM inherits DeviceMixin, it will automatically obtain a consistent definition. The SRT prefix in SRTPlatform (as in the example you provided) is intended to distinguish it from future MMPPlatform (DeviceMixin) - both share the same DeviceMixin foundation but carry subsystem specific factory methods. If you have better naming suggestions, I would be happy to make adjustments.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I missed that comment. Yes, the behavior differs between multimodal_gen and SRT. In multimodal_gen, current_platform.is_cuda_alike() is only used to determine graph capture behavior and the communication method.

In contrast, in SRT, it is used to decide whether certain kernels can be imported.

For example:

if _is_cuda_alike:
    from sgl_kernel import (
        cutlass_w4a8_moe_mm,
        get_cutlass_w4a8_moe_mm_data,
    )


class PlatformEnum(enum.Enum):
"""Enumeration of known platform types."""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is missing hardware types (e.g NPU is one of the top of my head) and if following my other PR comments, this should be a mixin.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, addressed! PlatformEnum now covers all current hardware: CUDA, ROCM, CPU, XPU, MUSA, NPU, TPU, MPS, OOT, UNSPECIFIED. All identity queries (is_cuda(), is_npu(), is_musa(), etc.) are defined in DeviceMixin and derived automatically from _enum.

SGLang Hardware Platform Abstraction.

Defines the Platform base class and PlatformEnum. Each hardware backend
(CUDA, ROCm, NPU, XPU, etc.) implements a Platform subclass providing
Copy link
Copy Markdown
Collaborator

@alexnails alexnails Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this approach has a few things we need to address:

  1. We should have the multimodal and srt platforms share the same core functionalities as multimodal already has a platform object (which will at some point inherit from our base abstraction). We want to avoid diamond inheritance though, so we should move to a device mixin approach IMO.
  2. Future ModelRunner / SpecDec refactor: With a mixin, ModelRunner can compose with just DeviceMixin for device operations without needing the full SRTPlatform. This is cleaner than an ABC hierarchy for the planned SpecDec refactor.
  3. As far as I understand, the @classmethod decorators / implementation don't focus around failing fast. If a hardware platform does not support something it should be NotImplementedError

Example

# Mixin -- device operations only, no ABC chain
class DeviceMixin:
    name: str
    _enum: PlatformEnum
    def get_device(self, local_rank) -> torch.device: ...
    def get_device_name(self, device_id=0) -> str: ...
    def get_distributed_backend(self) -> str: ...
    def get_available_memory(self, device_id=0) -> tuple[int, int]: ...
    # ... ~15 shared device/memory/distributed methods

class SRTPlatform(DeviceMixin):
    # SRT-specific: graph runners, KV pools, quant, compilation
    def get_graph_runner_class(self) -> type: ...
    def get_kv_pool_class(self, use_mla: bool) -> type: ...

class MMPlatform(DeviceMixin):
    # MM-specific methods
    def get_attn_backend_cls_str(self) -> str: ...

# External packages -- no diamond inheritance!
class CudaDeviceMixin(DeviceMixin):
    name = "npu"
    def get_device(self, local_rank): return torch.device("npu", local_rank)
    def get_distributed_backend(self): return "hccl"

class CudaSRTPlatform(SRTPlatform, CudaDeviceMixin):  # clean MRO
    ...
class CudaMMPlatform(MMPlatform, CudaDeviceMixin):    # clean MRO
    ...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion, thank you for the detailed design! Implemented this pattern. Created DeviceMixin in platforms/device_mixin.py with identity queries + device operations (all raising NotImplementedError). SRTPlatform extends DeviceMixin for SRT-specific factory methods. OOT plugins compose via MySRTPlatform(SRTPlatform, MyDeviceMixin) — clean MRO, no diamond inheritance. The MMPlatform(DeviceMixin) slot is ready for when the multimodal subsystem migrates to this pattern.

Comment on lines +50 to +57
from sglang.srt.platforms import current_platform

if current_platform.is_out_of_tree():
backend_cls = current_platform.get_piecewise_backend_cls()
elif is_npu():
backend_cls = NPUPiecewiseBackend
else:
backend_cls = CUDAPiecewiseBackend
Copy link
Copy Markdown
Collaborator

@alexnails alexnails Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really what this should be is just

from sglang.srt.platforms import current_platform
current_platform.get_piecewise_backend_cls()

we do not care that the platform is OOT, the hardware plugin itself should be able to implement a PiecewiseBackend Class that we can run (and we should know from flags / server args /. etc wherever we already determine if piecewise backend can be used)

Ideally, we have a unified platform dispatch

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely agree with the direction! This is tracked as part of the future migration described in plugin.md under "Current Scope & Future Direction". Currently, in-tree platforms (CUDA/NPU) still use direct imports rather than the platform interface. Once each in-tree backend is migrated to its own SRTPlatform subclass, the if/elif chain here will collapse into a single current_platform.get_piecewise_backend_cls() call. For this PR, we took the minimal non-intrusive approach of adding the OOT branch alongside existing logic.

Comment on lines +218 to +226
@classmethod
def support_cuda_graph(cls) -> bool:
"""Whether this platform supports CUDA graph capture."""
return True

@classmethod
def support_cublas(cls) -> bool:
"""Whether this platform supports cuBLAS initialization."""
return False
Copy link
Copy Markdown
Collaborator

@alexnails alexnails Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are these in an interface and not part of the cuda implementation? This comments can be seen as a paint brush for quite a few things in here. Let's chat more about this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaned this up — CUDA-specific methods have been removed from the base class. Methods now raise NotImplementedError (fail-fast). support_cublas() has been deleted entirely. support_cuda_graph() is kept — Many non-CUDA platforms (ROCm, MUSA, and potentially OOT devices) support a similar graph capture mechanism. In the codebase it gates init_device_graphs() and disable_piecewise_cuda_graph, which are hardware-agnostic graph capture paths. It defaults to False (conservative), so platforms opt in explicitly. We could consider renaming it to support_device_graph() in a follow-up if the naming feels misleading.

from sglang.srt.platforms import current_platform

if current_platform.is_out_of_tree():
mem_bytes = current_platform.get_device_total_memory()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this not numerically safe call? (will default to 0)

and same comment as https://github.com/sgl-project/sglang/pull/21388/changes#r2999364494.

the bottom of these should be pinned to their platform implementations

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking closely. Without the OOT branch, the if/elif chain falls through with no return value (None), which causes a TypeError downstream when used in arithmetic. The OOT branch calls current_platform.get_device_total_memory() which OOT plugins must implement (raises NotImplementedError if not).

@@ -0,0 +1,197 @@
"""
Function-level hook registry for SGLang plugins.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this! However.... we can probably (and should) make this a decorator

@sglang_hook("sglang.srt.managers.scheduler.Scheduler.schedule", type="around")
def my_timer(original_fn, *args, **kwargs):
    ...

Makes things easier (especially if we ever expose this as some form of JIT to something being hooked)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the sglang_hook decorator as suggested. Both the imperative HookRegistry.register() API and the decorator @sglang_hook(target, type=HookType.AROUND) are now available, so plugin authors can pick whichever style they prefer.


Allows plugins to transparently replace classes in the sglang engine
with custom implementations. Similar to vLLM's CustomOp.register_oot pattern.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this entire class is also just

@hook("sglang.srt.some.Class", type="replace")

what do you think?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! ClassReplacer has been merged into HookRegistry and class_replacer.py is deleted.
Class replacement is now a special case of HookType.REPLACE within the unified hook system. The plugin_hook decorator handles both function hooks and class replacement

Comment thread python/sglang/srt/plugins/__init__.py Outdated

# Entry point group names
PLATFORM_PLUGINS_GROUP = "sglang.platform_plugins"
GENERAL_PLUGINS_GROUP = "sglang.general_plugins"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to @merrymercy comments. I do not like this name (slack me if u want and I can help workshop)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to sglang.plugins — shorter and consistent with the actual package structure.

Comment thread python/sglang/srt/plugins/__init__.py Outdated
Returns:
Dictionary mapping plugin name to its loaded callable.
"""
from importlib.metadata import entry_points
Copy link
Copy Markdown
Collaborator

@alexnails alexnails Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic can be simplified via some toml work

def discover_platforms() -> dict[str, type[SRTPlatform]]:
    platforms: dict[str, type[SRTPlatform]] = {}

    # 1. Built-in (always available)
    from sglang.srt.platforms.cuda import CUDASRTPlatform
    from sglang.srt.platforms.cpu import CPUSRTPlatform
    platforms["cuda"] = CUDASRTPlatform
    platforms["cpu"] = CPUSRTPlatform
    # rest of in tree platforms

    # 2. entry_points from pip-installed packages
    for ep in importlib.metadata.entry_points(group="sglang.platforms"):
        try:
            cls = ep.load()
            platforms[ep.name] = cls
        except Exception as e:
            logger.warning(f"Failed to load platform plugin {ep.name}: {e}")

    # 3. SGLANG_PLATFORM_PLUGIN override (dev/testing)
    if plugin_spec := os.environ.get("SGLANG_PLATFORM_PLUGIN"):
        name, qualname = plugin_spec.split(":", 1) if ":" in plugin_spec else (plugin_spec, plugin_spec)
        cls = resolve_obj_by_qualname(qualname)
        platforms[name] = cls

    return platforms


def get_platform(device: str) -> SRTPlatform:
    """Return the platform for a given device string. Caches instances."""
    # Looks up from discovered platforms, instantiates, caches
    ...

packages register via pyproject.toml:

[project.entry-points."sglang.platforms"]
npu = "sglang_npu.platform:NPUSRTPlatform"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the thoughtful design! We simplified init.py to only do entry_points discovery for OOT platforms. Built-in platforms (CUDA/ROCm/NPU/XPU) are not registered as platform plugins yet — they continue to use the existing is_cuda() utility functions (432+ call sites across 195 files). This avoids a massive refactor in this PR while achieving the same goal of clean OOT extensibility. Once the interfaces stabilize, built-in platforms can be gradually migrated to the same plugin architecture, which aligns with your suggested end-state.

def is_out_of_tree(self) -> bool:
"""Returns True for externally-registered OOT platforms."""
return self._enum == PlatformEnum.OOT

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a general comment shouldn't quite a few of these be instant methods and not class methods?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, thanks! All methods in both DeviceMixin and SRTPlatform are now instance methods. current_platform is an instance (lazy singleton), so everything is accessed via current_platform.some_method().

Comment thread python/sglang/srt/platforms/__init__.py Outdated
"""
Discover and instantiate the active platform.

Priority: OOT plugins > builtin detection.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's a good idea to have a class member for priority? @alexnails

# OOT
class OOTPlatform(Platform):
    priority = 500

# built-in
class CUDAPlatform(Platform):
    priority = 100

class CPUPlatform(Platform):
    priority = 10

@Baidu-AIAK
Copy link
Copy Markdown
Contributor Author

Overall looks good to me! Left some minor style comments: IMPORTANT: We need at least one unit test for each plugin system

@merrymercy Added unit tests for the plugin system: test_hook_registry.py covers hook semantics (AROUND/BEFORE/AFTER/REPLACE), classmethod/staticmethod descriptor preservation, hook ordering, cross-target conflict detection, onion model, and idempotency; test_load_plugins.py covers plugin loading idempotency, exception resilience, SGLANG_PLUGINS whitelist filtering, and SGLANG_PLATFORM dist exclusion. 21 tests total.

@AgainstEntropy
Copy link
Copy Markdown
Collaborator

hi @Baidu-AIAK , can you use CustomTestCase instead of unittest.TestCase in test_platform_interface.py?
The reason can be found in test/registered/unit/README.md and the write-sglang-test skill.

def test_custom_supports_fp8(self):
"""Test platform can override supports_fp8."""

class CustomPlatform(SRTPlatform):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are many duplications of class CustomPlatform(SRTPlatform).
Can we extract it and make it shared across tests?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Extracted a shared _StubPlatform subclass that provides minimal concrete implementations, replacing the repeated CustomPlatform(SRTPlatform) boilerplate across tests.

"""Test is_out_of_tree returns False for non-OOT platform."""
self.assertFalse(self.mixin.is_out_of_tree())

def test_empty_cache_noop(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some *_noop and *_raises_not_implemented tests here are actually testing the "python language itself" instead of sglang logic.
The related methods of DeviceMixin are not overrided by a real device mixin class.

Maybe we should consider removing them.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Removed all *_raises_not_implemented and *_noop tests, as well as other tests that only verify Python behavior (default return values, no-op methods, repr formatting, trivial overrides). The remaining tests focus on SGLang-specific logic: classification rules, validation boundaries, error paths, and branching.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need some platform discovery and _resolve_platform related tests, which are very important sglang logic.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Added comprehensive tests for platform discovery: _resolve_platform (both SGLANG_PLATFORM env branch and auto-discover branch), _load_platform_class qualname resolution with type validation

Copy link
Copy Markdown
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! We can merge once CI is green

@Baidu-AIAK
Copy link
Copy Markdown
Contributor Author

Baidu-AIAK commented Apr 16, 2026

hi @Baidu-AIAK , I am trying to use the new platform plugin mechanism from PR #21388 for an MLU backend, but my plugin never gets activated, so current_platform falls back to SRTPlatform(device=unknown) and is_out_of_tree() never becomes true.

I first thought this was only an OOT detection issue, but after more testing it looks like the plugin import itself is failing before activation.

What I tested:

  1. Entry point group in pyproject.toml:
[project.entry-points."sglang.srt.platforms"]
my_device = "my_platform_plugin:activate"
  1. When I run:
python -c "from sglang.srt.platforms import current_platform; print(current_platform)"

I get:

Failed to load plugin my_device from group sglang.srt.platforms
...
File "/workspace/sglang-mlu/my_platform_plugin/my_platform_plugin/__init__.py", line 4, in <module>
    from .srt.platform.device import MluDeviceMixin
ModuleNotFoundError: No module named 'my_platform_plugin.srt'
No platform detected. Using base SRTPlatform with defaults.
SRTPlatform(device=unknown)

So at this point it seems the plugin is not even reaching activate(). The failure happens during module import.

My current __init__.py imports:

from .srt.platform.device import MluDeviceMixin
from .srt.platform.platform import MluSRTPlatform

But my suspicion is that this package layout is incompatible with the plugin mechanism, because my_platform_plugin.srt does not actually exist as an importable subpackage at runtime.

I also have:

class MluDeviceMixin(DeviceMixin):
    _enum = PlatformEnum.OOT
    device_name = "mlu"
    device_type = "mlu"

So my understanding is:

  • is_out_of_tree() should become true automatically once the platform class is successfully activated
  • the current root cause is probably plugin import/package layout, not the OOT check itself

Could you please help confirm whether this understanding is correct?

Also, for an out-of-tree backend package, what is the recommended import/package structure here? Should the plugin package be fully self-contained like the minimal example in plugin.md, or is it valid to structure it like sglang_mlu.srt... and return something like:

return "sglang_mlu.srt.platform.platform.MluSRTPlatform"

from activate()?

Thanks.

@kjuuii
Hi, your understanding is correct on both points:
is_out_of_tree() returns True automatically once your platform class with _enum = PlatformEnum.OOT is successfully activated — no extra step needed.
The root cause is the top-level import in init.py, not the OOT check itself. When the entry-point my_device = "my_platform_plugin:activate" is loaded, Python first import my_platform_plugin (executing init.py).
Since my_platform_plugin.srt doesn't exist as an importable subpackage, the ModuleNotFoundError is raised before activate() is ever reached.
Keep init.py minimal — only define activate() with no top-level subpackage imports:

# __init__.py
def activate():
    """Return FQN of platform class, or None to skip."""
    if _mlu_is_available():
        return "my_platform_plugin.platform.MluSRTPlatform"
    return None

And use a flat, self-contained package structure:

my_platform_plugin/
├── pyproject.toml
└── my_platform_plugin/
    ├── __init__.py    # activate() only, no top-level imports
    ├── device.py      # MluDeviceMixin
    └── platform.py    # MluSRTPlatform

@Baidu-AIAK
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@Baidu-AIAK
Copy link
Copy Markdown
Contributor Author

@merrymercy
The multimodal-gen-test-1-b200 stage is failing with errors not caused by this PR:
ValueError: New parameter 'blocks.11.attn2.to_k.weight_scale_2' is not supported
ModuleNotFoundError: No module named 'modelopt'

raise ValueError(
  ValueError: New parameter 'blocks.11.attn2.to_k.weight_scale_2' is not supported. Checkpoint-specific synthesized parameters should either match ['gate_compress', 'wcscales', 'wtscale', 'input_scale', 'bias', 'norm_q', 'norm_k'] or declare missing_param_init.

Could we skip this stage?

@Baidu-AIAK
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@merrymercy merrymercy merged commit 7ca3566 into sgl-project:main Apr 20, 2026
811 of 915 checks passed
zhangying098 pushed a commit to zhangying098/sglang that referenced this pull request Apr 23, 2026
Co-authored-by: root <root@tjzj-inf-sci-k8s-bzz2-0183.tjzj.baidu.com>
Co-authored-by: Alex Nails <alex.nails@radixark.ai>
Co-authored-by: Alex Nails <alexj.nails@gmail.com>
Co-authored-by: root <root@tjzj-inf-sci-k8s-bzz2-0000.tjzj.baidu.com>
Co-authored-by: Mick <mickjagger19@icloud.com>
kyx1999 pushed a commit to KMSorSMS/sglang that referenced this pull request Apr 27, 2026
Co-authored-by: root <root@tjzj-inf-sci-k8s-bzz2-0183.tjzj.baidu.com>
Co-authored-by: Alex Nails <alex.nails@radixark.ai>
Co-authored-by: Alex Nails <alexj.nails@gmail.com>
Co-authored-by: root <root@tjzj-inf-sci-k8s-bzz2-0000.tjzj.baidu.com>
Co-authored-by: Mick <mickjagger19@icloud.com>
@hnyls2002 hnyls2002 mentioned this pull request Apr 29, 2026
@afei6
Copy link
Copy Markdown

afei6 commented May 7, 2026

Hi, Do we have plan to implement the same multi-platform plugin system for multimodal-gen too?

@alexnails
Copy link
Copy Markdown
Collaborator

@afei6 the plug in itself will have a MMPlatform to be composed to, it just not the current set of tasks we are working on as the current priority is SRTPlatform side. I will write a Roadmap docs for tasks people can take

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation high priority piecewise-cuda-graph run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants